From 0864549ecc26e734bae3dcf40e0d761232f8bdad Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Mon, 28 Oct 2019 18:19:12 -0700 Subject: Use the user supplied TCP MSS when creating a new active socket This change supports using a user supplied TCP MSS for new active TCP connections. Note, the user supplied MSS must be less than or equal to the maximum possible MSS for a TCP connection's route. If it is greater than the maximum possible MSS, the maximum possible MSS will be used as the connection's MSS instead. This change does not use this user supplied MSS for connections accepted from listening sockets - that will come in a later change. Test: Test that outgoing TCP SYN segments contain a TCP MSS option with the user supplied MSS if it is not greater than the maximum possible MSS for the route. PiperOrigin-RevId: 277185125 --- pkg/tcpip/stack/nic.go | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index a867f8c00..a01a208b8 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -89,6 +89,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For // example, make sure that the link address it provides is a valid // unicast ethernet address. + + // TODO(b/143357959): RFC 8200 section 5 requires that IPv6 endpoints + // observe an MTU of at least 1280 bytes. Ensure that this requirement + // of IPv6 is supported on this endpoint's LinkEndpoint. + nic := &NIC{ stack: stack, id: id, -- cgit v1.2.3 From 7d80e85835fbe47b2395eedf287cf902ed78599a Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Tue, 29 Oct 2019 11:19:04 -0700 Subject: Allow waiting for Endpoint worker goroutines to finish. Updates #837 PiperOrigin-RevId: 277325162 --- pkg/tcpip/stack/registration.go | 14 ++++++++++++++ pkg/tcpip/stack/transport_demuxer.go | 20 ++++++++++++++++++++ pkg/tcpip/stack/transport_test.go | 5 +++-- pkg/tcpip/transport/icmp/endpoint.go | 3 +++ pkg/tcpip/transport/raw/endpoint.go | 3 +++ pkg/tcpip/transport/tcp/endpoint.go | 16 ++++++++++++++++ pkg/tcpip/transport/udp/endpoint.go | 3 +++ 7 files changed, 62 insertions(+), 2 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 0869fb084..0360187b8 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -67,6 +67,20 @@ type TransportEndpoint interface { // HandleControlPacket is called by the stack when new control (e.g., // ICMP) packets arrive to this transport endpoint. HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) + + // Close puts the endpoint in a closed state and frees all resources + // associated with it. This cleanup may happen asynchronously. Wait can + // be used to block on this asynchronous cleanup. + Close() + + // Wait waits for any worker goroutines owned by the endpoint to stop. + // + // An endpoint can be requested to stop its worker goroutines by calling + // its Close method. + // + // Wait will not block if the endpoint hasn't started any goroutines + // yet, even if it might later. + Wait() } // RawTransportEndpoint is the interface that needs to be implemented by raw diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 97a1aec4b..9aff90a3d 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -240,6 +240,26 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, v ep.mu.RUnlock() // Don't use defer for performance reasons. } +// Close implements stack.TransportEndpoint.Close. +func (ep *multiPortEndpoint) Close() { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + for _, e := range eps { + e.Close() + } +} + +// Wait implements stack.TransportEndpoint.Wait. +func (ep *multiPortEndpoint) Wait() { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + for _, e := range eps { + e.Wait() + } +} + // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 86c62be25..db951c9ce 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -225,8 +225,9 @@ func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { return iptables.IPTables{}, nil } -func (f *fakeTransportEndpoint) Resume(*stack.Stack) { -} +func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} + +func (f *fakeTransportEndpoint) Wait() {} type fakeTransportGoodOption bool diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 043467519..d0dd383fd 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -798,3 +798,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo { func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 308f10d24..951d317ed 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -641,3 +641,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo { func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8234a8b53..ce8307cee 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2399,6 +2399,22 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } +// Wait implements stack.TransportEndpoint.Wait. +func (e *endpoint) Wait() { + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) + defer e.waiterQueue.EventUnregister(&waitEntry) + for { + e.mu.Lock() + running := e.workerRunning + e.mu.Unlock() + if !running { + break + } + <-notifyCh + } +} + func mssForRoute(r *stack.Route) uint16 { // TODO(b/143359391): Respect TCP Min and Max size. return uint16(r.MTU() - header.TCPMinimumSize) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 91c8487f3..cda302bb7 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1234,6 +1234,9 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } +// Wait implements tcpip.Endpoint.Wait. +func (*endpoint) Wait() {} + func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } -- cgit v1.2.3 From a2c51efe3669f0380042b2375eae79e403d3680c Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Tue, 29 Oct 2019 16:13:43 -0700 Subject: Add endpoint tracking to the stack. In the future this will replace DanglingEndpoints. DanglingEndpoints must be kept for now due to issues with save/restore. This is arguably a cleaner design and allows the stack to know which transport endpoints might still be using its link endpoints. Updates #837 PiperOrigin-RevId: 277386633 --- pkg/sentry/inet/BUILD | 5 ++- pkg/sentry/inet/inet.go | 12 ++++++ pkg/sentry/inet/test_stack.go | 16 +++++++- pkg/sentry/socket/hostinet/BUILD | 1 + pkg/sentry/socket/hostinet/stack.go | 10 +++++ pkg/sentry/socket/netstack/stack.go | 15 +++++++ pkg/sentry/socket/rpcinet/BUILD | 1 + pkg/sentry/socket/rpcinet/stack.go | 10 +++++ pkg/tcpip/stack/stack.go | 59 ++++++++++++++++++++++++++-- pkg/tcpip/stack/transport_demuxer.go | 65 +++++++++++++++++++++++-------- pkg/tcpip/stack/transport_test.go | 3 +- pkg/tcpip/transport/tcp/endpoint.go | 5 ++- pkg/tcpip/transport/tcp/endpoint_state.go | 7 +++- 13 files changed, 182 insertions(+), 27 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index d5284f0d9..8d60ad4ad 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -13,5 +13,8 @@ go_library( "test_stack.go", ], importpath = "gvisor.dev/gvisor/pkg/sentry/inet", - deps = ["//pkg/sentry/context"], + deps = [ + "//pkg/sentry/context", + "//pkg/tcpip/stack", + ], ) diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index bc6cb1095..a7dfb78a7 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -15,6 +15,8 @@ // Package inet defines semantics for IP stacks. package inet +import "gvisor.dev/gvisor/pkg/tcpip/stack" + // Stack represents a TCP/IP stack. type Stack interface { // Interfaces returns all network interfaces as a mapping from interface @@ -58,6 +60,16 @@ type Stack interface { // Resume restarts the network stack after restore. Resume() + + // RegisteredEndpoints returns all endpoints which are currently registered. + RegisteredEndpoints() []stack.TransportEndpoint + + // CleanupEndpoints returns endpoints currently in the cleanup state. + CleanupEndpoints() []stack.TransportEndpoint + + // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful + // for restoring a stack after a save. + RestoreCleanupEndpoints([]stack.TransportEndpoint) } // Interface contains information about a network interface. diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index b9eed7c3a..dcfcbd97e 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -14,6 +14,8 @@ package inet +import "gvisor.dev/gvisor/pkg/tcpip/stack" + // TestStack is a dummy implementation of Stack for tests. type TestStack struct { InterfacesMap map[int32]Interface @@ -94,5 +96,17 @@ func (s *TestStack) RouteTable() []Route { } // Resume implements Stack.Resume. -func (s *TestStack) Resume() { +func (s *TestStack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *TestStack) RegisteredEndpoints() []stack.TransportEndpoint { + return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint { + return nil +} + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 4d174dda4..8b66a719d 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", + "//pkg/tcpip/stack", "//pkg/waiter", ], ) diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index d4387f5d4..e67b46c9e 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) var defaultRecvBufSize = inet.TCPBufferSize{ @@ -442,3 +443,12 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index d5db8c17c..a0db2d4fd 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -291,3 +291,18 @@ func (s *Stack) FillDefaultIPTables() { func (s *Stack) Resume() { s.Stack.Resume() } + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { + return s.Stack.RegisteredEndpoints() +} + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { + return s.Stack.CleanupEndpoints() +} + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) { + s.Stack.RestoreCleanupEndpoints(es) +} diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 3a6baa308..4668b87d1 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -37,6 +37,7 @@ go_library( "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", "//pkg/unet", "//pkg/waiter", ], diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go index 5dcb6b455..f7878a760 100644 --- a/pkg/sentry/socket/rpcinet/stack.go +++ b/pkg/sentry/socket/rpcinet/stack.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/unet" ) @@ -165,3 +166,12 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} + +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 242d2150c..360c54b2d 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -361,9 +361,10 @@ type Stack struct { linkAddrCache *linkAddrCache - mu sync.RWMutex - nics map[tcpip.NICID]*NIC - forwarding bool + mu sync.RWMutex + nics map[tcpip.NICID]*NIC + forwarding bool + cleanupEndpoints map[TransportEndpoint]struct{} // route is the route table passed in by the user via SetRouteTable(), // it is used by FindRoute() to build a route for a specific @@ -513,6 +514,7 @@ func New(opts Options) *Stack { networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), nics: make(map[tcpip.NICID]*NIC), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), PortManager: ports.NewPortManager(), clock: clock, @@ -1136,6 +1138,25 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) } +// StartTransportEndpointCleanup removes the endpoint with the given id from +// the stack transport dispatcher. It also transitions it to the cleanup stage. +func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + s.mu.Lock() + defer s.mu.Unlock() + + s.cleanupEndpoints[ep] = struct{}{} + + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) +} + +// CompleteTransportEndpointCleanup removes the endpoint from the cleanup +// stage. +func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { + s.mu.Lock() + delete(s.cleanupEndpoints, ep) + s.mu.Unlock() +} + // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. @@ -1157,6 +1178,38 @@ func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) { s.mu.Unlock() } +// RegisteredEndpoints returns all endpoints which are currently registered. +func (s *Stack) RegisteredEndpoints() []TransportEndpoint { + s.mu.Lock() + defer s.mu.Unlock() + var es []TransportEndpoint + for _, e := range s.demux.protocol { + es = append(es, e.transportEndpoints()...) + } + return es +} + +// CleanupEndpoints returns endpoints currently in the cleanup state. +func (s *Stack) CleanupEndpoints() []TransportEndpoint { + s.mu.Lock() + es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints)) + for e := range s.cleanupEndpoints { + es = append(es, e) + } + s.mu.Unlock() + return es +} + +// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful +// for restoring a stack after a save. +func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { + s.mu.Lock() + for _, e := range es { + s.cleanupEndpoints[e] = struct{}{} + } + s.mu.Unlock() +} + // Resume restarts the stack after a restore. This must be called after the // entire system has been restored. func (s *Stack) Resume() { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 9aff90a3d..f633632f0 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -41,6 +41,31 @@ type transportEndpoints struct { rawEndpoints []RawTransportEndpoint } +// unregisterEndpoint unregisters the endpoint with the given id such that it +// won't receive any more packets. +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + eps.mu.Lock() + defer eps.mu.Unlock() + epsByNic, ok := eps.endpoints[id] + if !ok { + return + } + if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + return + } + delete(eps.endpoints, id) +} + +func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { + eps.mu.RLock() + defer eps.mu.RUnlock() + es := make([]TransportEndpoint, 0, len(eps.endpoints)) + for _, e := range eps.endpoints { + es = append(es, e.transportEndpoints()...) + } + return es +} + type endpointsByNic struct { mu sync.RWMutex endpoints map[tcpip.NICID]*multiPortEndpoint @@ -48,6 +73,16 @@ type endpointsByNic struct { seed uint32 } +func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { + epsByNic.mu.RLock() + defer epsByNic.mu.RUnlock() + var eps []TransportEndpoint + for _, ep := range epsByNic.endpoints { + eps = append(eps, ep.transportEndpoints()...) + } + return eps +} + // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { @@ -127,21 +162,6 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T return len(epsByNic.endpoints) == 0 } -// unregisterEndpoint unregisters the endpoint with the given id such that it -// won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { - eps.mu.Lock() - defer eps.mu.Unlock() - epsByNic, ok := eps.endpoints[id] - if !ok { - return - } - if !epsByNic.unregisterEndpoint(bindToDevice, ep) { - return - } - delete(eps.endpoints, id) -} - // transportDemuxer demultiplexes packets targeted at a transport endpoint // (i.e., after they've been parsed by the network layer). It does two levels // of demultiplexing: first based on the network and transport protocols, then @@ -183,14 +203,27 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. +// +// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save +// this to ensure that the underlying endpoints get saved/restored, but not not +// use the restored copy. +// +// +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int // reuse indicates if more than one endpoint is allowed. reuse bool } +func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { + ep.mu.RLock() + eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + ep.mu.RUnlock() + return eps +} + // reciprocalScale scales a value into range [0, n). // // This is similar to val % n, but faster. diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index db951c9ce..ae6fda3a9 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -218,8 +218,7 @@ func (f *fakeTransportEndpoint) State() uint32 { return 0 } -func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) { -} +func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { return iptables.IPTables{}, nil diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ce8307cee..8a3ca0f1b 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -686,7 +686,7 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) e.isRegistered = false } @@ -747,7 +747,7 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice) e.isRegistered = false } @@ -757,6 +757,7 @@ func (e *endpoint) cleanupLocked() { } e.route.Release() + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index eae17237e..19f003b6b 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -193,8 +193,10 @@ func (e *endpoint) Resume(s *stack.Stack) { if len(e.BindAddr) == 0 { e.BindAddr = e.ID.LocalAddress } - if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil { - panic("endpoint binding failed: " + err.String()) + addr := e.BindAddr + port := e.ID.LocalPort + if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil { + panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err)) } } @@ -265,6 +267,7 @@ func (e *endpoint) Resume(s *stack.Stack) { } fallthrough case StateError: + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } } -- cgit v1.2.3 From dc21c5ca16dbc43755185ffdf53764c7bb4c3a12 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Tue, 29 Oct 2019 17:21:01 -0700 Subject: Add Close and Wait methods to stack. Link endpoints still don't have a unified way to be requested to stop. Updates #837 PiperOrigin-RevId: 277398952 --- pkg/tcpip/stack/stack.go | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 360c54b2d..6d6ddc0ff 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1210,6 +1210,37 @@ func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { s.mu.Unlock() } +// Close closes all currently registered transport endpoints. +// +// Endpoints created or modified during this call may not get closed. +func (s *Stack) Close() { + for _, e := range s.RegisteredEndpoints() { + e.Close() + } +} + +// Wait waits for all transport and link endpoints to halt their worker +// goroutines. +// +// Endpoints created or modified during this call may not get waited on. +// +// Note that link endpoints must be stopped via an implementation specific +// mechanism. +func (s *Stack) Wait() { + for _, e := range s.RegisteredEndpoints() { + e.Wait() + } + for _, e := range s.CleanupEndpoints() { + e.Wait() + } + + s.mu.RLock() + defer s.mu.RUnlock() + for _, n := range s.nics { + n.linkEP.Wait() + } +} + // Resume restarts the stack after a restore. This must be called after the // entire system has been restored. func (s *Stack) Resume() { -- cgit v1.2.3 From db37483cb6acf55b66132d534bb734f09555b1cf Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Wed, 30 Oct 2019 15:32:20 -0700 Subject: Store endpoints inside multiPortEndpoint in a sorted order It is required to guarantee the same order of endpoints after save/restore. PiperOrigin-RevId: 277598665 --- pkg/tcpip/stack/registration.go | 3 + pkg/tcpip/stack/stack.go | 29 ++++++++ pkg/tcpip/stack/transport_demuxer.go | 10 +++ pkg/tcpip/stack/transport_test.go | 11 ++- pkg/tcpip/transport/icmp/endpoint.go | 7 ++ pkg/tcpip/transport/tcp/endpoint.go | 7 ++ pkg/tcpip/transport/udp/endpoint.go | 7 ++ runsc/boot/loader.go | 5 +- test/syscalls/linux/socket_inet_loopback.cc | 107 ++++++++++++++++++++++++++++ 9 files changed, 181 insertions(+), 5 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 0360187b8..94015ba54 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -60,6 +60,9 @@ const ( // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { + // UniqueID returns an unique ID for this transport endpoint. + UniqueID() uint64 + // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6d6ddc0ff..115a6fcb8 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -22,6 +22,7 @@ package stack import ( "encoding/binary" "sync" + "sync/atomic" "time" "golang.org/x/time/rate" @@ -344,6 +345,13 @@ type ResumableEndpoint interface { Resume(*Stack) } +// uniqueIDGenerator is a default unique ID generator. +type uniqueIDGenerator uint64 + +func (u *uniqueIDGenerator) UniqueID() uint64 { + return atomic.AddUint64((*uint64)(u), 1) +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -411,6 +419,14 @@ type Stack struct { // ndpDisp is the NDP event dispatcher that is used to send the netstack // integrator NDP related events. ndpDisp NDPDispatcher + + // uniqueIDGenerator is a generator of unique identifiers. + uniqueIDGenerator UniqueID +} + +// UniqueID is an abstract generator of unique identifiers. +type UniqueID interface { + UniqueID() uint64 } // Options contains optional Stack configuration. @@ -434,6 +450,9 @@ type Options struct { // stack (false). HandleLocal bool + // UniqueID is an optional generator of unique identifiers. + UniqueID UniqueID + // NDPConfigs is the default NDP configurations used by interfaces. // // By default, NDPConfigs will have a zero value for its @@ -506,6 +525,10 @@ func New(opts Options) *Stack { clock = &tcpip.StdClock{} } + if opts.UniqueID == nil { + opts.UniqueID = new(uniqueIDGenerator) + } + // Make sure opts.NDPConfigs contains valid values only. opts.NDPConfigs.validate() @@ -524,6 +547,7 @@ func New(opts Options) *Stack { portSeed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, + uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, } @@ -551,6 +575,11 @@ func New(opts Options) *Stack { return s } +// UniqueID returns a unique identifier. +func (s *Stack) UniqueID() uint64 { + return s.uniqueIDGenerator.UniqueID() +} + // SetNetworkProtocolOption allows configuring individual protocol level // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index f633632f0..ccd3d030e 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -17,6 +17,7 @@ package stack import ( "fmt" "math/rand" + "sort" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -310,6 +311,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo // endpointsMap. This will allow us to remove endpoint from the array fast. ep.endpointsMap[t] = len(ep.endpointsArr) ep.endpointsArr = append(ep.endpointsArr, t) + + // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints + // can be restored in the same order. + sort.Slice(ep.endpointsArr, func(i, j int) bool { + return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID() + }) + for i, e := range ep.endpointsArr { + ep.endpointsMap[e] = i + } return nil } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index ae6fda3a9..203e79f56 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -43,6 +43,7 @@ type fakeTransportEndpoint struct { proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route + uniqueID uint64 // acceptQueue is non-nil iff bound. acceptQueue []fakeTransportEndpoint @@ -56,8 +57,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } -func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto} +func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { + return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } func (f *fakeTransportEndpoint) Close() { @@ -144,6 +145,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return nil } +func (f *fakeTransportEndpoint) UniqueID() uint64 { + return f.uniqueID +} + func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { return nil } @@ -251,7 +256,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { } func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto), nil + return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index d0dd383fd..114a69b4e 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -58,6 +58,7 @@ type endpoint struct { // immutable. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -90,9 +91,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, state: stateInitial, + uniqueID: s.UniqueID(), }, nil } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8a3ca0f1b..a1efd8d55 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -287,6 +287,7 @@ type endpoint struct { // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` + uniqueID uint64 // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. @@ -504,6 +505,11 @@ type endpoint struct { stats Stats `state:"nosave"` } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // calculateAdvertisedMSS calculates the MSS to advertise. // // If userMSS is non-zero and is not greater than the maximum possible MSS for @@ -565,6 +571,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue interval: 75 * time.Second, count: 9, }, + uniqueID: s.UniqueID(), } var ss SendBufferSizeOption diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index cda302bb7..68977dc25 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -80,6 +80,7 @@ type endpoint struct { // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -160,9 +161,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, state: StateInitial, + uniqueID: s.UniqueID(), } } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 0c0eba99e..86df384f8 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -232,7 +232,7 @@ func New(args Args) (*Loader, error) { // this point. Netns is configured before Run() is called. Netstack is // configured using a control uRPC message. Host network is configured inside // Run(). - networkStack, err := newEmptyNetworkStack(args.Conf, k) + networkStack, err := newEmptyNetworkStack(args.Conf, k, k) if err != nil { return nil, fmt.Errorf("creating network: %v", err) } @@ -905,7 +905,7 @@ func (l *Loader) WaitExit() kernel.ExitStatus { return l.k.GlobalInit().ExitStatus() } -func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) { +func newEmptyNetworkStack(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) { switch conf.Network { case NetworkHost: return hostinet.NewStack(), nil @@ -923,6 +923,7 @@ func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) { // Enable raw sockets for users with sufficient // privileges. RawFactory: raw.EndpointFactory{}, + UniqueID: uniqueID, })} // Enable SACK Recovery. diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 322ee07ad..ab375aaaf 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -516,6 +517,112 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) { EquivalentWithin((kConnectAttempts / kThreadCount), 0.10)); } +TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + constexpr int kThreadCount = 3; + + // TODO(b/141211329): endpointsByNic.seed has to be saved/restored. + const DisableSave ds141211329; + + // Create listening sockets. + FileDescriptor listener_fds[kThreadCount]; + for (int i = 0; i < kThreadCount; i++) { + listener_fds[i] = + ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)); + int fd = listener_fds[i].get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast(&listen_addr), listener.addr_len), + SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (i != 0) { + continue; + } + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10; + FileDescriptor client_fds[kConnectAttempts]; + + // Do the first run without save/restore. + DisableSave ds; + for (int i = 0; i < kConnectAttempts; i++) { + client_fds[i] = + ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); + EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0, + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + } + ds.reset(); + + // Check that a mapping of client and server sockets has + // not been change after save/restore. + for (int i = 0; i < kConnectAttempts; i++) { + EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0, + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + } + + int epollfd; + ASSERT_THAT(epollfd = epoll_create1(0), SyscallSucceeds()); + + for (int i = 0; i < kThreadCount; i++) { + int fd = listener_fds[i].get(); + struct epoll_event ev; + ev.data.fd = fd; + ev.events = EPOLLIN; + ASSERT_THAT(epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev), SyscallSucceeds()); + } + + std::map portToFD; + + for (int i = 0; i < kConnectAttempts * 2; i++) { + struct sockaddr_storage addr = {}; + socklen_t addrlen = sizeof(addr); + struct epoll_event ev; + int data, fd; + + ASSERT_THAT(epoll_wait(epollfd, &ev, 1, -1), SyscallSucceedsWithValue(1)); + + fd = ev.data.fd; + EXPECT_THAT(RetryEINTR(recvfrom)(fd, &data, sizeof(data), 0, + reinterpret_cast(&addr), + &addrlen), + SyscallSucceedsWithValue(sizeof(data))); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(connector.family(), addr)); + auto prev_port = portToFD.find(port); + // Check that all packets from one client have been delivered to the same + // server socket. + if (prev_port == portToFD.end()) { + portToFD[port] = ev.data.fd; + } else { + EXPECT_EQ(portToFD[port], ev.data.fd); + } + } +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetReusePortTest, ::testing::Values( -- cgit v1.2.3 From 3246040447c6d0a08cc12c5721480c06f77f5dfe Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 30 Oct 2019 17:00:29 -0700 Subject: Deep copy dispatcher views. When VectorisedViews were passed up the stack from packet_dispatchers, we were passing a sub-slice of the dispatcher's views fields. The dispatchers then immediately set those views to nil. This wasn't caught before because every implementer copied the data in these views before returning. PiperOrigin-RevId: 277615351 --- pkg/tcpip/link/fdbased/endpoint_test.go | 10 +++++++--- pkg/tcpip/link/fdbased/packet_dispatchers.go | 4 ++-- pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/registration.go | 16 ++++++++++++++++ pkg/tcpip/transport/icmp/endpoint.go | 5 +---- pkg/tcpip/transport/packet/endpoint.go | 8 ++------ pkg/tcpip/transport/raw/endpoint.go | 6 +----- pkg/tcpip/transport/udp/endpoint.go | 5 +---- 8 files changed, 31 insertions(+), 25 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 59378b96c..e7c05ca4f 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -45,7 +45,7 @@ const ( type packetInfo struct { raddr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - contents buffer.View + contents buffer.VectorisedView linkHeader buffer.View } @@ -94,7 +94,7 @@ func (c *context) cleanup() { } func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { - c.ch <- packetInfo{remote, protocol, vv.ToView(), linkHeader} + c.ch <- packetInfo{remote, protocol, vv, linkHeader} } func TestNoEthernetProperties(t *testing.T) { @@ -319,13 +319,17 @@ func TestDeliverPacket(t *testing.T) { want := packetInfo{ raddr: raddr, proto: proto, - contents: b, + contents: buffer.View(b).ToVectorisedView(), linkHeader: buffer.View(hdr), } if !eth { want.proto = header.IPv4ProtocolNumber want.raddr = "" } + // want.contents will be a single view, + // so make pi do the same for the + // DeepEqual check. + pi.contents = pi.contents.ToView().ToVectorisedView() if !reflect.DeepEqual(want, pi) { t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 12168a1dc..3331b6453 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -139,7 +139,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - vv := buffer.NewVectorisedView(n, d.views[:used]) + vv := buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)) vv.TrimFront(d.e.hdrSize) d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth)) @@ -293,7 +293,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - vv := buffer.NewVectorisedView(int(n), d.views[k][:used]) + vv := buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)) vv.TrimFront(d.e.hdrSize) d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth)) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index a01a208b8..fe8f83d58 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -762,7 +762,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } n.mu.RUnlock() for _, ep := range packetEPs { - ep.HandlePacket(n.id, local, protocol, vv, linkHeader) + ep.HandlePacket(n.id, local, protocol, vv.Clone(nil), linkHeader) } if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 94015ba54..d7c124e81 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -65,10 +65,14 @@ type TransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. + // + // HandlePacket takes ownership of vv. HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) // HandleControlPacket is called by the stack when new control (e.g., // ICMP) packets arrive to this transport endpoint. + // + // HandleControlPacket takes ownership of vv. HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) // Close puts the endpoint in a closed state and frees all resources @@ -94,6 +98,8 @@ type RawTransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. The packet contains all data from the link // layer up. + // + // HandlePacket takes ownership of packet and netHeader. HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView) } @@ -110,6 +116,8 @@ type PacketEndpoint interface { // // linkHeader may have a length of 0, in which case the PacketEndpoint // should construct its own ethernet header for applications. + // + // HandlePacket takes ownership of packet and linkHeader. HandlePacket(nicid tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, packet buffer.VectorisedView, linkHeader buffer.View) } @@ -160,10 +168,14 @@ type TransportDispatcher interface { // DeliverTransportPacket delivers packets to the appropriate // transport protocol endpoint. It also returns the network layer // header for the enpoint to inspect or pass up the stack. + // + // DeliverTransportPacket takes ownership of vv and netHeader. DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. + // + // DeliverTransportControlPacket takes ownership of vv. DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) } @@ -237,6 +249,8 @@ type NetworkEndpoint interface { // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. + // + // HandlePacket takes ownership of vv. HandlePacket(r *Route, vv buffer.VectorisedView) // Close is called when the endpoint is reomved from a stack. @@ -282,6 +296,8 @@ type NetworkDispatcher interface { // DeliverNetworkPacket finds the appropriate network protocol endpoint // and hands the packet over for further processing. linkHeader may have // length 0 when the caller does not have ethernet data. + // + // DeliverNetworkPacket takes ownership of vv and linkHeader. DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 114a69b4e..33405eb7d 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -31,9 +31,6 @@ type icmpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } type endpointState int @@ -767,7 +764,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv }, } - pkt.data = vv.Clone(pkt.views[:]) + pkt.data = vv e.rcvList.PushBack(pkt) e.rcvBufSize += pkt.data.Size() diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 73cdaa265..ead83b83d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -41,10 +41,6 @@ type packet struct { // data holds the actual packet data, including any headers and // payload. data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // views is pre-allocated space to back data. As long as the packet is - // made up of fewer than 8 buffer.Views, no extra allocation is - // necessary to store packet data. - views [8]buffer.View `state:"nosave"` // timestampNS is the unix time at which the packet was received. timestampNS int64 // senderAddr is the network address of the sender. @@ -310,7 +306,7 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, if ep.cooked { // Cooked packets can simply be queued. - packet.data = vv.Clone(packet.views[:]) + packet.data = vv } else { // Raw packets need their ethernet headers prepended before // queueing. @@ -328,7 +324,7 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, } combinedVV := buffer.View(ethHeader).ToVectorisedView() combinedVV.Append(vv) - packet.data = combinedVV.Clone(packet.views[:]) + packet.data = combinedVV } packet.timestampNS = ep.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 951d317ed..23922a30e 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -42,10 +42,6 @@ type rawPacket struct { // data holds the actual packet data, including any headers and // payload. data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // views is pre-allocated space to back data. As long as the packet is - // made up of fewer than 8 buffer.Views, no extra allocation is - // necessary to store packet data. - views [8]buffer.View `state:"nosave"` // timestampNS is the unix time at which the packet was received. timestampNS int64 // senderAddr is the network address of the sender. @@ -609,7 +605,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu combinedVV := netHeader.ToVectorisedView() combinedVV.Append(vv) - pkt.data = combinedVV.Clone(pkt.views[:]) + pkt.data = combinedVV pkt.timestampNS = e.stack.NowNanoseconds() e.rcvList.PushBack(pkt) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 68977dc25..03bd5c8fd 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -31,9 +31,6 @@ type udpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } // EndpointState represents the state of a UDP endpoint. @@ -1202,7 +1199,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv Port: hdr.SourcePort(), }, } - pkt.data = vv.Clone(pkt.views[:]) + pkt.data = vv e.rcvList.PushBack(pkt) e.rcvBufSize += vv.Size() -- cgit v1.2.3 From a824b48ceac4e2e3bacd23d63e72881c76d669c8 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 6 Nov 2019 10:38:02 -0800 Subject: Validate incoming NDP Router Advertisements, as per RFC 4861 section 6.1.2 This change validates incoming NDP Router Advertisements as per RFC 4861 section 6.1.2. It also includes the skeleton to handle Router Advertiements that arrive on some NIC. Tests: Unittest to make sure only valid NDP Router Advertisements are received/ not dropped. PiperOrigin-RevId: 278891972 --- pkg/tcpip/network/ipv6/icmp.go | 51 +++++++++- pkg/tcpip/network/ipv6/icmp_test.go | 4 +- pkg/tcpip/network/ipv6/ndp_test.go | 189 +++++++++++++++++++++++++++++++++++- pkg/tcpip/stack/ndp.go | 55 +++++++++++ pkg/tcpip/stack/nic.go | 13 ++- pkg/tcpip/stack/stack.go | 16 +++ 6 files changed, 322 insertions(+), 6 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index c3f1dd488..05e8c075b 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -86,7 +86,8 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field - // in the IPv6 header is not set to 255. + // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not + // set to 0. switch h.Type() { case header.ICMPv6NeighborSolicit, header.ICMPv6NeighborAdvert, @@ -97,6 +98,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } + + if h.Code() != 0 { + received.Invalid.Increment() + return + } } // TODO(b/112892170): Meaningfully handle all ICMP types. @@ -309,8 +315,51 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.RouterSolicit.Increment() case header.ICMPv6RouterAdvert: + routerAddr := iph.SourceAddress() + + // + // Validate the RA as per RFC 4861 section 6.1.2. + // + + // Is the IP Source Address a link-local address? + if !header.IsV6LinkLocalAddress(routerAddr) { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + p := h.NDPPayload() + + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if len(p) < header.NDPRAMinimumSize { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + ra := header.NDPRouterAdvert(p) + opts := ra.Options() + + // Are options valid as per the wire format? + if _, err := opts.Iter(true); err != nil { + // ...No, silently drop the packet. + received.Invalid.Increment() + return + } + + // + // At this point, we have a valid Router Advertisement, as far + // as RFC 4861 section 6.1.2 is concerned. + // + received.RouterAdvert.Increment() + // Tell the NIC to handle the RA. + stack := r.Stack() + rxNICID := r.NICID() + stack.HandleNDPRA(rxNICID, routerAddr, ra) + case header.ICMPv6RedirectMsg: received.RedirectMsg.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index b112303b6..d686f79ce 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -131,7 +131,7 @@ func TestICMPCounts(t *testing.T) { {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize}, {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize}, {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize}, - {header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize}, + {header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize}, {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize}, {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize}, {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize}, @@ -426,7 +426,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { { "RouterAdvert", header.ICMPv6RouterAdvert, - header.ICMPv6MinimumSize, + header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterAdvert }, diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index c32716f2e..69ab7ba12 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) @@ -109,7 +110,7 @@ func TestHopLimitValidation(t *testing.T) { {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterSolicit }}, - {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterAdvert }}, {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { @@ -179,3 +180,189 @@ func TestHopLimitValidation(t *testing.T) { }) } } + +// TestRouterAdvertValidation tests that when the NIC is configured to handle +// NDP Router Advertisement packets, it validates the Router Advertisement +// properly before handling them. +func TestRouterAdvertValidation(t *testing.T) { + tests := []struct { + name string + src tcpip.Address + hopLimit uint8 + code uint8 + ndpPayload []byte + expectedSuccess bool + }{ + { + "OK", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + true, + }, + { + "NonLinkLocalSourceAddr", + addr1, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "HopLimitNot255", + lladdr0, + 254, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NonZeroCode", + lladdr0, + 255, + 1, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + false, + }, + { + "NDPPayloadTooSmall", + lladdr0, + 255, + 0, + []byte{ + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, + }, + false, + }, + { + "OKWithOptions", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + 2, 1, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + true, + }, + { + "OptionWithZeroLength", + lladdr0, + 255, + 0, + []byte{ + // RA payload + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + + // Option #1 (TargetLinkLayerAddress) + // Invalid as it has 0 length. + 2, 0, 0, 0, 0, 0, 0, 0, + + // Option #2 (unrecognized) + 255, 1, 0, 0, 0, 0, 0, 0, + + // Option #3 (PrefixInformation) + 3, 4, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(test.code) + copy(pkt.NDPPayload(), test.ndpPayload) + payloadLength := hdr.UsedLength() + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + rxRA := stats.RouterAdvert + + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } + + e.Inject(header.IPv6ProtocolNumber, hdr.View().ToVectorisedView()) + + if test.expectedSuccess { + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) + } + + } else { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } + } + }) + } +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 03ddebdbd..d5352bb5f 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -38,6 +38,19 @@ const ( // Default = 1s (from RFC 4861 section 10). defaultRetransmitTimer = time.Second + // defaultHandleRAs is the default configuration for whether or not to + // handle incoming Router Advertisements as a host. + // + // Default = true. + defaultHandleRAs = true + + // defaultDiscoverDefaultRouters is the default configuration for + // whether or not to discover default routers from incoming Router + // Advertisements as a host. + // + // Default = true. + defaultDiscoverDefaultRouters = true + // minimumRetransmitTimer is the minimum amount of time to wait between // sending NDP Neighbor solicitation messages. Note, RFC 4861 does // not impose a minimum Retransmit Timer, but we do here to make sure @@ -49,6 +62,13 @@ const ( // // Min = 1ms. minimumRetransmitTimer = time.Millisecond + + // MaxDiscoveredDefaultRouters is the maximum number of discovered + // default routers. The stack should stop discovering new routers after + // discovering MaxDiscoveredDefaultRouters routers. + // + // Max = 10. + MaxDiscoveredDefaultRouters = 10 ) // NDPDispatcher is the interface integrators of netstack must implement to @@ -80,6 +100,15 @@ type NDPConfigurations struct { // // Must be greater than 0.5s. RetransmitTimer time.Duration + + // HandleRAs determines whether or not Router Advertisements will be + // processed. + HandleRAs bool + + // DiscoverDefaultRouters determines whether or not default routers will + // be discovered from Router Advertisements. This configuration is + // ignored if HandleRAs is false. + DiscoverDefaultRouters bool } // DefaultNDPConfigurations returns an NDPConfigurations populated with @@ -88,6 +117,8 @@ func DefaultNDPConfigurations() NDPConfigurations { return NDPConfigurations{ DupAddrDetectTransmits: defaultDupAddrDetectTransmits, RetransmitTimer: defaultRetransmitTimer, + HandleRAs: defaultHandleRAs, + DiscoverDefaultRouters: defaultDiscoverDefaultRouters, } } @@ -112,6 +143,9 @@ type ndpState struct { // The DAD state to send the next NS message, or resolve the address. dad map[tcpip.Address]dadState + + // The default routers discovered through Router Advertisements. + defaultRouters map[tcpip.Address]defaultRouterState } // dadState holds the Duplicate Address Detection timer and channel to signal @@ -127,6 +161,12 @@ type dadState struct { done *bool } +// defaultRouterState holds data associated with a default router discovered by +// a Router Advertisement. +type defaultRouterState struct { + invalidationTimer *time.Timer +} + // startDuplicateAddressDetection performs Duplicate Address Detection. // // This function must only be called by IPv6 addresses that are currently @@ -319,3 +359,18 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) } } + +// handleRA handles a Router Advertisement message that arrived on the NIC +// this ndp is for. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { + // Is the NIC configured to handle RAs at all? + if !ndp.configs.HandleRAs { + return + } + + // TODO(b/140882146): Do Router Discovery. + // TODO(b/140948104): Do Prefix Discovery. + // TODO(b/141556115): Do Parameter Discovery. +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index fe8f83d58..12969c74e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -115,8 +115,9 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback }, }, ndp: ndpState{ - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), }, } nic.ndp.nic = nic @@ -960,6 +961,14 @@ func (n *NIC) setNDPConfigs(c NDPConfigurations) { n.mu.Unlock() } +// handleNDPRA handles an NDP Router Advertisement message that arrived on n. +func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { + n.mu.Lock() + defer n.mu.Unlock() + + n.ndp.handleRA(ip, ra) +} + type networkEndpointKind int32 const ( diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 115a6fcb8..8b141cafd 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1557,6 +1557,22 @@ func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip return nil } +// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement +// message that it needs to handle. +func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + nic.handleNDPRA(ip, ra) + + return nil +} + // PortSeed returns a 32 bit value that can be used as a seed value for port // picking. // -- cgit v1.2.3 From e1b21f3c8ca989dc94b25526fda1bb107691f1af Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 6 Nov 2019 14:24:38 -0800 Subject: Use PacketBuffers, rather than VectorisedViews, in netstack. PacketBuffers are analogous to Linux's sk_buff. They hold all information about a packet, headers, and payload. This is important for: * iptables to access various headers of packets * Preventing the clutter of passing different net and link headers along with VectorisedViews to packet handling functions. This change only affects the incoming packet path, and a future change will change the outgoing path. Benchmark Regular PacketBufferPtr PacketBufferConcrete -------------------------------------------------------------------------------- BM_Recvmsg 400.715MB/s 373.676MB/s 396.276MB/s BM_Sendmsg 361.832MB/s 333.003MB/s 335.571MB/s BM_Recvfrom 453.336MB/s 393.321MB/s 381.650MB/s BM_Sendto 378.052MB/s 372.134MB/s 341.342MB/s BM_SendmsgTCP/0/1k 353.711MB/s 316.216MB/s 322.747MB/s BM_SendmsgTCP/0/2k 600.681MB/s 588.776MB/s 565.050MB/s BM_SendmsgTCP/0/4k 995.301MB/s 888.808MB/s 941.888MB/s BM_SendmsgTCP/0/8k 1.517GB/s 1.274GB/s 1.345GB/s BM_SendmsgTCP/0/16k 1.872GB/s 1.586GB/s 1.698GB/s BM_SendmsgTCP/0/32k 1.017GB/s 1.020GB/s 1.133GB/s BM_SendmsgTCP/0/64k 475.626MB/s 584.587MB/s 627.027MB/s BM_SendmsgTCP/0/128k 416.371MB/s 503.434MB/s 409.850MB/s BM_SendmsgTCP/0/256k 323.449MB/s 449.599MB/s 388.852MB/s BM_SendmsgTCP/0/512k 243.992MB/s 267.676MB/s 314.474MB/s BM_SendmsgTCP/0/1M 95.138MB/s 95.874MB/s 95.417MB/s BM_SendmsgTCP/0/2M 96.261MB/s 94.977MB/s 96.005MB/s BM_SendmsgTCP/0/4M 96.512MB/s 95.978MB/s 95.370MB/s BM_SendmsgTCP/0/8M 95.603MB/s 95.541MB/s 94.935MB/s BM_SendmsgTCP/0/16M 94.598MB/s 94.696MB/s 94.521MB/s BM_SendmsgTCP/0/32M 94.006MB/s 94.671MB/s 94.768MB/s BM_SendmsgTCP/0/64M 94.133MB/s 94.333MB/s 94.746MB/s BM_SendmsgTCP/0/128M 93.615MB/s 93.497MB/s 93.573MB/s BM_SendmsgTCP/0/256M 93.241MB/s 95.100MB/s 93.272MB/s BM_SendmsgTCP/1/1k 303.644MB/s 316.074MB/s 308.430MB/s BM_SendmsgTCP/1/2k 537.093MB/s 584.962MB/s 529.020MB/s BM_SendmsgTCP/1/4k 882.362MB/s 939.087MB/s 892.285MB/s BM_SendmsgTCP/1/8k 1.272GB/s 1.394GB/s 1.296GB/s BM_SendmsgTCP/1/16k 1.802GB/s 2.019GB/s 1.830GB/s BM_SendmsgTCP/1/32k 2.084GB/s 2.173GB/s 2.156GB/s BM_SendmsgTCP/1/64k 2.515GB/s 2.463GB/s 2.473GB/s BM_SendmsgTCP/1/128k 2.811GB/s 3.004GB/s 2.946GB/s BM_SendmsgTCP/1/256k 3.008GB/s 3.159GB/s 3.171GB/s BM_SendmsgTCP/1/512k 2.980GB/s 3.150GB/s 3.126GB/s BM_SendmsgTCP/1/1M 2.165GB/s 2.233GB/s 2.163GB/s BM_SendmsgTCP/1/2M 2.370GB/s 2.219GB/s 2.453GB/s BM_SendmsgTCP/1/4M 2.005GB/s 2.091GB/s 2.214GB/s BM_SendmsgTCP/1/8M 2.111GB/s 2.013GB/s 2.109GB/s BM_SendmsgTCP/1/16M 1.902GB/s 1.868GB/s 1.897GB/s BM_SendmsgTCP/1/32M 1.655GB/s 1.665GB/s 1.635GB/s BM_SendmsgTCP/1/64M 1.575GB/s 1.547GB/s 1.575GB/s BM_SendmsgTCP/1/128M 1.524GB/s 1.584GB/s 1.580GB/s BM_SendmsgTCP/1/256M 1.579GB/s 1.607GB/s 1.593GB/s PiperOrigin-RevId: 278940079 --- pkg/tcpip/BUILD | 2 + pkg/tcpip/link/channel/channel.go | 10 ++-- pkg/tcpip/link/fdbased/endpoint.go | 4 +- pkg/tcpip/link/fdbased/endpoint_test.go | 27 ++++----- pkg/tcpip/link/fdbased/mmap.go | 5 +- pkg/tcpip/link/fdbased/packet_dispatchers.go | 18 ++++-- pkg/tcpip/link/loopback/loopback.go | 10 +++- pkg/tcpip/link/muxed/injectable.go | 4 +- pkg/tcpip/link/sharedmem/sharedmem.go | 7 ++- pkg/tcpip/link/sharedmem/sharedmem_test.go | 9 ++- pkg/tcpip/link/sniffer/sniffer.go | 12 ++-- pkg/tcpip/link/waitable/waitable.go | 4 +- pkg/tcpip/link/waitable/waitable_test.go | 8 +-- pkg/tcpip/network/arp/arp.go | 4 +- pkg/tcpip/network/arp/arp_test.go | 4 +- pkg/tcpip/network/ip_test.go | 34 ++++++++---- pkg/tcpip/network/ipv4/icmp.go | 34 +++++++----- pkg/tcpip/network/ipv4/ipv4.go | 43 +++++++++------ pkg/tcpip/network/ipv4/ipv4_test.go | 4 +- pkg/tcpip/network/ipv6/icmp.go | 48 ++++++++-------- pkg/tcpip/network/ipv6/icmp_test.go | 24 +++++--- pkg/tcpip/network/ipv6/ipv6.go | 28 ++++++---- pkg/tcpip/network/ipv6/ipv6_test.go | 8 ++- pkg/tcpip/network/ipv6/ndp_test.go | 8 ++- pkg/tcpip/packet_buffer.go | 54 ++++++++++++++++++ pkg/tcpip/packet_buffer_state.go | 26 +++++++++ pkg/tcpip/stack/ndp_test.go | 4 +- pkg/tcpip/stack/nic.go | 48 ++++++++-------- pkg/tcpip/stack/registration.go | 64 +++++++++++++--------- pkg/tcpip/stack/stack.go | 4 +- pkg/tcpip/stack/stack_test.go | 50 +++++++++++------ pkg/tcpip/stack/transport_demuxer.go | 53 +++++++++--------- pkg/tcpip/stack/transport_demuxer_test.go | 4 +- pkg/tcpip/stack/transport_test.go | 34 ++++++++---- pkg/tcpip/transport/icmp/endpoint.go | 18 +++--- pkg/tcpip/transport/icmp/protocol.go | 2 +- pkg/tcpip/transport/packet/endpoint.go | 19 ++++--- pkg/tcpip/transport/raw/endpoint.go | 17 +++--- pkg/tcpip/transport/tcp/endpoint.go | 6 +- pkg/tcpip/transport/tcp/forwarder.go | 5 +- pkg/tcpip/transport/tcp/protocol.go | 4 +- pkg/tcpip/transport/tcp/segment.go | 5 +- pkg/tcpip/transport/tcp/testing/context/context.go | 16 ++++-- pkg/tcpip/transport/udp/endpoint.go | 20 +++---- pkg/tcpip/transport/udp/forwarder.go | 9 ++- pkg/tcpip/transport/udp/protocol.go | 30 +++++----- pkg/tcpip/transport/udp/udp_test.go | 19 +++++-- test/syscalls/linux/raw_socket_icmp.cc | 2 +- 48 files changed, 542 insertions(+), 330 deletions(-) create mode 100644 pkg/tcpip/packet_buffer.go create mode 100644 pkg/tcpip/packet_buffer_state.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 3c2b2b5ea..65d4d0cd8 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -6,6 +6,8 @@ package(licenses = ["notice"]) go_library( name = "tcpip", srcs = [ + "packet_buffer.go", + "packet_buffer_state.go", "tcpip.go", "time_unsafe.go", ], diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 14f197a77..22eefb564 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -65,14 +65,14 @@ func (e *Endpoint) Drain() int { } } -// Inject injects an inbound packet. -func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.InjectLinkAddr(protocol, "", vv) +// InjectInbound injects an inbound packet. +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, vv.Clone(nil), nil /* linkHeader */) +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) } // Attach saves the stack network-layer dispatcher for use later when packets diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index ae4858529..edef7db26 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -598,8 +598,8 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } // InjectInbound injects an inbound packet. -func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */) +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt) } // NewInjectable creates a new fd-based InjectableEndpoint. diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index e7c05ca4f..7e08e033b 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -43,10 +43,9 @@ const ( ) type packetInfo struct { - raddr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - contents buffer.VectorisedView - linkHeader buffer.View + raddr tcpip.LinkAddress + proto tcpip.NetworkProtocolNumber + contents tcpip.PacketBuffer } type context struct { @@ -93,8 +92,8 @@ func (c *context) cleanup() { syscall.Close(c.fds[1]) } -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { - c.ch <- packetInfo{remote, protocol, vv, linkHeader} +func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + c.ch <- packetInfo{remote, protocol, pkt} } func TestNoEthernetProperties(t *testing.T) { @@ -317,19 +316,21 @@ func TestDeliverPacket(t *testing.T) { select { case pi := <-c.ch: want := packetInfo{ - raddr: raddr, - proto: proto, - contents: buffer.View(b).ToVectorisedView(), - linkHeader: buffer.View(hdr), + raddr: raddr, + proto: proto, + contents: tcpip.PacketBuffer{ + Data: buffer.View(b).ToVectorisedView(), + LinkHeader: buffer.View(hdr), + }, } if !eth { want.proto = header.IPv4ProtocolNumber want.raddr = "" } - // want.contents will be a single view, - // so make pi do the same for the + // want.contents.Data will be a single + // view, so make pi do the same for the // DeepEqual check. - pi.contents = pi.contents.ToView().ToVectorisedView() + pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView() if !reflect.DeepEqual(want, pi) { t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) } diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 554d45715..62ed1e569 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -190,6 +190,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}), buffer.View(eth)) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, tcpip.PacketBuffer{ + Data: buffer.View(pkt).ToVectorisedView(), + LinkHeader: buffer.View(eth), + }) return true, nil } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 3331b6453..c67d684ce 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -139,10 +139,13 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - vv := buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)) - vv.TrimFront(d.e.hdrSize) + pkt := tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + LinkHeader: buffer.View(eth), + } + pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth)) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { @@ -293,9 +296,12 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - vv := buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)) - vv.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth)) + pkt := tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + LinkHeader: buffer.View(eth), + } + pkt.Data.TrimFront(d.e.hdrSize) + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index a3b48fa73..bc5d8a2f3 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -80,12 +80,13 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependa views := make([]buffer.View, 1, 1+len(payload.Views())) views[0] = hdr.View() views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) // Because we're immediately turning around and writing the packet back to the // rx path, we intentionally don't preserve the remote and local link // addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */) + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + }) return nil } @@ -105,7 +106,10 @@ func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { // There should be an ethernet header at the beginning of packet. linkHeader := header.Ethernet(packet.First()[:header.EthernetMinimumSize]) packet.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), packet, buffer.View(linkHeader)) + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), tcpip.PacketBuffer{ + Data: packet, + LinkHeader: buffer.View(linkHeader), + }) return nil } diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 682b60291..9a8e8ebfe 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -80,8 +80,8 @@ func (m *InjectableEndpoint) IsAttached() bool { } // InjectInbound implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */) +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt) } // WritePackets writes outbound packets to the appropriate diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 279e2b457..2bace5298 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -273,8 +273,11 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { } // Send packet up the stack. - eth := header.Ethernet(b) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), buffer.View(eth)) + eth := header.Ethernet(b[:header.EthernetMinimumSize]) + d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), tcpip.PacketBuffer{ + Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), + LinkHeader: buffer.View(eth), + }) } // Clean state. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index f3e9705c9..199406886 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -131,13 +131,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { +func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ - addr: remoteLinkAddr, - proto: proto, - vv: vv.Clone(nil), - linkHeader: linkHeader, + addr: remoteLinkAddr, + proto: proto, + vv: pkt.Data.Clone(nil), }) c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 39757ea2a..d71a03cd2 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -116,19 +116,19 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { +func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("recv", protocol, vv.First(), nil) + logPacket("recv", protocol, pkt.Data.First(), nil) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - vs := vv.Views() - length := vv.Size() + vs := pkt.Data.Views() + length := pkt.Data.Size() if length > int(e.maxPCAPLen) { length = int(e.maxPCAPLen) } buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil { + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(pkt.Data.Size()))); err != nil { panic(err) } for _, v := range vs { @@ -147,7 +147,7 @@ func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local panic(err) } } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv, linkHeader) + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) } // Attach implements the stack.LinkEndpoint interface. It saves the dispatcher diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index a04fc1062..b440970e0 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -50,12 +50,12 @@ func New(lower stack.LinkEndpoint) *Endpoint { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { +func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if !e.dispatchGate.Enter() { return } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv, linkHeader) + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) e.dispatchGate.Leave() } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 5f0f8fa2d..df2e70e54 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { +func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { e.dispatchCount++ } @@ -120,21 +120,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 46178459e..4161ebf87 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -92,8 +92,8 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - v := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 88b57ec03..47098bfdc 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -102,7 +102,9 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) + c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{ + Data: v.ToVectorisedView(), + }) } for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 666d8b92a..fe499d47e 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -96,16 +96,16 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { - t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress) +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { + t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - t.checkValues(trans, vv, remote, local) +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) } @@ -279,7 +279,9 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -367,7 +369,9 @@ func TestIPv4ReceiveControl(t *testing.T) { o.extra = c.expectedExtra vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, vv) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: vv, + }) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } @@ -430,13 +434,17 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, frag1.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: frag1.ToVectorisedView(), + }) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - ep.HandlePacket(&r, frag2.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: frag2.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -509,7 +517,9 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, view.ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view.ToVectorisedView(), + }) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -618,12 +628,12 @@ func TestIPv6ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view[:len(view)-c.trunc].ToVectorisedView() - // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, vv) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: view[:len(view)-c.trunc].ToVectorisedView(), + }) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 50b363dc4..ce771631c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,7 @@ package ipv4 import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -24,8 +25,8 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv4(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -39,7 +40,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } hlen := int(h.HeaderLength()) - if vv.Size() < hlen || h.FragmentOffset() != 0 { + if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -47,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. } // Skip the ip header, then deliver control message. - vv.TrimFront(hlen) + pkt.Data.TrimFront(hlen) p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := vv.First() + v := pkt.Data.First() if len(v) < header.ICMPv4MinimumSize { received.Invalid.Increment() return @@ -73,20 +74,23 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // checksum. We'll have to reset this before we hand the packet // off. h.SetChecksum(0) - gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) if gotChecksum != wantChecksum { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) received.Invalid.Increment() return } // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, tcpip.PacketBuffer{ + Data: pkt.Data.Clone(nil), + NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), + }) - vv := vv.Clone(nil) + vv := pkt.Data.Clone(nil) vv.TrimFront(header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) @@ -104,19 +108,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv4EchoReply: received.EchoReply.Increment() - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() - vv.TrimFront(header.ICMPv4MinimumSize) + pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) } case header.ICMPv4SrcQuench: diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1339f8474..26f1402ed 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -198,7 +198,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff return nil } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) { +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) length := uint16(hdr.UsedLength() + payloadSize) id := uint32(0) @@ -218,19 +218,24 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) + return ip } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { - e.addIPHeader(r, &hdr, payload.Size(), params) + ip := e.addIPHeader(r, &hdr, payload.Size(), params) if loop&stack.PacketLoop != 0 { views := make([]buffer.View, 1, 1+len(payload.Views())) views[0] = hdr.View() views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + NetworkHeader: buffer.View(ip), + }) + loopedR.Release() } if loop&stack.PacketOut == 0 { @@ -301,7 +306,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect ip.SetChecksum(^ip.CalculateChecksum()) if loop&stack.PacketLoop != 0 { - e.HandlePacket(r, payload) + e.HandlePacket(r, tcpip.PacketBuffer{ + Data: payload, + NetworkHeader: buffer.View(ip), + }) } if loop&stack.PacketOut == 0 { return nil @@ -314,22 +322,23 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + headerView := pkt.Data.First() h := header.IPv4(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } + pkt.NetworkHeader = headerView[:h.HeaderLength()] hlen := int(h.HeaderLength()) tlen := int(h.TotalLength()) - vv.TrimFront(hlen) - vv.CapLength(tlen - hlen) + pkt.Data.TrimFront(hlen) + pkt.Data.CapLength(tlen - hlen) more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 if more || h.FragmentOffset() != 0 { - if vv.Size() == 0 { + if pkt.Data.Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -337,10 +346,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { return } // The packet is a fragment, let's try to reassemble it. - last := h.FragmentOffset() + uint16(vv.Size()) - 1 + last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1 // Drop the packet if the fragmentOffset is incorrect. i.e the - // combination of fragmentOffset and vv.size() causes a wrap - // around resulting in last being less than the offset. + // combination of fragmentOffset and pkt.Data.size() causes a + // wrap around resulting in last being less than the offset. if last < h.FragmentOffset() { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -348,7 +357,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { } var ready bool var err error - vv, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -361,11 +370,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { headerView.CapLength(hlen) - e.handleICMP(r, headerView, vv) + e.handleICMP(r, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 99f84acd7..f100d84ee 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -464,7 +464,9 @@ func TestInvalidFragments(t *testing.T) { s.CreateNIC(nicid, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, buffer.NewVectorisedView(len(pkt), []buffer.View{pkt})) + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), + }) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 05e8c075b..58f8e80df 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -25,8 +25,8 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - h := header.IPv6(vv.First()) +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { + h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -40,10 +40,10 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip the IP header, then handle the fragmentation header if there // is one. - vv.TrimFront(header.IPv6MinimumSize) + pkt.Data.TrimFront(header.IPv6MinimumSize) p := h.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(vv.First()) + f := header.IPv6Fragment(pkt.Data.First()) if !f.IsValid() || f.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. @@ -52,19 +52,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer. // Skip fragmentation header and find out the actual protocol // number. - vv.TrimFront(header.IPv6FragmentHeaderSize) + pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) p = f.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.PacketBuffer) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := vv.First() + v := pkt.Data.First() if len(v) < header.ICMPv6MinimumSize { received.Invalid.Increment() return @@ -77,7 +77,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V // Only the first view in vv is accounted for by h. To account for the // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. - payload := vv + payload := pkt.Data payload.RemoveFirst() if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() @@ -113,9 +113,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) + pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) mtu := h.MTU() - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() @@ -123,10 +123,10 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize) + pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch h.Code() { case header.ICMPv6PortUnreachable: - e.handleControl(stack.ControlPortUnreachable, 0, vv) + e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: @@ -189,9 +189,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]), } hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + packet.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(packet.NDPPayload()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(targetAddr) @@ -209,7 +209,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V r := r.Clone() defer r.Release() r.LocalAddress = targetAddr - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) // TODO(tamird/ghanan): there exists an explicit NDP option that is // used to update the neighbor table with link addresses for a @@ -285,13 +285,13 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - vv.TrimFront(header.ICMPv6EchoMinimumSize) + pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv6EchoReply) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) - if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + copy(packet, h) + packet.SetType(header.ICMPv6EchoReply) + packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) + if err := r.WritePacket(nil /* gso */, hdr, pkt.Data, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { sent.Dropped.Increment() return } @@ -303,7 +303,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V received.Invalid.Increment() return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: received.TimeExceeded.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d686f79ce..6037a1ef8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -65,7 +65,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, buffer.View, buffer.VectorisedView) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) { } type stubLinkAddressCache struct { @@ -147,7 +147,9 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, hdr.View().ToVectorisedView()) + ep.HandlePacket(&r, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } for _, typ := range types { @@ -280,7 +282,9 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. views := []buffer.View{pkt.Header, pkt.Payload} size := len(pkt.Header) + len(pkt.Payload) vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), vv) + args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ + Data: vv, + }) } if pkt.Proto != ProtocolNumber { @@ -498,7 +502,9 @@ func TestICMPChecksumValidationSimple(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } stats := s.Stats().ICMP.V6PacketsReceived @@ -673,7 +679,9 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } stats := s.Stats().ICMP.V6PacketsReceived @@ -849,9 +857,9 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.Inject(ProtocolNumber, - buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, - []buffer.View{hdr.View(), payload})) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), + }) } stats := s.Stats().ICMP.V6PacketsReceived diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 5898f8f9e..805d1739c 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -97,7 +97,7 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) { +func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 { length := uint16(hdr.UsedLength() + payloadSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ @@ -108,19 +108,24 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) + return ip } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { - e.addIPHeader(r, &hdr, payload.Size(), params) + ip := e.addIPHeader(r, &hdr, payload.Size(), params) if loop&stack.PacketLoop != 0 { views := make([]buffer.View, 1, 1+len(payload.Views())) views[0] = hdr.View() views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, vv) + + e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + NetworkHeader: buffer.View(ip), + }) + loopedR.Release() } if loop&stack.PacketOut == 0 { @@ -160,24 +165,25 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vector // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { - headerView := vv.First() +func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { + headerView := pkt.Data.First() h := header.IPv6(headerView) - if !h.IsValid(vv.Size()) { + if !h.IsValid(pkt.Data.Size()) { return } - vv.TrimFront(header.IPv6MinimumSize) - vv.CapLength(int(h.PayloadLength())) + pkt.NetworkHeader = headerView[:header.IPv6MinimumSize] + pkt.Data.TrimFront(header.IPv6MinimumSize) + pkt.Data.CapLength(int(h.PayloadLength())) p := h.TransportProtocol() if p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, vv) + e.handleICMP(r, headerView, pkt) return } r.Stats().IP.PacketsDelivered.Increment() - e.dispatcher.DeliverTransportPacket(r, p, headerView, vv) + e.dispatcher.DeliverTransportPacket(r, p, pkt) } // Close cleans up resources associated with the endpoint. diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index deaa9b7f3..1cbfa7278 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -55,7 +55,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) stats := s.Stats().ICMP.V6PacketsReceived @@ -111,7 +113,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.Inject(ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) stat := s.Stats().UDP.PacketsReceived diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 69ab7ba12..0dbce14a0 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -98,7 +98,9 @@ func TestHopLimitValidation(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, hdr.View().ToVectorisedView()) + ep.HandlePacket(r, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) } types := []struct { @@ -345,7 +347,9 @@ func TestRouterAdvertValidation(t *testing.T) { t.Fatalf("got rxRA = %d, want = 0", got) } - e.Inject(header.IPv6ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) if test.expectedSuccess { if got := invalid.Value(); got != 0 { diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go new file mode 100644 index 000000000..10b04239d --- /dev/null +++ b/pkg/tcpip/packet_buffer.go @@ -0,0 +1,54 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// A PacketBuffer contains all the data of a network packet. +// +// As a PacketBuffer traverses up the stack, it may be necessary to pass it to +// multiple endpoints. Clone() should be called in such cases so that +// modifications to the Data field do not affect other copies. +// +// +stateify savable +type PacketBuffer struct { + // Data holds the payload of the packet. For inbound packets, it also + // holds the headers, which are consumed as the packet moves up the + // stack. Headers are guaranteed not to be split across views. + // + // The bytes backing Data are immutable, but Data itself may be trimmed + // or otherwise modified. + Data buffer.VectorisedView + + // The bytes backing these views are immutable. Each field may be nil + // if either it has not been set yet or no such header exists (e.g. + // packets sent via loopback may not have a link header). + // + // These fields may be Views into other Views. SR dosen't support this, + // so deep copies are necessary in some cases. + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View +} + +// Clone makes a copy of pk. It clones the Data field, which creates a new +// VectorisedView but does not deep copy the underlying bytes. +func (pk PacketBuffer) Clone() PacketBuffer { + return PacketBuffer{ + Data: pk.Data.Clone(nil), + LinkHeader: pk.LinkHeader, + NetworkHeader: pk.NetworkHeader, + TransportHeader: pk.TransportHeader, + } +} diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/packet_buffer_state.go new file mode 100644 index 000000000..04c4cf136 --- /dev/null +++ b/pkg/tcpip/packet_buffer_state.go @@ -0,0 +1,26 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// beforeSave is invoked by stateify. +func (pk *PacketBuffer) beforeSave() { + // Non-Data fields may be slices of the Data field. This causes + // problems for SR, so during save we make each header independent. + pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) + pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) + pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 525a25218..cc789b5af 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -328,7 +328,9 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.Inject(header.IPv6ProtocolNumber, hdr.View().ToVectorisedView()) + e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) if got := stat.Value(); got != 1 { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 12969c74e..28a28ae6e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -723,10 +723,10 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { return nil } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr - ref.ep.HandlePacket(&r, vv) + ref.ep.HandlePacket(&r, pkt) ref.decRef() } @@ -736,9 +736,9 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size())) + n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) netProto, ok := n.stack.networkProtocols[protocol] if !ok { @@ -763,22 +763,22 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } n.mu.RUnlock() for _, ep := range packetEPs { - ep.HandlePacket(n.id, local, protocol, vv.Clone(nil), linkHeader) + ep.HandlePacket(n.id, local, protocol, pkt.Clone()) } if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { n.stack.stats.IP.PacketsReceived.Increment() } - if len(vv.First()) < netProto.MinimumPacketSize() { + if len(pkt.Data.First()) < netProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(vv.First()) + src, dst := netProto.ParseAddresses(pkt.Data.First()) if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) return } @@ -806,20 +806,20 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link if ok { r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. - ref.ep.HandlePacket(&r, vv) + ref.ep.HandlePacket(&r, pkt) ref.decRef() } else { // n doesn't have a destination endpoint. // Send the packet out of n. - hdr := buffer.NewPrependableFromView(vv.First()) - vv.RemoveFirst() + hdr := buffer.NewPrependableFromView(pkt.Data.First()) + pkt.Data.RemoveFirst() // TODO(b/128629022): use route.WritePacket. - if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil { + if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, pkt.Data, protocol); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + vv.Size())) + n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + pkt.Data.Size())) } } return @@ -833,7 +833,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -845,41 +845,41 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) + n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(vv.First()) < transProto.MinimumPacketSize() { + if len(pkt.Data.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { + if n.stack.demux.deliverPacket(r, protocol, pkt, id) { return } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, netHeader, vv) { + if state.defaultHandler(r, id, pkt) { return } } // We could not find an appropriate destination for this packet, so // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) { + if !transProto.HandleUnknownDestinationPacket(r, id, pkt) { n.stack.stats.MalformedRcvdPackets.Increment() } } // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return @@ -890,17 +890,17 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(vv.First()) < 8 { + if len(pkt.Data.First()) < 8 { return } - srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { return } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) { + if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) { return } } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index d7c124e81..5806d294c 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -64,16 +64,15 @@ type TransportEndpoint interface { UniqueID() uint64 // HandlePacket is called by the stack when new packets arrive to - // this transport endpoint. + // this transport endpoint. It sets pkt.TransportHeader. // - // HandlePacket takes ownership of vv. - HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) - // HandleControlPacket is called by the stack when new control (e.g., + // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. - // - // HandleControlPacket takes ownership of vv. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) + // HandleControlPacket takes ownership of pkt. + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) // Close puts the endpoint in a closed state and frees all resources // associated with it. This cleanup may happen asynchronously. Wait can @@ -99,8 +98,8 @@ type RawTransportEndpoint interface { // this transport endpoint. The packet contains all data from the link // layer up. // - // HandlePacket takes ownership of packet and netHeader. - HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView) + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, pkt tcpip.PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -117,8 +116,8 @@ type PacketEndpoint interface { // linkHeader may have a length of 0, in which case the PacketEndpoint // should construct its own ethernet header for applications. // - // HandlePacket takes ownership of packet and linkHeader. - HandlePacket(nicid tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, packet buffer.VectorisedView, linkHeader buffer.View) + // HandlePacket takes ownership of pkt. + HandlePacket(nicid tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -148,7 +147,9 @@ type TransportProtocol interface { // // The return value indicates whether the packet was well-formed (for // stats purposes only). - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool + // + // HandleUnknownDestinationPacket takes ownership of pkt. + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -166,17 +167,21 @@ type TransportProtocol interface { // the network layer. type TransportDispatcher interface { // DeliverTransportPacket delivers packets to the appropriate - // transport protocol endpoint. It also returns the network layer - // header for the enpoint to inspect or pass up the stack. + // transport protocol endpoint. + // + // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // - // DeliverTransportPacket takes ownership of vv and netHeader. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) + // DeliverTransportPacket takes ownership of pkt. + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. // - // DeliverTransportControlPacket takes ownership of vv. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) + // pkt.NetworkHeader must be set before calling + // DeliverTransportControlPacket. + // + // DeliverTransportControlPacket takes ownership of pkt. + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -248,10 +253,10 @@ type NetworkEndpoint interface { NICID() tcpip.NICID // HandlePacket is called by the link layer when new packets arrive to - // this network endpoint. + // this network endpoint. It sets pkt.NetworkHeader. // - // HandlePacket takes ownership of vv. - HandlePacket(r *Route, vv buffer.VectorisedView) + // HandlePacket takes ownership of pkt. + HandlePacket(r *Route, pkt tcpip.PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -294,11 +299,14 @@ type NetworkProtocol interface { // the data link layer. type NetworkDispatcher interface { // DeliverNetworkPacket finds the appropriate network protocol endpoint - // and hands the packet over for further processing. linkHeader may have - // length 0 when the caller does not have ethernet data. + // and hands the packet over for further processing. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverNetworkPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. // - // DeliverNetworkPacket takes ownership of vv and linkHeader. - DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) + // DeliverNetworkPacket takes ownership of pkt. + DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -329,7 +337,9 @@ const ( // LinkEndpoint is the interface implemented by data link layer protocols (e.g., // ethernet, loopback, raw) and used by network layer protocols to send packets -// out through the implementer's data link endpoint. +// out through the implementer's data link endpoint. When a link header exists, +// it sets each tcpip.PacketBuffer's LinkHeader field before passing it up the +// stack. type LinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is // usually dictated by the backing physical network; when such a @@ -395,7 +405,7 @@ type InjectableLinkEndpoint interface { LinkEndpoint // InjectInbound injects an inbound packet. - InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) // InjectOutbound writes a fully formed outbound packet directly to the // link. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 8b141cafd..08599d765 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -51,7 +51,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -641,7 +641,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, tcpip.PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9dae853d0..1fac5477f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -86,28 +86,28 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := vv.First() - vv.TrimFront(fakeNetHeaderLen) + b := pkt.Data.First() + pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := vv.First() + nb := pkt.Data.First() if len(nb) < fakeNetHeaderLen { return } - vv.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv) + pkt.Data.TrimFront(fakeNetHeaderLen) + f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -138,7 +138,9 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu views[0] = hdr.View() views = append(views, payload.Views()...) vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) - f.HandlePacket(r, vv) + f.HandlePacket(r, tcpip.PacketBuffer{ + Data: vv, + }) } if loop&stack.PacketOut == 0 { return nil @@ -259,7 +261,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -269,7 +273,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -279,7 +285,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -288,7 +296,9 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -298,7 +308,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -373,7 +385,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -1795,7 +1809,9 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1855,7 +1871,9 @@ func TestNICForwarding(t *testing.T) { // Send a packet to address 3. buf := buffer.NewView(30) buf[0] = 3 - ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) select { case <-ep2.C: diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index ccd3d030e..594570216 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -21,7 +21,6 @@ import ( "sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -86,7 +85,7 @@ func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { +func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { epsByNic.mu.RLock() mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] @@ -100,18 +99,18 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, v // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. if isMulticastOrBroadcast(id.LocalAddress) { - mpep.handlePacketAll(r, id, vv) + mpep.handlePacketAll(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. return } // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv) + selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { +func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { epsByNic.mu.RLock() defer epsByNic.mu.RUnlock() @@ -127,7 +126,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv) + selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns @@ -258,18 +257,16 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpointsArr[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { ep.mu.RLock() for i, endpoint := range ep.endpointsArr { - // HandlePacket modifies vv, so each endpoint needs its own copy except for - // the final one. + // HandlePacket takes ownership of pkt, so each endpoint needs + // its own copy except for the final one. if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) + endpoint.HandlePacket(r, id, pkt) break } - vvCopy := buffer.NewView(vv.Size()) - copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) + endpoint.HandlePacket(r, id, pkt.Clone()) } ep.mu.RUnlock() // Don't use defer for performance reasons. } @@ -395,7 +392,7 @@ var loopbackSubnet = func() tcpip.Subnet { // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if it // found one or more endpoints, false otherwise. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -408,8 +405,8 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // transport endpoints. var destEps []*endpointsByNic if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { - destEps = d.findAllEndpointsLocked(eps, vv, id) - } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil { + destEps = d.findAllEndpointsLocked(eps, id) + } else if ep := d.findEndpointLocked(eps, id); ep != nil { destEps = append(destEps, ep) } @@ -424,17 +421,19 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return false } - // Deliver the packet. - for _, ep := range destEps { - ep.handlePacket(r, id, vv) + // HandlePacket takes ownership of pkt, so each endpoint needs its own + // copy except for the final one. + for _, ep := range destEps[:len(destEps)-1] { + ep.handlePacket(r, id, pkt.Clone()) } + destEps[len(destEps)-1].handlePacket(r, id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -448,7 +447,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView()) + rawEP.HandlePacket(r, pkt) foundRaw = true } eps.mu.RUnlock() @@ -458,7 +457,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -466,7 +465,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco // Try to find the endpoint. eps.mu.RLock() - ep := d.findEndpointLocked(eps, vv, id) + ep := d.findEndpointLocked(eps, id) eps.mu.RUnlock() // Fail if we didn't find one. @@ -475,12 +474,12 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco } // Deliver the packet. - ep.handleControlPacket(n, id, typ, extra, vv) + ep.handleControlPacket(n, id, typ, extra, pkt) return true } -func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic { +func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { var matchedEPs []*endpointsByNic // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { @@ -514,8 +513,8 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv bu // findEndpointLocked returns the endpoint that most closely matches the given // id. -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { - if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 { +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { + if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 { return matchedEPs[0] } return nil diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 210233dc0..f54117c4e 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -156,7 +156,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } func TestTransportDemuxerRegister(t *testing.T) { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 203e79f56..2cacea99a 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -197,7 +197,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -214,7 +214,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -271,7 +271,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } @@ -342,7 +342,9 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -351,7 +353,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -360,7 +364,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } @@ -413,7 +419,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -422,7 +430,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -431,7 +441,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } @@ -584,7 +596,9 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: req.ToVectorisedView(), + }) aep, _, err := ep.Accept() if err != nil || aep == nil { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 33405eb7d..0092d0ea9 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -718,18 +718,18 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(vv.First()) + h := header.ICMPv4(pkt.Data.First()) if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(vv.First()) + h := header.ICMPv6(pkt.Data.First()) if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -757,19 +757,19 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &icmpPacket{ + packet := &icmpPacket{ senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, }, } - pkt.data = vv + packet.data = pkt.Data - e.rcvList.PushBack(pkt) - e.rcvBufSize += pkt.data.Size() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() - pkt.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() @@ -780,7 +780,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index bfb16f7c3..9ce500e80 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index ead83b83d..26335094e 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -266,7 +266,7 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } // HandlePacket implements stack.PacketEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, ethHeader buffer.View) { +func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { ep.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -289,9 +289,9 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet // TODO(b/129292371): Return network protocol. - if len(ethHeader) > 0 { + if len(pkt.LinkHeader) > 0 { // Get info directly from the ethernet header. - hdr := header.Ethernet(ethHeader) + hdr := header.Ethernet(pkt.LinkHeader) packet.senderAddr = tcpip.FullAddress{ NIC: nicid, Addr: tcpip.Address(hdr.SourceAddress()), @@ -306,11 +306,12 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, if ep.cooked { // Cooked packets can simply be queued. - packet.data = vv + packet.data = pkt.Data } else { // Raw packets need their ethernet headers prepended before // queueing. - if len(ethHeader) == 0 { + var linkHeader buffer.View + if len(pkt.LinkHeader) == 0 { // We weren't provided with an actual ethernet header, // so fake one. ethFields := header.EthernetFields{ @@ -320,10 +321,12 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, } fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) fakeHeader.Encode(ðFields) - ethHeader = buffer.View(fakeHeader) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader...) } - combinedVV := buffer.View(ethHeader).ToVectorisedView() - combinedVV.Append(vv) + combinedVV := linkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) packet.data = combinedVV } packet.timestampNS = ep.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 23922a30e..230a1537a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -555,7 +555,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -596,20 +596,21 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &rawPacket{ + packet := &rawPacket{ senderAddr: tcpip.FullAddress{ NIC: route.NICID(), Addr: route.RemoteAddress, }, } - combinedVV := netHeader.ToVectorisedView() - combinedVV.Append(vv) - pkt.data = combinedVV - pkt.timestampNS = e.stack.NowNanoseconds() + networkHeader := append(buffer.View(nil), pkt.NetworkHeader...) + combinedVV := networkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + packet.timestampNS = e.stack.NowNanoseconds() - e.rcvList.PushBack(pkt) - e.rcvBufSize += pkt.data.Size() + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a1efd8d55..e31464c9b 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2029,8 +2029,8 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { - s := newSegment(r, id, vv) +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + s := newSegment(r, id, pkt) if !s.parse() { e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() @@ -2065,7 +2065,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 63666f0b3..4983bca81 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -18,7 +18,6 @@ import ( "sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -63,8 +62,8 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() // We only care about well-formed SYN packets. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index db40785d3..c4f1a84bb 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -126,8 +126,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { - s := newSegment(r, id, vv) +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { + s := newSegment(r, id, pkt) defer s.decRef() if !s.parse() || !s.csumValid { diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index c4a89525e..1c10da5ca 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -60,13 +61,13 @@ type segment struct { xmitTime time.Time `state:".(unixTime)"` } -func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, route: r.Clone(), } - s.data = vv.Clone(s.views[:]) + s.data = pkt.Data.Clone(s.views[:]) s.rcvdTime = time.Now() return s } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index ef823e4ae..4854e719d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -302,7 +302,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } // BuildSegment builds a TCP segment based on the given Headers and payload. @@ -350,13 +352,17 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.Inject(ipv4.ProtocolNumber, s) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: s, + }) } // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h)) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: c.BuildSegment(payload, h), + }) } // SendAck sends an ACK packet. @@ -518,7 +524,9 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } // CreateConnected creates a connected TCP endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 03bd5c8fd..4e11de9db 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1158,17 +1158,17 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(vv.First()) - if int(hdr.Length()) > vv.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } - vv.TrimFront(header.UDPMinimumSize) + pkt.Data.TrimFront(header.UDPMinimumSize) e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() @@ -1192,18 +1192,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv wasEmpty := e.rcvBufSize == 0 // Push new packet into receive list and increment the buffer size. - pkt := &udpPacket{ + packet := &udpPacket{ senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, Port: hdr.SourcePort(), }, } - pkt.data = vv - e.rcvList.PushBack(pkt) - e.rcvBufSize += vv.Size() + packet.data = pkt.Data + e.rcvList.PushBack(packet) + e.rcvBufSize += pkt.Data.Size() - pkt.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() @@ -1214,7 +1214,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { } // State implements tcpip.Endpoint.State. diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index d399ec722..fc706ede2 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -16,7 +16,6 @@ package udp import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -44,12 +43,12 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, route: r, id: id, - vv: vv, + pkt: pkt, }) return true @@ -62,7 +61,7 @@ type ForwarderRequest struct { stack *stack.Stack route *stack.Route id stack.TransportEndpointID - vv buffer.VectorisedView + pkt tcpip.PacketBuffer } // ID returns the 4-tuple (src address, src port, dst address, dst port) that @@ -90,7 +89,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.rcvReady = true ep.rcvMu.Unlock() - ep.HandlePacket(r.route, r.id, r.vv) + ep.HandlePacket(r.route, r.id, r.pkt) return ep, nil } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 5c3358a5e..43f11b700 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -66,10 +66,10 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(vv.First()) - if int(hdr.Length()) > vv.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true @@ -116,20 +116,18 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(netHeader) + vv.Size() + payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - // The buffers used by vv and netHeader may be used elsewhere - // in the system. For example, a raw or packet socket may use - // what UDP considers an unreachable destination. Thus we deep - // copy vv and netHeader to prevent multiple ownership and SR - // errors. - newNetHeader := make(buffer.View, len(netHeader)) - copy(newNetHeader, netHeader) - payload := buffer.NewVectorisedView(len(newNetHeader), []buffer.View{newNetHeader}) - payload.Append(vv.ToView().ToVectorisedView()) + // The buffers used by pkt may be used elsewhere in the system. + // For example, a raw or packet socket may use what UDP + // considers an unreachable destination. Thus we deep copy pkt + // to prevent multiple ownership and SR errors. + newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...) + payload := newNetHeader.ToVectorisedView() + payload.Append(pkt.Data.ToView().ToVectorisedView()) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) @@ -158,12 +156,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(netHeader) + vv.Size() + payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader}) - payload.Append(vv) + payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader}) + payload.Append(pkt.Data) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index b724d788c..30ee9801b 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -397,7 +397,8 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) // Initialize the IP header. ip := header.IPv6(buf) @@ -431,7 +432,11 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) } // injectV4Packet creates a V4 test packet with the given payload and header @@ -441,7 +446,8 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) - copy(buf[len(buf)-len(payload):], payload) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) // Initialize the IP header. ip := header.IPv4(buf) @@ -471,7 +477,12 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) + + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) } func newPayload() []byte { diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc index 8bcaba6f1..3de898df7 100644 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ b/test/syscalls/linux/raw_socket_icmp.cc @@ -129,7 +129,7 @@ TEST_F(RawSocketICMPTest, SendAndReceiveBadChecksum) { EXPECT_THAT(RetryEINTR(recv)(s_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT), SyscallFailsWithErrno(EAGAIN)); } -// + // Send and receive an ICMP packet. TEST_F(RawSocketICMPTest, SendAndReceive) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); -- cgit v1.2.3 From e63db5e7bbf8decc6f799965f54fcf7aa6673527 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 6 Nov 2019 16:28:25 -0800 Subject: Discover default routers from Router Advertisements This change allows the netstack to do NDP's Router Discovery as outlined by RFC 4861 section 6.3.4. Note, this change will not break existing uses of netstack as the default configuration for the stack options is set in such a way that Router Discovery will not be performed. See `stack.Options` and `stack.NDPConfigurations` for more details. This change introduces 2 options required to take advantage of Router Discovery, all available under NDPConfigurations: - HandleRAs: Whether or not NDP RAs are processes - DiscoverDefaultRouters: Whether or not Router Discovery is performed Another note: for a NIC to process Router Advertisements, it must not be a router itself. Currently the netstack does not have per-interface routing configuration; the routing/forwarding configuration is controlled stack-wide. Therefore, if the stack is configured to enable forwarding/routing, no Router Advertisements will be processed. Tests: Unittest to make sure that Router Discovery and updates to the routing table only occur if explicitly configured to do so. Unittest to make sure at max stack.MaxDiscoveredDefaultRouters discovered default routers are remembered. PiperOrigin-RevId: 278965143 --- pkg/tcpip/header/ipv6.go | 4 +- pkg/tcpip/stack/ndp.go | 166 ++++++++++++++++- pkg/tcpip/stack/ndp_test.go | 426 ++++++++++++++++++++++++++++++++++++++++++-- pkg/tcpip/tcpip.go | 7 + 4 files changed, 586 insertions(+), 17 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index f1e60911b..0caa51c1e 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -92,7 +92,9 @@ const ( IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" ) -// IPv6EmptySubnet is the empty IPv6 subnet. +// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the +// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to +// be contained within this subnet. var IPv6EmptySubnet = func() tcpip.Subnet { subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any)) if err != nil { diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index d5352bb5f..a216242d8 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -67,6 +67,9 @@ const ( // default routers. The stack should stop discovering new routers after // discovering MaxDiscoveredDefaultRouters routers. // + // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and + // SHOULD be more. + // // Max = 10. MaxDiscoveredDefaultRouters = 10 ) @@ -85,6 +88,24 @@ type NDPDispatcher interface { // This function is permitted to block indefinitely without interfering // with the stack's operation. OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) + + // OnDefaultRouterDiscovered will be called when a new default router is + // discovered. Implementations must return true along with a new valid + // route table if the newly discovered router should be remembered. If + // an implementation returns false, the second return value will be + // ignored. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDefaultRouterDiscovered(nicid tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) + + // OnDefaultRouterInvalidated will be called when a discovered default + // router is invalidated. Implementers must return a new valid route + // table. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnDefaultRouterInvalidated(nicid tcpip.NICID, addr tcpip.Address) []tcpip.Route } // NDPConfigurations is the NDP configurations for the netstack. @@ -165,6 +186,22 @@ type dadState struct { // a Router Advertisement. type defaultRouterState struct { invalidationTimer *time.Timer + + // Used to signal the timer not to invalidate the default router (R) in + // a race condition (T1 is a goroutine that handles an RA from R and T2 + // is the goroutine that handles R's invalidation timer firing): + // T1: Receive a new RA from R + // T1: Obtain the NIC's lock before processing the RA + // T2: R's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends R's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates R immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // signal the timer using this channel to not invalidate R, so that once + // T2 obtains the lock, it will see that there is an event on this + // channel and do nothing further. + doNotInvalidateC chan struct{} } // startDuplicateAddressDetection performs Duplicate Address Detection. @@ -361,16 +398,137 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { } // handleRA handles a Router Advertisement message that arrived on the NIC -// this ndp is for. +// this ndp is for. Does nothing if the NIC is configured to not handle RAs. // -// The NIC that ndp belongs to MUST be locked. +// The NIC that ndp belongs to and its associated stack MUST be locked. func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // Is the NIC configured to handle RAs at all? - if !ndp.configs.HandleRAs { + // + // Currently, the stack does not determine router interface status on a + // per-interface basis; it is a stack-wide configuration, so we check + // stack's forwarding flag to determine if the NIC is a routing + // interface. + if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding { return } - // TODO(b/140882146): Do Router Discovery. + // Is the NIC configured to discover default routers? + if ndp.configs.DiscoverDefaultRouters { + rtr, ok := ndp.defaultRouters[ip] + rl := ra.RouterLifetime() + switch { + case !ok && rl != 0: + // This is a new default router we are discovering. + // + // Only remember it if we currently know about less than + // MaxDiscoveredDefaultRouters routers. + if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters { + ndp.rememberDefaultRouter(ip, rl) + } + + case ok && rl != 0: + // This is an already discovered default router. Update + // the invalidation timer. + timer := rtr.invalidationTimer + + // We should ALWAYS have an invalidation timer for a + // discovered router. + if timer == nil { + panic("ndphandlera: RA invalidation timer should not be nil") + } + + if !timer.Stop() { + // If we reach this point, then we know the + // timer fired after we already took the NIC + // lock. Signal the timer so that once it + // obtains the lock, it doesn't actually + // invalidate the router as we just got a new + // RA that refreshes its lifetime to a non-zero + // value. See + // defaultRouterState.doNotInvalidateC for more + // details. + rtr.doNotInvalidateC <- struct{}{} + } + + timer.Reset(rl) + + case ok && rl == 0: + // We know about the router but it is no longer to be + // used as a default router so invalidate it. + ndp.invalidateDefaultRouter(ip) + } + } + // TODO(b/140948104): Do Prefix Discovery. // TODO(b/141556115): Do Parameter Discovery. } + +// invalidateDefaultRouter invalidates a discovered default router. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { + rtr, ok := ndp.defaultRouters[ip] + + // Is the router still discovered? + if !ok { + // ...Nope, do nothing further. + return + } + + rtr.invalidationTimer.Stop() + rtr.invalidationTimer = nil + close(rtr.doNotInvalidateC) + rtr.doNotInvalidateC = nil + + delete(ndp.defaultRouters, ip) + + // Let the integrator know a discovered default router is invalidated. + if ndp.nic.stack.ndpDisp != nil { + ndp.nic.stack.routeTable = ndp.nic.stack.ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) + } +} + +// rememberDefaultRouter remembers a newly discovered default router with IPv6 +// link-local address ip with lifetime rl. +// +// The router identified by ip MUST NOT already be known by the NIC. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { + if ndp.nic.stack.ndpDisp == nil { + return + } + + // Inform the integrator when we discovered a default router. + remember, routeTable := ndp.nic.stack.ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) + if !remember { + // Informed by the integrator to not remember the router, do + // nothing further. + return + } + + // Used to signal the timer not to invalidate the default router (R) in + // a race condition. See defaultRouterState.doNotInvalidateC for more + // details. + doNotInvalidateC := make(chan struct{}, 1) + + ndp.defaultRouters[ip] = defaultRouterState{ + invalidationTimer: time.AfterFunc(rl, func() { + ndp.nic.stack.mu.Lock() + defer ndp.nic.stack.mu.Unlock() + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + select { + case <-doNotInvalidateC: + return + default: + } + + ndp.invalidateDefaultRouter(ip) + }), + doNotInvalidateC: doNotInvalidateC, + } + + ndp.nic.stack.routeTable = routeTable +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index cc789b5af..0dbe4da9d 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -15,9 +15,12 @@ package stack_test import ( + "encoding/binary" + "fmt" "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -29,10 +32,19 @@ import ( ) const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" - linkAddr1 = "\x02\x02\x03\x04\x05\x06" + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + linkAddr1 = "\x02\x02\x03\x04\x05\x06" + linkAddr2 = "\x02\x02\x03\x04\x05\x07" + linkAddr3 = "\x02\x02\x03\x04\x05\x08" + defaultTimeout = 250 * time.Millisecond +) + +var ( + llAddr1 = header.LinkLocalAddr(linkAddr1) + llAddr2 = header.LinkLocalAddr(linkAddr2) + llAddr3 = header.LinkLocalAddr(linkAddr3) ) // TestDADDisabled tests that an address successfully resolves immediately @@ -77,26 +89,86 @@ type ndpDADEvent struct { err *tcpip.Error } +type ndpRouterEvent struct { + nicid tcpip.NICID + addr tcpip.Address + // true if router was discovered, false if invalidated. + discovered bool +} + var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP // related events happen for test purposes. type ndpDispatcher struct { - dadC chan ndpDADEvent + dadC chan ndpDADEvent + routerC chan ndpRouterEvent + rememberRouter bool + routeTable []tcpip.Route } // Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. -// -// If the DAD event matches what we are expecting, send signal on n.dadC. func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { - n.dadC <- ndpDADEvent{ - nicid, - addr, - resolved, - err, + if n.dadC != nil { + n.dadC <- ndpDADEvent{ + nicid, + addr, + resolved, + err, + } } } +// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. +func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicid tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) { + if n.routerC != nil { + n.routerC <- ndpRouterEvent{ + nicid, + addr, + true, + } + } + + if !n.rememberRouter { + return false, nil + } + + rt := append([]tcpip.Route(nil), n.routeTable...) + rt = append(rt, tcpip.Route{ + Destination: header.IPv6EmptySubnet, + Gateway: addr, + NIC: nicid, + }) + n.routeTable = rt + return true, rt +} + +// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. +func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicid tcpip.NICID, addr tcpip.Address) []tcpip.Route { + if n.routerC != nil { + n.routerC <- ndpRouterEvent{ + nicid, + addr, + false, + } + } + + var rt []tcpip.Route + exclude := tcpip.Route{ + Destination: header.IPv6EmptySubnet, + Gateway: addr, + NIC: nicid, + } + + for _, r := range n.routeTable { + if r != exclude { + rt = append(rt, r) + } + } + n.routeTable = rt + return rt +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -609,3 +681,333 @@ func TestSetNDPConfigurations(t *testing.T) { }) } } + +// raBuf returns a valid NDP Router Advertisement. +// +// Note, raBuf does not populate any of the RA fields other than the +// Router Lifetime. +func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { + icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(0) + // Populate the Router Lifetime. + binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + iph.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: header.NDPHopLimit, + SrcAddr: ip, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} +} + +// TestNoRouterDiscovery tests that router discovery will not be performed if +// configured not to. +func TestNoRouterDiscovery(t *testing.T) { + // Being configured to discover routers means handle and + // discover are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, discover = + // true and forwarding = false (the required configuration to do + // router discovery) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + discover := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + DiscoverDefaultRouters: discover, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router when configured not to") + case <-time.After(defaultTimeout): + } + }) + } +} + +// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not +// remember a discovered router when the dispatcher asks it not to. +func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + routeTable := []tcpip.Route{ + { + header.IPv6EmptySubnet, + llAddr3, + 1, + }, + } + s.SetRouteTable(routeTable) + + // Rx an RA with short lifetime. + lifetime := time.Duration(1) + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(lifetime))) + select { + case r := <-ndpDisp.routerC: + if r.nicid != 1 { + t.Fatalf("got r.nicid = %d, want = 1", r.nicid) + } + if r.addr != llAddr2 { + t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr2) + } + if !r.discovered { + t.Fatal("got r.discovered = false, want = true") + } + case <-time.After(defaultTimeout): + t.Fatal("timeout waiting for router discovery event") + } + + // Original route table should not have been modified. + if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + } + + // Wait for the normal invalidation time plus an extra second to + // make sure we do not actually receive any invalidation events as + // we should not have remembered the router in the first place. + select { + case <-ndpDisp.routerC: + t.Fatal("should not have received any router events") + case <-time.After(lifetime*time.Second + defaultTimeout): + } + + // Original route table should not have been modified. + if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + } +} + +func TestRouterDiscovery(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 10), + rememberRouter: true, + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + waitForEvent := func(addr tcpip.Address, discovered bool, timeout time.Duration) { + t.Helper() + + select { + case r := <-ndpDisp.routerC: + if r.nicid != 1 { + t.Fatalf("got r.nicid = %d, want = 1", r.nicid) + } + if r.addr != addr { + t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + } + if r.discovered != discovered { + t.Fatalf("got r.discovered = %t, want = %t", r.discovered, discovered) + } + case <-time.After(timeout): + t.Fatal("timeout waiting for router discovery event") + } + } + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA from lladdr2 with zero lifetime. It should not be + // remembered. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + select { + case <-ndpDisp.routerC: + t.Fatal("unexpectedly discovered a router with 0 lifetime") + case <-time.After(defaultTimeout): + } + + // Rx an RA from lladdr2 with a huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + waitForEvent(llAddr2, true, defaultTimeout) + + // Should have a default route through the discovered router. + if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Rx an RA from another router (lladdr3) with non-zero lifetime. + l3Lifetime := time.Duration(6) + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime))) + waitForEvent(llAddr3, true, defaultTimeout) + + // Should have default routes through the discovered routers. + if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Rx an RA from lladdr2 with lesser lifetime. + l2Lifetime := time.Duration(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime))) + select { + case <-ndpDisp.routerC: + t.Fatal("Should not receive a router event when updating lifetimes for known routers") + case <-time.After(defaultTimeout): + } + + // Should still have a default route through the discovered routers. + if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Wait for lladdr2's router invalidation timer to fire. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + waitForEvent(llAddr2, false, l2Lifetime*time.Second+defaultTimeout) + + // Should no longer have the default route through lladdr2. + if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Rx an RA from lladdr2 with huge lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) + waitForEvent(llAddr2, true, defaultTimeout) + + // Should have a default route through the discovered routers. + if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}, {header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + waitForEvent(llAddr2, false, defaultTimeout) + + // Should have deleted the default route through the router that just + // got invalidated. + if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Wait for lladdr3's router invalidation timer to fire. The lifetime + // of the router should have been updated to the most recent (smaller) + // lifetime. + // + // Wait for the normal lifetime plus an extra bit for the + // router to get invalidated. If we don't get an invalidation + // event after this time, then something is wrong. + waitForEvent(llAddr3, false, l3Lifetime*time.Second+defaultTimeout) + + // Should not have any routes now that all discovered routers have been + // invalidated. + if got := len(s.GetRouteTable()); got != 0 { + t.Fatalf("got len(s.GetRouteTable()) = %d, want = 0", got) + } +} + +// TestRouterDiscoveryMaxRouters tests that only +// stack.MaxDiscoveredDefaultRouters discovered routers are remembered. +func TestRouterDiscoveryMaxRouters(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, 10), + rememberRouter: true, + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectedRt := [stack.MaxDiscoveredDefaultRouters]tcpip.Route{} + + // Receive an RA from 2 more than the max number of discovered routers. + for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { + linkAddr := []byte{2, 2, 3, 4, 5, 0} + linkAddr[5] = byte(i) + llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) + + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) + + if i <= stack.MaxDiscoveredDefaultRouters { + expectedRt[i-1] = tcpip.Route{header.IPv6EmptySubnet, llAddr, 1} + select { + case r := <-ndpDisp.routerC: + if r.nicid != 1 { + t.Fatalf("got r.nicid = %d, want = 1", r.nicid) + } + if r.addr != llAddr { + t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr) + } + if !r.discovered { + t.Fatal("got r.discovered = false, want = true") + } + case <-time.After(defaultTimeout): + t.Fatal("timeout waiting for router discovery event") + } + + } else { + select { + case <-ndpDisp.routerC: + t.Fatal("should not have discovered a new router after we already discovered the max number of routers") + case <-time.After(defaultTimeout): + } + } + } + + // Should only have default routes for the first + // stack.MaxDiscoveredDefaultRouters discovered routers. + if got := s.GetRouteTable(); !cmp.Equal(got, expectedRt[:]) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt) + } +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 03be7d3d4..3edb513d4 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -231,6 +231,13 @@ func (s *Subnet) Broadcast() Address { return Address(addr) } +// Equal returns true if s equals o. +// +// Needed to use cmp.Equal on Subnet as its fields are unexported. +func (s Subnet) Equal(o Subnet) bool { + return s == o +} + // NICID is a number that uniquely identifies a NIC. type NICID int32 -- cgit v1.2.3 From 0c424ea73198866066ddc5e7047a3a357d313f46 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 6 Nov 2019 19:39:57 -0800 Subject: Rename nicid to nicID to follow go-readability initialisms https://github.com/golang/go/wiki/CodeReviewComments#initialisms This change does not introduce any new functionality. It just renames variables from `nicid` to `nicID`. PiperOrigin-RevId: 278992966 --- pkg/tcpip/network/arp/arp.go | 12 ++--- pkg/tcpip/network/ipv4/ipv4.go | 8 +-- pkg/tcpip/network/ipv4/ipv4_test.go | 4 +- pkg/tcpip/network/ipv6/icmp.go | 8 +-- pkg/tcpip/network/ipv6/ipv6.go | 8 +-- pkg/tcpip/stack/ndp.go | 8 +-- pkg/tcpip/stack/ndp_test.go | 48 ++++++++--------- pkg/tcpip/stack/registration.go | 12 ++--- pkg/tcpip/stack/stack.go | 30 +++++------ pkg/tcpip/stack/stack_test.go | 86 +++++++++++++++---------------- pkg/tcpip/stack/transport_demuxer_test.go | 8 +-- pkg/tcpip/transport/icmp/endpoint.go | 26 +++++----- pkg/tcpip/transport/packet/endpoint.go | 6 +-- pkg/tcpip/transport/tcp/endpoint.go | 18 +++---- pkg/tcpip/transport/udp/endpoint.go | 50 +++++++++--------- 15 files changed, 166 insertions(+), 166 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 4161ebf87..0ee509ebe 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -42,7 +42,7 @@ const ( // endpoint implements stack.NetworkEndpoint. type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache } @@ -58,7 +58,7 @@ func (e *endpoint) MTU() uint32 { } func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { @@ -102,7 +102,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { switch h.Op() { case header.ARPRequest: localAddr := tcpip.Address(h.ProtocolAddressTarget()) - if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 { + if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) @@ -118,7 +118,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr) + e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) } } @@ -135,12 +135,12 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { if addrWithPrefix.Address != ProtocolAddress { return nil, tcpip.ErrBadLocalAddress } return &endpoint{ - nicid: nicid, + nicID: nicID, linkEP: sender, linkAddrCache: linkAddrCache, }, nil diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 26f1402ed..ac16c8add 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -47,7 +47,7 @@ const ( ) type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int linkEP stack.LinkEndpoint @@ -57,9 +57,9 @@ type endpoint struct { } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { e := &endpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, @@ -89,7 +89,7 @@ func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } // ID returns the ipv4 endpoint ID. diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index f100d84ee..01dfb5f20 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -451,7 +451,7 @@ func TestInvalidFragments(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - const nicid tcpip.NICID = 42 + const nicID tcpip.NICID = 42 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ ipv4.NewProtocol(), @@ -461,7 +461,7 @@ func TestInvalidFragments(t *testing.T) { var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) ep := channel.New(10, 1500, linkAddr) - s.CreateNIC(nicid, sniffer.New(ep)) + s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{ diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 58f8e80df..6629951c6 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -180,7 +180,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P // rxNICID so the packet is processed as defined in RFC 4861, // as per RFC 4862 section 5.4.3. - if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 { + if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { // We don't have a useful answer; the best we can do is ignore the request. return } @@ -218,7 +218,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P // // Furthermore, the entirety of NDP handling here seems to be // contradicted by RFC 4861. - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) // @@ -274,9 +274,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P // inform the netstack integration that a duplicate address was // detected outside of DAD. - e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) + e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, r.RemoteLinkAddress) if targetAddr != r.RemoteAddress { - e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) } case header.ICMPv6EchoRequest: diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 805d1739c..4cee848a1 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -43,7 +43,7 @@ const ( ) type endpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int linkEP stack.LinkEndpoint @@ -65,7 +65,7 @@ func (e *endpoint) MTU() uint32 { // NICID returns the ID of the NIC this endpoint belongs to. func (e *endpoint) NICID() tcpip.NICID { - return e.nicid + return e.nicID } // ID returns the ipv6 endpoint ID. @@ -218,9 +218,9 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &endpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index a216242d8..8e49f7a56 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -78,7 +78,7 @@ const ( // receive and handle NDP related events. type NDPDispatcher interface { // OnDuplicateAddressDetectionStatus will be called when the DAD process - // for an address (addr) on a NIC (with ID nicid) completes. resolved + // for an address (addr) on a NIC (with ID nicID) completes. resolved // will be set to true if DAD completed successfully (no duplicate addr // detected); false otherwise (addr was detected to be a duplicate on // the link the NIC is a part of, or it was stopped for some other @@ -87,7 +87,7 @@ type NDPDispatcher interface { // // This function is permitted to block indefinitely without interfering // with the stack's operation. - OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) + OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) // OnDefaultRouterDiscovered will be called when a new default router is // discovered. Implementations must return true along with a new valid @@ -97,7 +97,7 @@ type NDPDispatcher interface { // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterDiscovered(nicid tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) + OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) // OnDefaultRouterInvalidated will be called when a discovered default // router is invalidated. Implementers must return a new valid route @@ -105,7 +105,7 @@ type NDPDispatcher interface { // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterInvalidated(nicid tcpip.NICID, addr tcpip.Address) []tcpip.Route + OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route } // NDPConfigurations is the NDP configurations for the netstack. diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 0dbe4da9d..50ce1bbfa 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -83,14 +83,14 @@ func TestDADDisabled(t *testing.T) { // ndpDADEvent is a set of parameters that was passed to // ndpDispatcher.OnDuplicateAddressDetectionStatus. type ndpDADEvent struct { - nicid tcpip.NICID + nicID tcpip.NICID addr tcpip.Address resolved bool err *tcpip.Error } type ndpRouterEvent struct { - nicid tcpip.NICID + nicID tcpip.NICID addr tcpip.Address // true if router was discovered, false if invalidated. discovered bool @@ -108,10 +108,10 @@ type ndpDispatcher struct { } // Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. -func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { +func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) { if n.dadC != nil { n.dadC <- ndpDADEvent{ - nicid, + nicID, addr, resolved, err, @@ -120,10 +120,10 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, add } // Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. -func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicid tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) { +func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) { if n.routerC != nil { n.routerC <- ndpRouterEvent{ - nicid, + nicID, addr, true, } @@ -137,17 +137,17 @@ func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicid tcpip.NICID, addr tcpip. rt = append(rt, tcpip.Route{ Destination: header.IPv6EmptySubnet, Gateway: addr, - NIC: nicid, + NIC: nicID, }) n.routeTable = rt return true, rt } // Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. -func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicid tcpip.NICID, addr tcpip.Address) []tcpip.Route { +func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route { if n.routerC != nil { n.routerC <- ndpRouterEvent{ - nicid, + nicID, addr, false, } @@ -157,7 +157,7 @@ func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicid tcpip.NICID, addr tcpip exclude := tcpip.Route{ Destination: header.IPv6EmptySubnet, Gateway: addr, - NIC: nicid, + NIC: nicID, } for _, r := range n.routeTable { @@ -254,8 +254,8 @@ func TestDADResolve(t *testing.T) { if e.err != nil { t.Fatal("got DAD error: ", e.err) } - if e.nicid != 1 { - t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) } if e.addr != addr1 { t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) @@ -421,8 +421,8 @@ func TestDADFail(t *testing.T) { if e.err != nil { t.Fatal("got DAD error: ", e.err) } - if e.nicid != 1 { - t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) } if e.addr != addr1 { t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) @@ -492,8 +492,8 @@ func TestDADStop(t *testing.T) { if e.err != nil { t.Fatal("got DAD error: ", e.err) } - if e.nicid != 1 { - t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) } if e.addr != addr1 { t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) @@ -661,8 +661,8 @@ func TestSetNDPConfigurations(t *testing.T) { if e.err != nil { t.Fatal("got DAD error: ", e.err) } - if e.nicid != 1 { - t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) } if e.addr != addr1 { t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) @@ -786,8 +786,8 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(lifetime))) select { case r := <-ndpDisp.routerC: - if r.nicid != 1 { - t.Fatalf("got r.nicid = %d, want = 1", r.nicid) + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) } if r.addr != llAddr2 { t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr2) @@ -839,8 +839,8 @@ func TestRouterDiscovery(t *testing.T) { select { case r := <-ndpDisp.routerC: - if r.nicid != 1 { - t.Fatalf("got r.nicid = %d, want = 1", r.nicid) + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) } if r.addr != addr { t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) @@ -983,8 +983,8 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { expectedRt[i-1] = tcpip.Route{header.IPv6EmptySubnet, llAddr, 1} select { case r := <-ndpDisp.routerC: - if r.nicid != 1 { - t.Fatalf("got r.nicid = %d, want = 1", r.nicid) + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) } if r.addr != llAddr { t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 5806d294c..c0026f5a3 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -117,7 +117,7 @@ type PacketEndpoint interface { // should construct its own ethernet header for applications. // // HandlePacket takes ownership of pkt. - HandlePacket(nicid tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -281,7 +281,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) + NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -440,10 +440,10 @@ type LinkAddressResolver interface { type LinkAddressCache interface { // CheckLocalAddress determines if the given local address exists, and if it // does not exist. - CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID + CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID // AddLinkAddress adds a link address to the cache. - AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) + AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC). // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver @@ -454,10 +454,10 @@ type LinkAddressCache interface { // If address resolution is required, ErrNoLinkAddress and a notification channel is // returned for the top level caller to block. Channel is closed once address resolution // is complete (success or not). - GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) + GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) // RemoveWaker removes a waker that has been added in GetLinkAddress(). - RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) + RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) } // RawFactory produces endpoints for writing various types of raw packets. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 08599d765..99809df75 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1055,13 +1055,13 @@ func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool // CheckLocalAddress determines if the given local address exists, and if it // does, returns the id of the NIC it's bound to. Returns 0 if the address // does not exist. -func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { +func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { s.mu.RLock() defer s.mu.RUnlock() // If a NIC is specified, we try to find the address there only. - if nicid != 0 { - nic := s.nics[nicid] + if nicID != 0 { + nic := s.nics[nicID] if nic == nil { return 0 } @@ -1120,35 +1120,35 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { } // AddLinkAddress adds a link address to the stack link cache. -func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} +func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} s.linkAddrCache.add(fullAddr, linkAddr) // TODO: provide a way for a transport endpoint to receive a signal // that AddLinkAddress for a particular address has been called. } // GetLinkAddress implements LinkAddressCache.GetLinkAddress. -func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { s.mu.RLock() - nic := s.nics[nicid] + nic := s.nics[nicID] if nic == nil { s.mu.RUnlock() return "", nil, tcpip.ErrUnknownNICID } s.mu.RUnlock() - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) } // RemoveWaker implements LinkAddressCache.RemoveWaker. -func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { +func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { s.mu.RLock() defer s.mu.RUnlock() - if nic := s.nics[nicid]; nic == nil { - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + if nic := s.nics[nicID]; nic == nil { + fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} s.linkAddrCache.removeWaker(fullAddr, waker) } } @@ -1344,9 +1344,9 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip // WritePacket writes data directly to the specified NIC. It adds an ethernet // header based on the arguments. -func (s *Stack) WritePacket(nicid tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { +func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { s.mu.Lock() - nic, ok := s.nics[nicid] + nic, ok := s.nics[nicID] s.mu.Unlock() if !ok { return tcpip.ErrUnknownDevice @@ -1372,9 +1372,9 @@ func (s *Stack) WritePacket(nicid tcpip.NICID, dst tcpip.LinkAddress, netProto t // WriteRawPacket writes data directly to the specified NIC without adding any // headers. -func (s *Stack) WriteRawPacket(nicid tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error { +func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error { s.mu.Lock() - nic, ok := s.nics[nicid] + nic, ok := s.nics[nicID] s.mu.Unlock() if !ok { return tcpip.ErrUnknownDevice diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 1fac5477f..bf1d6974c 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -58,7 +58,7 @@ const ( // use the first three: destination address, source address, and transport // protocol. They're all one byte fields to simplify parsing. type fakeNetworkEndpoint struct { - nicid tcpip.NICID + nicID tcpip.NICID id stack.NetworkEndpointID prefixLen int proto *fakeNetworkProtocol @@ -71,7 +71,7 @@ func (f *fakeNetworkEndpoint) MTU() uint32 { } func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { - return f.nicid + return f.nicID } func (f *fakeNetworkEndpoint) PrefixLen() int { @@ -199,9 +199,9 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ - nicid: nicid, + nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, @@ -682,11 +682,11 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { } } -func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) { +func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) { t.Helper() - info, ok := s.NICInfo()[nicid] + info, ok := s.NICInfo()[nicID] if !ok { - t.Fatalf("NICInfo() failed to find nicid=%d", nicid) + t.Fatalf("NICInfo() failed to find nicID=%d", nicID) } if len(addr) == 0 { // No address given, verify that there is no address assigned to the NIC. @@ -719,7 +719,7 @@ func TestEndpointExpiration(t *testing.T) { localAddrByte byte = 0x01 remoteAddr tcpip.Address = "\x03" noAddr tcpip.Address = "" - nicid tcpip.NICID = 1 + nicID tcpip.NICID = 1 ) localAddr := tcpip.Address([]byte{localAddrByte}) @@ -731,7 +731,7 @@ func TestEndpointExpiration(t *testing.T) { }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -748,13 +748,13 @@ func TestEndpointExpiration(t *testing.T) { buf[0] = localAddrByte if promiscuous { - if err := s.SetPromiscuousMode(nicid, true); err != nil { + if err := s.SetPromiscuousMode(nicID, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } } if spoofing { - if err := s.SetSpoofing(nicid, true); err != nil { + if err := s.SetSpoofing(nicID, true); err != nil { t.Fatal("SetSpoofing failed:", err) } } @@ -762,7 +762,7 @@ func TestEndpointExpiration(t *testing.T) { // 1. No Address yet, send should only work for spoofing, receive for // promiscuous mode. //----------------------- - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -777,20 +777,20 @@ func TestEndpointExpiration(t *testing.T) { // 2. Add Address, everything should work. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) // 3. Remove the address, send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -805,10 +805,10 @@ func TestEndpointExpiration(t *testing.T) { // 4. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) @@ -826,10 +826,10 @@ func TestEndpointExpiration(t *testing.T) { // 6. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -845,10 +845,10 @@ func TestEndpointExpiration(t *testing.T) { // 7. Add Address back, everything should work again. //----------------------- - if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) testSend(t, r, ep, nil) @@ -856,17 +856,17 @@ func TestEndpointExpiration(t *testing.T) { // 8. Remove the route, sendTo/recv should still work. //----------------------- r.Release() - verifyAddress(t, s, nicid, localAddr) + verifyAddress(t, s, nicID, localAddr) testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) // 9. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. //----------------------- - if err := s.RemoveAddress(nicid, localAddr); err != nil { + if err := s.RemoveAddress(nicID, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - verifyAddress(t, s, nicid, noAddr) + verifyAddress(t, s, nicID, noAddr) if promiscuous { testRecv(t, fakeNet, localAddrByte, ep, buf) } else { @@ -1659,12 +1659,12 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto } func TestAddAddress(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1672,7 +1672,7 @@ func TestAddAddress(t *testing.T) { expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2) for _, addrLen := range []int{4, 16} { address := addrGen.next(addrLen) - if err := s.AddAddress(nicid, fakeNetNumber, address); err != nil { + if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil { t.Fatalf("AddAddress(address=%s) failed: %s", address, err) } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ @@ -1681,17 +1681,17 @@ func TestAddAddress(t *testing.T) { }) } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddress(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1708,24 +1708,24 @@ func TestAddProtocolAddress(t *testing.T) { PrefixLen: prefixLen, }, } - if err := s.AddProtocolAddress(nicid, protocolAddress); err != nil { + if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil { t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err) } expectedAddresses = append(expectedAddresses, protocolAddress) } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddAddressWithOptions(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1736,7 +1736,7 @@ func TestAddAddressWithOptions(t *testing.T) { for _, addrLen := range addrLenRange { for _, behavior := range behaviorRange { address := addrGen.next(addrLen) - if err := s.AddAddressWithOptions(nicid, fakeNetNumber, address, behavior); err != nil { + if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil { t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err) } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ @@ -1746,17 +1746,17 @@ func TestAddAddressWithOptions(t *testing.T) { } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddressWithOptions(t *testing.T) { - const nicid = 1 + const nicID = 1 s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, ep); err != nil { + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1775,7 +1775,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { PrefixLen: prefixLen, }, } - if err := s.AddProtocolAddressWithOptions(nicid, protocolAddress, behavior); err != nil { + if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil { t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err) } expectedAddresses = append(expectedAddresses, protocolAddress) @@ -1783,7 +1783,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { } } - gotAddresses := s.AllAddresses()[nicid] + gotAddresses := s.AllAddresses()[nicID] verifyAddresses(t, expectedAddresses, gotAddresses) } @@ -2030,8 +2030,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { if e.err != nil { t.Fatal("got DAD error: ", e.err) } - if e.nicid != 1 { - t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid) + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) } if e.addr != linkLocalAddr { t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index f54117c4e..3b28b06d0 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -79,17 +79,17 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) linkEPs := make(map[string]*channel.Endpoint) for i, linkEpName := range linkEpNames { channelEP := channel.New(256, mtu, "") - nicid := tcpip.NICID(i + 1) - if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil { + nicID := tcpip.NICID(i + 1) + if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } linkEPs[linkEpName] = channelEP - if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil { + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress IPv4 failed: %v", err) } - if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil { + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { t.Fatalf("AddAddress IPv6 failed: %v", err) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 0092d0ea9..70e008d36 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -278,13 +278,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } else { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicid := to.NIC + nicID := to.NIC if e.BindNICID != 0 { - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.BindNICID + nicID = e.BindNICID } toCopy := *to @@ -295,7 +295,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } // Find the enpoint. - r, err := e.stack.FindRoute(nicid, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) if err != nil { return 0, nil, err } @@ -483,7 +483,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicid := addr.NIC + nicID := addr.NIC localPort := uint16(0) switch e.state { case stateBound, stateConnected: @@ -492,11 +492,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { break } - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.BindNICID + nicID = e.BindNICID default: return tcpip.ErrInvalidEndpointState } @@ -507,7 +507,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } @@ -524,14 +524,14 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // v6only is set to false and this is an ipv6 endpoint. netProtos := []tcpip.NetworkProtocolNumber{netProto} - id, err = e.registerWithStack(nicid, netProtos, id) + id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { return err } e.ID = id e.route = r.Clone() - e.RegisterNICID = nicid + e.RegisterNICID = nicID e.state = stateConnected @@ -582,18 +582,18 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } -func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */) switch err { case nil: return true, nil diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 26335094e..0010b5e5f 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -266,7 +266,7 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } // HandlePacket implements stack.PacketEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { ep.rcvMu.Lock() // Drop the packet if our buffer is currently full. @@ -293,13 +293,13 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, // Get info directly from the ethernet header. hdr := header.Ethernet(pkt.LinkHeader) packet.senderAddr = tcpip.FullAddress{ - NIC: nicid, + NIC: nicID, Addr: tcpip.Address(hdr.SourceAddress()), } } else { // Guess the would-be ethernet header. packet.senderAddr = tcpip.FullAddress{ - NIC: nicid, + NIC: nicID, Addr: tcpip.Address(localAddr), } } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index d29f0f81b..79fec6b77 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1214,9 +1214,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.bindToDevice = 0 return nil } - for nicid, nic := range e.stack.NICInfo() { + for nicID, nic := range e.stack.NICInfo() { if nic.Name == string(v) { - e.bindToDevice = nicid + e.bindToDevice = nicID return nil } } @@ -1634,7 +1634,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return tcpip.ErrAlreadyConnected } - nicid := addr.NIC + nicID := addr.NIC switch e.state { case StateBound: // If we're already bound to a NIC but the caller is requesting @@ -1643,11 +1643,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc break } - if nicid != 0 && nicid != e.boundNICID { + if nicID != 0 && nicID != e.boundNICID { return tcpip.ErrNoRoute } - nicid = e.boundNICID + nicID = e.boundNICID case StateInitial: // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) @@ -1666,7 +1666,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) if err != nil { return err } @@ -1681,7 +1681,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice) if err != nil { return err } @@ -1716,7 +1716,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { + switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { case nil: e.ID = id return true, nil @@ -1741,7 +1741,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc e.isRegistered = true e.state = StateConnecting e.route = r.Clone() - e.boundNICID = nicid + e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 4e11de9db..5270f24df 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -282,7 +282,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { localAddr := e.ID.LocalAddress if isBroadcastOrMulticast(localAddr) { // A packet can only originate from a unicast address (i.e., an interface). @@ -290,20 +290,20 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr } if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { - if nicid == 0 { - nicid = e.multicastNICID + if nicID == 0 { + nicID = e.multicastNICID } - if localAddr == "" && nicid == 0 { + if localAddr == "" && nicID == 0 { localAddr = e.multicastAddr } } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop) + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop) if err != nil { return stack.Route{}, 0, err } - return r, nicid, nil + return r, nicID, nil } // Write writes data to the endpoint's peer. This method does not block @@ -382,13 +382,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } else { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. - nicid := to.NIC + nicID := to.NIC if e.BindNICID != 0 { - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return 0, nil, tcpip.ErrNoRoute } - nicid = e.BindNICID + nicID = e.BindNICID } if to.Addr == header.IPv4Broadcast && !e.broadcast { @@ -400,7 +400,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, err } - r, _, err := e.connectRoute(nicid, *to, netProto) + r, _, err := e.connectRoute(nicID, *to, netProto) if err != nil { return 0, nil, err } @@ -622,9 +622,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.bindToDevice = 0 return nil } - for nicid, nic := range e.stack.NICInfo() { + for nicID, nic := range e.stack.NICInfo() { if nic.Name == string(v) { - e.bindToDevice = nicid + e.bindToDevice = nicID return nil } } @@ -907,7 +907,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicid := addr.NIC + nicID := addr.NIC var localPort uint16 switch e.state { case StateInitial: @@ -917,16 +917,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { break } - if nicid != 0 && nicid != e.BindNICID { + if nicID != 0 && nicID != e.BindNICID { return tcpip.ErrInvalidEndpointState } - nicid = e.BindNICID + nicID = e.BindNICID default: return tcpip.ErrInvalidEndpointState } - r, nicid, err := e.connectRoute(nicid, addr, netProto) + r, nicID, err := e.connectRoute(nicID, addr, netProto) if err != nil { return err } @@ -954,7 +954,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } } - id, err = e.registerWithStack(nicid, netProtos, id) + id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { return err } @@ -967,7 +967,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.ID = id e.route = r.Clone() e.dstPort = addr.Port - e.RegisterNICID = nicid + e.RegisterNICID = nicID e.effectiveNetProtos = netProtos e.state = StateConnected @@ -1022,7 +1022,7 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } -func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { if e.ID.LocalPort == 0 { port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice) if err != nil { @@ -1031,7 +1031,7 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ id.LocalPort = port } - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) if err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice) } @@ -1061,11 +1061,11 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } } - nicid := addr.NIC + nicID := addr.NIC if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { // A local unicast address was specified, verify that it's valid. - nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) - if nicid == 0 { + nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + if nicID == 0 { return tcpip.ErrBadLocalAddress } } @@ -1074,13 +1074,13 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { LocalPort: addr.Port, LocalAddress: addr.Addr, } - id, err = e.registerWithStack(nicid, netProtos, id) + id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { return err } e.ID = id - e.RegisterNICID = nicid + e.RegisterNICID = nicID e.effectiveNetProtos = netProtos // Mark endpoint as bound. -- cgit v1.2.3 From 66ebb6575f929a389d3c929977ed5e31d706fcfe Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 7 Nov 2019 09:45:26 -0800 Subject: Add support for TIME_WAIT timeout. This change adds explicit support for honoring the 2MSL timeout for sockets in TIME_WAIT state. It also adds support for the TCP_LINGER2 option that allows modification of the FIN_WAIT2 state timeout duration for a given socket. It also adds an option to modify the Stack wide TIME_WAIT timeout but this is only for testing. On Linux this is fixed at 60s. Further, we also now correctly process RST's in CLOSE_WAIT and close the socket similar to linux without moving it to error state. We also now handle SYN in ESTABLISHED state as per RFC5961#section-4.1. Earlier we would just drop these SYNs. Which can result in some tests that pass on linux to fail on gVisor. Netstack now honors TIME_WAIT correctly as well as handles the following cases correctly. - TCP RSTs in TIME_WAIT are ignored. - A duplicate TCP FIN during TIME_WAIT extends the TIME_WAIT and a dup ACK is sent in response to the FIN as the dup FIN indicates potential loss of the original final ACK. - An out of order segment during TIME_WAIT generates a dup ACK. - A new SYN w/ a sequence number > the highest sequence number in the previous connection closes the TIME_WAIT early and opens a new connection. Further to make the SYN case work correctly the ISN (Initial Sequence Number) generation for Netstack has been updated to be as per RFC. Its not a pure random number anymore and follows the recommendation in https://tools.ietf.org/html/rfc6528#page-3. The current hash used is not a cryptographically secure hash function. A separate change will update the hash function used to Siphash similar to what is used in Linux. PiperOrigin-RevId: 279106406 --- pkg/sentry/socket/netstack/netstack.go | 20 + pkg/tcpip/adapters/gonet/gonet_test.go | 12 +- pkg/tcpip/stack/stack.go | 20 +- pkg/tcpip/stack/transport_demuxer.go | 33 +- pkg/tcpip/tcpip.go | 12 +- pkg/tcpip/transport/tcp/BUILD | 2 +- pkg/tcpip/transport/tcp/accept.go | 17 +- pkg/tcpip/transport/tcp/connect.go | 322 ++++++++++++-- pkg/tcpip/transport/tcp/endpoint.go | 101 ++++- pkg/tcpip/transport/tcp/endpoint_state.go | 26 +- pkg/tcpip/transport/tcp/protocol.go | 43 ++ pkg/tcpip/transport/tcp/rcv.go | 167 ++++++- pkg/tcpip/transport/tcp/tcp_test.go | 622 ++++++++++++++++++++++++++- test/syscalls/BUILD | 22 +- test/syscalls/linux/BUILD | 1 + test/syscalls/linux/socket_inet_loopback.cc | 336 +++++++++++++++ test/syscalls/linux/socket_ip_tcp_generic.cc | 93 ++++ 17 files changed, 1736 insertions(+), 113 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 27c6692c4..d92399efd 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1173,6 +1173,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa copy(b, v) return b, nil + case linux.TCP_LINGER2: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPLingerTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Second), nil + default: emitUnimplementedEventTCP(t, name) } @@ -1556,6 +1568,14 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return nil + case linux.TCP_LINGER2: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 8ced960bb..ee077ae83 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -151,10 +151,8 @@ func TestCloseReader(t *testing.T) { buf := make([]byte, 256) n, err := c.Read(buf) - got, ok := err.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want) + if n != 0 || err != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err) } }() sender, err := connect(s, addr) @@ -203,10 +201,8 @@ func TestCloseReaderWithForwarder(t *testing.T) { buf := make([]byte, 256) n, e := c.Read(buf) - got, ok := e.(*net.OpError) - want := tcpip.ErrConnectionAborted - if n != 0 || !ok || got.Err.Error() != want.String() { - t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want) + if n != 0 || e != io.EOF { + t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e) } }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 99809df75..2f8d8e822 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -402,11 +402,11 @@ type Stack struct { // by the stack. icmpRateLimiter *ICMPRateLimiter - // portSeed is a one-time random value initialized at stack startup + // seed is a one-time random value initialized at stack startup // and is used to seed the TCP port picking on active connections // // TODO(gvisor.dev/issue/940): S/R this field. - portSeed uint32 + seed uint32 // ndpConfigs is the default NDP configurations used by interfaces. ndpConfigs NDPConfigurations @@ -544,7 +544,7 @@ func New(opts Options) *Stack { stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, icmpRateLimiter: NewICMPRateLimiter(), - portSeed: generateRandUint32(), + seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, uniqueIDGenerator: opts.UniqueID, @@ -1186,6 +1186,12 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { s.mu.Unlock() } +// FindTransportEndpoint finds an endpoint that most closely matches the provided +// id. If no endpoint is found it returns nil. +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, r) +} + // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. @@ -1573,12 +1579,12 @@ func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRoute return nil } -// PortSeed returns a 32 bit value that can be used as a seed value for port -// picking. +// Seed returns a 32 bit value that can be used as a seed value for port +// picking, ISN generation etc. // // NOTE: The seed is generated once during stack initialization only. -func (s *Stack) PortSeed() uint32 { - return s.portSeed +func (s *Stack) Seed() uint32 { + return s.seed } func generateRandUint32() uint32 { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 594570216..cb805522b 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -103,7 +103,6 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p epsByNic.mu.RUnlock() // Don't use defer for performance reasons. return } - // multiPortEndpoints are guaranteed to have at least one element. selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. @@ -507,10 +506,40 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr if ep, ok := eps.endpoints[nid]; ok { matchedEPs = append(matchedEPs, ep) } - return matchedEPs } +// findTransportEndpoint find a single endpoint that most closely matches the provided id. +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { + eps, ok := d.protocol[protocolIDs{netProto, transProto}] + if !ok { + return nil + } + // Try to find the endpoint. + eps.mu.RLock() + epsByNic := d.findEndpointLocked(eps, id) + // Fail if we didn't find one. + if epsByNic == nil { + eps.mu.RUnlock() + return nil + } + + epsByNic.mu.RLock() + eps.mu.RUnlock() + + mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNic.endpoints[0]; !ok { + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return nil + } + } + + ep := selectEndpoint(id, mpep, epsByNic.seed) + epsByNic.mu.RUnlock() + return ep +} + // findEndpointLocked returns the endpoint that most closely matches the given // id. func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 3edb513d4..bd5eb89ca 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -586,6 +586,16 @@ type MaxSegOption int // A zero value indicates the default. type TTLOption uint8 +// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state +// before being marked closed. +type TCPLingerTimeoutOption time.Duration + +// TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the +// maximum duration for which a socket lingers in the TIME_WAIT state +// before being marked closed. +type TCPTimeWaitTimeoutOption time.Duration + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 @@ -1329,8 +1339,8 @@ var ( // GetDanglingEndpoints returns all dangling endpoints. func GetDanglingEndpoints() []Endpoint { - es := make([]Endpoint, 0, len(danglingEndpoints)) danglingEndpointsMu.Lock() + es := make([]Endpoint, 0, len(danglingEndpoints)) for e := range danglingEndpoints { es = append(es, e) } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index f1dbc6f91..3f47b328d 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -71,7 +71,7 @@ filegroup( go_test( name = "tcp_test", - size = "small", + size = "medium", srcs = [ "dual_stack_test.go", "sack_scoreboard_test.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index cb0e13ebc..0e8e0a2b4 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -269,8 +269,8 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) - ep, err := l.createConnectingEndpoint(s, cookie, irs, opts) + isn := generateSecureISN(s.id, l.stack.Seed()) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts) if err != nil { return nil, err } @@ -289,7 +289,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Perform the 3-way handshake. h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) - h.resetToSynRcvd(cookie, irs, opts) + h.resetToSynRcvd(isn, irs, opts) if err := h.execute(); err != nil { ep.Close() if l.listenEP != nil { @@ -361,6 +361,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header defer decSynRcvdCount() defer e.decSynRcvdCount() defer s.decRef() + n, err := ctx.createEndpointAndPerformHandshake(s, opts) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() @@ -368,6 +369,11 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header return } ctx.removePendingEndpoint(n) + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + e.deliverAccepted(n) } @@ -543,6 +549,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // number of goroutines as we do check before // entering here that there was at least some // space available in the backlog. + + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index ca982c451..a114c06c1 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -15,6 +15,7 @@ package tcp import ( + "encoding/binary" "sync" "time" @@ -22,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -139,7 +141,32 @@ func (h *handshake) resetState() { h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) + h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed()) +} + +// generateSecureISN generates a secure Initial Sequence number based on the +// recommendation here https://tools.ietf.org/html/rfc6528#page-3. +func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { + isnHasher := jenkins.Sum32(seed) + isnHasher.Write([]byte(id.LocalAddress)) + isnHasher.Write([]byte(id.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, id.LocalPort) + isnHasher.Write(portBuf) + binary.LittleEndian.PutUint16(portBuf, id.RemotePort) + isnHasher.Write(portBuf) + // The time period here is 64ns. This is similar to what linux uses + // generate a sequence number that overlaps less than one + // time per MSL (2 minutes). + // + // A 64ns clock ticks 10^9/64 = 15625000) times in a second. + // To wrap the whole 32 bit space would require + // 2^32/1562500 ~ 274 seconds. + // + // Which sort of guarantees that we won't reuse the ISN for a new + // connection for the same tuple for at least 274s. + isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6) + return seqnum.Value(isn) } // effectiveRcvWndScale returns the effective receive window scale to be used. @@ -809,7 +836,19 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { e.state = StateError e.HardError = err if err != tcpip.ErrConnectionReset { - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + // The exact sequence number to be used for the RST is the same as the + // one used by Linux. We need to handle the case of window being shrunk + // which can cause sndNxt to be outside the acceptable window on the + // receiver. + // + // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more + // information. + sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd) + resetSeqNum := sndWndEnd + if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1< + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() } else if s.flagIsSet(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. @@ -856,7 +960,15 @@ func (e *endpoint) handleSegments() *tcpip.Error { // RFC 793, page 41 states that "once in the ESTABLISHED // state all segments must carry current acknowledgment // information." - e.rcv.handleRcvdSegment(s) + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + s.decRef() + return err + } + if drop { + s.decRef() + continue + } e.snd.handleRcvdSegment(s) } s.decRef() @@ -955,7 +1067,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.mu.Unlock() - // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -1001,6 +1112,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // RTT itself. e.rcvAutoParams.prevCopied = initialRcvWnd e.rcvListMu.Unlock() + e.stack.Stats().TCP.CurrentEstablished.Increment() + e.mu.Lock() + e.state = StateEstablished + e.mu.Unlock() } e.keepalive.timer.init(&e.keepalive.waker) @@ -1008,10 +1123,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() - if e.state != StateEstablished { - e.stack.Stats().TCP.CurrentEstablished.Increment() - e.state = StateEstablished - } drained := e.drainDone != nil e.mu.Unlock() if drained { @@ -1042,7 +1153,13 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { { w: &closeWaker, f: func() *tcpip.Error { - return tcpip.ErrConnectionAborted + // This means the socket is being closed due + // to the TCP_FIN_WAIT2 timeout was hit. Just + // mark the socket as closed. + e.mu.Lock() + e.state = StateClose + e.mu.Unlock() + return nil }, }, { @@ -1085,17 +1202,18 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.resetConnectionLocked(tcpip.ErrConnectionAborted) e.mu.Unlock() } + if n¬ifyClose != 0 && closeTimer == nil { - // Reset the connection 3 seconds after - // the endpoint has been closed. - // - // The timer could fire in background - // when the endpoint is drained. That's - // OK as the loop here will not honor - // the firing until the undrain arrives. - closeTimer = time.AfterFunc(3*time.Second, func() { - closeWaker.Assert() - }) + e.mu.Lock() + if e.state == StateFinWait2 && e.closed { + // The socket has been closed and we are in FIN_WAIT2 + // so start the FIN_WAIT2 timer. + closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { + closeWaker.Assert() + }) + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + } + e.mu.Unlock() } if n¬ifyKeepaliveChanged != 0 { @@ -1117,6 +1235,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } } + if n¬ifyTickleWorker != 0 { + // Just a tickle notification. No need to do + // anything. + return nil + } + return nil }, }, @@ -1143,15 +1267,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.rcvListMu.Unlock() - e.mu.RLock() + e.mu.Lock() if e.workerCleanup { e.notifyProtocolGoroutine(notifyClose) } - e.mu.RUnlock() // Main loop. Handle segments until both send and receive ends of the // connection have completed. - for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList { + + for e.state != StateTimeWait && e.state != StateClose && e.state != StateError { + e.mu.Unlock() e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() @@ -1167,6 +1292,23 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return nil } + e.mu.Lock() + } + + state := e.state + e.mu.Unlock() + var reuseTW func() + if state == StateTimeWait { + // Disable close timer as we now entering real TIME_WAIT. + if closeTimer != nil { + closeTimer.Stop() + } + // Mark the current sleeper done so as to free all associated + // wakers. + s.Done() + // Wake up any waiters before we enter TIME_WAIT. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + reuseTW = e.doTimeWait() } // Mark endpoint as closed. @@ -1176,8 +1318,130 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.stack.Stats().TCP.CurrentEstablished.Decrement() e.state = StateClose } + // Lock released below. epilogue() + // A new SYN was received during TIME_WAIT and we need to abort + // the timewait and redirect the segment to the listener queue + if reuseTW != nil { + reuseTW() + } + return nil } + +// handleTimeWaitSegments processes segments received during TIME_WAIT +// state. +func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) { + checkRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + checkRequeue = false + break + } + extTW, newSyn := e.rcv.handleTimeWaitSegment(s) + if newSyn { + info := e.EndpointInfo.TransportEndpointInfo + newID := info.ID + newID.RemoteAddress = "" + newID.RemotePort = 0 + netProtos := []tcpip.NetworkProtocolNumber{info.NetProto} + // If the local address is an IPv4 address then also + // look for IPv6 dual stack endpoints that might be + // listening on the local address. + if newID.LocalAddress.To4() != "" { + netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} + } + for _, netProto := range netProtos { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + tcpEP := listenEP.(*endpoint) + if EndpointState(tcpEP.State()) == StateListen { + reuseTW = func() { + tcpEP.enqueueSegment(s) + } + // We explicitly do not decRef + // the segment as it's still + // valid and being reflected to + // a listening endpoint. + return false, reuseTW + } + } + } + } + if extTW { + extendTimeWait = true + } + s.decRef() + } + if checkRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + return extendTimeWait, nil +} + +// doTimeWait is responsible for handling the TCP behaviour once a socket +// enters the TIME_WAIT state. Optionally it can return a closure that +// should be executed after releasing the endpoint registrations. This is +// done in cases where a new SYN is received during TIME_WAIT that carries +// a sequence number larger than one see on the connection. +func (e *endpoint) doTimeWait() (twReuse func()) { + // Trigger a 2 * MSL time wait state. During this period + // we will drop all incoming segments. + // NOTE: On Linux this is not configurable and is fixed at 60 seconds. + timeWaitDuration := DefaultTCPTimeWaitTimeout + + // Get the stack wide configuration. + var tcpTW tcpip.TCPTimeWaitTimeoutOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil { + timeWaitDuration = time.Duration(tcpTW) + } + + const newSegment = 1 + const notification = 2 + const timeWaitDone = 3 + + s := sleep.Sleeper{} + s.AddWaker(&e.newSegmentWaker, newSegment) + s.AddWaker(&e.notificationWaker, notification) + + var timeWaitWaker sleep.Waker + s.AddWaker(&timeWaitWaker, timeWaitDone) + timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert) + defer timeWaitTimer.Stop() + + for { + e.workMu.Unlock() + v, _ := s.Fetch(true) + e.workMu.Lock() + switch v { + case newSegment: + extendTimeWait, reuseTW := e.handleTimeWaitSegments() + if reuseTW != nil { + return reuseTW + } + if extendTimeWait { + timeWaitTimer.Reset(timeWaitDuration) + } + case notification: + n := e.fetchNotifications() + if n¬ifyClose != 0 { + return nil + } + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + // Ignore extending TIME_WAIT during a + // save. For sockets in TIME_WAIT we just + // terminate the TIME_WAIT early. + e.handleTimeWaitSegments() + } + close(e.drainDone) + <-e.undrain + return nil + } + case timeWaitDone: + return nil + } + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 79fec6b77..04c92c04c 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -121,6 +121,11 @@ const ( notifyReset notifyKeepaliveChanged notifyMSSChanged + // notifyTickleWorker is used to tickle the protocol main loop during a + // restore after we update the endpoint state to the correct one. This + // ensures the loop terminates if the final state of the endpoint is + // say TIME_WAIT. + notifyTickleWorker ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -320,6 +325,11 @@ type endpoint struct { state EndpointState `state:".(EndpointState)"` + // origEndpointState is only used during a restore phase to save the + // endpoint state at restore time as the socket is moved to it's correct + // state. + origEndpointState EndpointState `state:"nosave"` + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID `state:"manual"` @@ -503,6 +513,16 @@ type endpoint struct { // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats Stats `state:"nosave"` + + // tcpLingerTimeout is the maximum amount of a time a socket + // a socket stays in TIME_WAIT state before being marked + // closed. + tcpLingerTimeout time.Duration + + // closed indicates that the user has called closed on the + // endpoint and at this point the endpoint is only around + // to complete the TCP shutdown. + closed bool } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -599,6 +619,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.SetSockOptInt(tcpip.DelayOption, 1) } + var tcpLT tcpip.TCPLingerTimeoutOption + if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil { + e.tcpLingerTimeout = time.Duration(tcpLT) + } + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -686,6 +711,13 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { // with it. It must be called only once and with no other concurrent calls to // the endpoint. func (e *endpoint) Close() { + e.mu.Lock() + closed := e.closed + e.mu.Unlock() + if closed { + return + } + // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) @@ -706,6 +738,8 @@ func (e *endpoint) Close() { e.isPortReserved = false } + // Mark endpoint as closed. + e.closed = true // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) @@ -731,9 +765,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { go func() { defer close(done) for n := range e.acceptedChan { - n.mu.Lock() - n.resetConnectionLocked(tcpip.ErrConnectionAborted) - n.mu.Unlock() + n.notifyProtocolGoroutine(notifyReset) n.Close() } }() @@ -1349,6 +1381,28 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + if v < 0 { + // Same as effectively disabling TCPLinger timeout. + v = 0 + } + var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil { + // We were unable to retrieve a stack config, just use + // the DefaultTCPLingerTimeout. + if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) { + stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) + } + } + // Cap it to the stack wide TCPLinger timeout. + if v > stkTCPLingerTimeout { + v = stkTCPLingerTimeout + } + e.tcpLingerTimeout = time.Duration(v) + e.mu.Unlock() + return nil + default: return nil } @@ -1562,6 +1616,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.RUnlock() return nil + case *tcpip.TCPLingerTimeoutOption: + e.mu.Lock() + *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -1696,7 +1756,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // src IP to ensure that for a given tuple (srcIP, destIP, // destPort) the offset used as a starting point is the same to // ensure that we can cycle through the port space effectively. - h := jenkins.Sum32(e.stack.PortSeed()) + h := jenkins.Sum32(e.stack.Seed()) h.Write([]byte(e.ID.LocalAddress)) h.Write([]byte(e.ID.RemoteAddress)) portBuf := make([]byte, 2) @@ -1782,9 +1842,8 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { // peer. func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.mu.Lock() - defer e.mu.Unlock() e.shutdownFlags |= flags - + finQueued := false switch { case e.state.connected(): // Close for read. @@ -1799,6 +1858,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { e.notifyProtocolGoroutine(notifyReset) + e.mu.Unlock() return nil } } @@ -1817,14 +1877,11 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { s := newSegmentFromView(&e.route, e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ - + finQueued = true // Mark endpoint as closed. e.sndClosed = true e.sndBufMu.Unlock() - - // Tell protocol goroutine to close. - e.sndCloseWaker.Assert() } case e.state == StateListen: @@ -1832,11 +1889,20 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) } - default: + e.mu.Unlock() return tcpip.ErrNotConnected } - + e.mu.Unlock() + if finQueued { + if e.workMu.TryLock() { + e.handleClose() + e.workMu.Unlock() + } else { + // Tell protocol goroutine to close. + e.sndCloseWaker.Assert() + } + } return nil } @@ -1928,12 +1994,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrWouldBlock } - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - - return n, wq, nil + return n, n.waiterQueue, nil } // Bind binds the endpoint to a specific local port and optionally address. @@ -2058,6 +2119,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.stack.Stats().TCP.ResetsReceived.Increment() } + e.enqueueSegment(s) +} + +func (e *endpoint) enqueueSegment(s *segment) { // Send packet to worker goroutine. if e.segmentQueue.enqueue(s) { e.newSegmentWaker.Assert() diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 19f003b6b..7aa4c3f0e 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -78,7 +78,7 @@ func (e *endpoint) beforeSave() { } fallthrough case StateError, StateClose: - for e.state == StateError && e.workerRunning { + for (e.state == StateError || e.state == StateClose) && e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() @@ -165,6 +165,12 @@ func (e *endpoint) loadState(state EndpointState) { // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { + // Freeze segment queue before registering to prevent any segments + // from being delivered while it is being restored. + e.origEndpointState = e.state + // Restore the endpoint to InitialState as it will be moved to + // its origEndpointState during Resume. + e.state = StateInitial stack.StackFromEnv.RegisterRestoredEndpoint(e) } @@ -173,8 +179,8 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() + state := e.origEndpointState - state := e.state switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption @@ -189,7 +195,6 @@ func (e *endpoint) Resume(s *stack.Stack) { } bind := func() { - e.state = StateInitial if len(e.BindAddr) == 0 { e.BindAddr = e.ID.LocalAddress } @@ -219,6 +224,16 @@ func (e *endpoint) Resume(s *stack.Stack) { if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { panic("endpoint connecting failed: " + err.String()) } + e.mu.Lock() + e.state = e.origEndpointState + closed := e.closed + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) + if state == StateFinWait2 && closed { + // If the endpoint has been closed then make sure we notify so + // that the FIN_WAIT2 timer is started after a restore. + e.notifyProtocolGoroutine(notifyClose) + } connectedLoading.Done() case StateListen: tcpip.AsyncLoading.Add(1) @@ -265,8 +280,11 @@ func (e *endpoint) Resume(s *stack.Stack) { tcpip.AsyncLoading.Done() }() } - fallthrough + e.state = StateClose + e.stack.CompleteTransportEndpointCleanup(e) + tcpip.DeleteDanglingEndpoint(e) case StateError: + e.state = StateError e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index c8e4a0d7e..89b965c23 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -23,6 +23,7 @@ package tcp import ( "strings" "sync" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -54,6 +55,14 @@ const ( // MaxUnprocessedSegments is the maximum number of unprocessed segments // that can be queued for a given endpoint. MaxUnprocessedSegments = 300 + + // DefaultTCPLingerTimeout is the amount of time that sockets linger in + // FIN_WAIT_2 state before being marked closed. + DefaultTCPLingerTimeout = 60 * time.Second + + // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger + // in TIME_WAIT state before being marked closed. + DefaultTCPTimeWaitTimeout = 60 * time.Second ) // SACKEnabled option can be used to enable SACK support in the TCP @@ -93,6 +102,8 @@ type protocol struct { congestionControl string availableCongestionControl []string moderateReceiveBuffer bool + tcpLingerTimeout time.Duration + tcpTimeWaitTimeout time.Duration } // Number returns the tcp protocol number. @@ -212,6 +223,24 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPLingerTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpLingerTimeout = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPTimeWaitTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpTimeWaitTimeout = time.Duration(v) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -262,6 +291,18 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case *tcpip.TCPLingerTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) + p.mu.Unlock() + return nil + + case *tcpip.TCPTimeWaitTimeoutOption: + p.mu.Lock() + *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -274,5 +315,7 @@ func NewProtocol() stack.TransportProtocol { recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize}, congestionControl: ccReno, availableCongestionControl: []string{ccReno, ccCubic}, + tcpLingerTimeout: DefaultTCPLingerTimeout, + tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index e90f9a7d9..068b90fb6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -18,6 +18,7 @@ import ( "container/heap" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) @@ -209,6 +210,11 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum switch r.ep.state { case StateFinWait1: r.ep.state = StateFinWait2 + // Notify protocol goroutine that we have received an + // ACK to our FIN so that it can start the FIN_WAIT2 + // timer to abort connection if the other side does + // not close within 2MSL. + r.ep.notifyProtocolGoroutine(notifyClose) case StateClosing: r.ep.state = StateTimeWait case StateLastAck: @@ -253,23 +259,105 @@ func (r *receiver) updateRTT() { r.ep.rcvListMu.Unlock() } -// handleRcvdSegment handles TCP segments directed at the connection managed by -// r as they arrive. It is called by the protocol main loop. -func (r *receiver) handleRcvdSegment(s *segment) { +func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) { + r.ep.rcvListMu.Lock() + rcvClosed := r.ep.rcvClosed || r.closed + r.ep.rcvListMu.Unlock() + + // If we are in one of the shutdown states then we need to do + // additional checks before we try and process the segment. + switch state { + case StateCloseWait, StateClosing, StateLastAck: + if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + s.decRef() + // Just drop the segment as we have + // already received a FIN and this + // segment is after the sequence number + // for the FIN. + return true, nil + } + fallthrough + case StateFinWait1: + fallthrough + case StateFinWait2: + // If we are closed for reads (either due to an + // incoming FIN or the user calling shutdown(.., + // SHUT_RD) then any data past the rcvNxt should + // trigger a RST. + endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + if state == StateFinWait1 { + break + } + + // If it's a retransmission of an old data segment + // or a pure ACK then allow it. + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + s.logicalLen() == 0 { + break + } + + // In FIN-WAIT2 if the socket is fully + // closed(not owned by application on our end + // then the only acceptable segment is a + // FIN. Since FIN can technically also carry + // data we verify that the segment carrying a + // FIN ends at exactly e.rcvNxt+1. + // + // From RFC793 page 25. + // + // For sequence number purposes, the SYN is + // considered to occur before the first actual + // data octet of the segment in which it occurs, + // while the FIN is considered to occur after + // the last actual data octet in a segment in + // which it occurs. + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + s.decRef() + return true, tcpip.ErrConnectionAborted + } + } + // We don't care about receive processing anymore if the receive side // is closed. - if r.closed { - return + // + // NOTE: We still want to permit a FIN as it's possible only our + // end has closed and the peer is yet to send a FIN. Hence we + // compare only the payload. + segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + return true, nil + } + return false, nil +} + +// handleRcvdSegment handles TCP segments directed at the connection managed by +// r as they arrive. It is called by the protocol main loop. +func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { + r.ep.mu.RLock() + state := r.ep.state + closed := r.ep.closed + r.ep.mu.RUnlock() + + if state != StateEstablished { + drop, err := r.handleRcvdSegmentClosing(s, state, closed) + if drop || err != nil { + return drop, err + } } segLen := seqnum.Size(s.data.Size()) segSeq := s.sequenceNumber // If the sequence number range is outside the acceptable range, just - // send an ACK. This is according to RFC 793, page 37. + // send an ACK and stop further processing of the segment. + // This is according to RFC 793, page 68. if !r.acceptable(segSeq, segLen) { r.ep.snd.sendAck() - return + return true, nil } // Defer segment processing if it can't be consumed now. @@ -288,7 +376,7 @@ func (r *receiver) handleRcvdSegment(s *segment) { // have to retransmit. r.ep.snd.sendAck() } - return + return false, nil } // Since we consumed a segment update the receiver's RTT estimate @@ -315,4 +403,67 @@ func (r *receiver) handleRcvdSegment(s *segment) { r.pendingBufUsed -= s.logicalLen() s.decRef() } + return false, nil +} + +// handleTimeWaitSegment handles inbound segments received when the endpoint +// has entered the TIME_WAIT state. +func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) { + segSeq := s.sequenceNumber + segLen := seqnum.Size(s.data.Size()) + + // Just silently drop any RST packets in TIME_WAIT. We do not support + // TIME_WAIT assasination as a result we confirm w/ fix 1 as described + // in https://tools.ietf.org/html/rfc1337#section-3. + if s.flagIsSet(header.TCPFlagRst) { + return false, false + } + + // If it's a SYN and the sequence number is higher than any seen before + // for this connection then try and redirect it to a listening endpoint + // if available. + // + // RFC 1122: + // "When a connection is [...] on TIME-WAIT state [...] + // [a TCP] MAY accept a new SYN from the remote TCP to + // reopen the connection directly, if it: + + // (1) assigns its initial sequence number for the new + // connection to be larger than the largest sequence + // number it used on the previous connection incarnation, + // and + + // (2) returns to TIME-WAIT state if the SYN turns out + // to be an old duplicate". + if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + + return false, true + } + + // Drop the segment if it does not contain an ACK. + if !s.flagIsSet(header.TCPFlagAck) { + return false, false + } + + // Update Timestamp if required. See RFC7323, section-4.3. + if r.ep.sendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + } + + if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + // If it's a FIN-ACK then resetTimeWait and send an ACK, as it + // indicates our final ACK could have been lost. + r.ep.snd.sendAck() + return true, false + } + + // If the sequence number range is outside the acceptable range or + // carries data then just send an ACK. This is according to RFC 793, + // page 37. + // + // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. + if segSeq != r.rcvNxt || segLen != 0 { + r.ep.snd.sendAck() + } + return false, false } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index f4ea5f091..0c1704d74 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -206,17 +206,18 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed wq := &waiter.Queue{} ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { - t.Fatalf("Bind failed: %v", err) + t.Fatalf("Bind failed: %s", err) } if err := ep.Listen(10); err != nil { - t.Fatalf("Listen failed: %v", err) + t.Fatalf("Listen failed: %s", err) } // Send a SYN request. @@ -256,7 +257,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { case <-ch: c.EP, _, err = ep.Accept() if err != nil { - t.Fatalf("Accept failed: %v", err) + t.Fatalf("Accept failed: %s", err) } case <-time.After(1 * time.Second): @@ -264,6 +265,13 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { } } + // Lower stackwide TIME_WAIT timeout so that the reservations + // are released instantly on Close. + tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { + t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err) + } + c.EP.Close() checker.IPv4(t, c.GetPacket(), checker.TCP( checker.SrcPort(context.StackPort), @@ -285,6 +293,11 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // Get the ACK to the FIN we just sent. c.GetPacket() + // Since an active close was done we need to wait for a little more than + // tcpLingerTimeout for the port reservations to be released and the + // socket to move to a CLOSED state. + time.Sleep(20 * time.Millisecond) + // Now resend the same ACK, this ACK should generate a RST as there // should be no endpoint in SYN-RCVD state and we are not using // syn-cookies yet. The reason we send the same ACK is we need a valid @@ -376,6 +389,13 @@ func TestConnectResetAfterClose(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPLinger to 3 seconds so that sockets are marked closed + // after 3 second in FIN_WAIT2 state. + tcpLingerTimeout := 3 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpLingerTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) ep := c.EP c.EP = nil @@ -396,12 +416,24 @@ func TestConnectResetAfterClose(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: 790, - AckNum: c.IRS.Add(1), + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Wait for the ep to give up waiting for a FIN. + time.Sleep(tcpLingerTimeout + 1*time.Second) + + // Now send an ACK and it should trigger a RST as the endpoint should + // not exist anymore. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), RcvWnd: 30000, }) - // Wait for the ep to give up waiting for a FIN, and send a RST. - time.Sleep(3 * time.Second) for { b := c.GetPacket() tcpHdr := header.TCP(header.IPv4(b).Payload()) @@ -413,7 +445,7 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+2), checker.AckNum(790), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), ), @@ -1110,8 +1142,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), - // We shouldn't consume a sequence number on RST. - checker.SeqNum(uint32(c.IRS)+1), + checker.SeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { @@ -3085,6 +3116,13 @@ func TestReadAfterClosedState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + } + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) we, ch := waiter.NewChannelEntry(nil) @@ -3092,12 +3130,12 @@ func TestReadAfterClosedState(t *testing.T) { defer c.WQ.EventUnregister(&we) if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrWouldBlock) } // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } checker.IPv4(t, c.GetPacket(), @@ -3135,10 +3173,9 @@ func TestReadAfterClosedState(t *testing.T) { ), ) - // Give the stack the chance to transition to closed state. Note that since - // both the sender and receiver are now closed, we effectively skip the - // TIME-WAIT state. - time.Sleep(1 * time.Second) + // Give the stack the chance to transition to closed state from + // TIME_WAIT. + time.Sleep(tcpTimeWaitTimeout * 2) if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) @@ -3155,7 +3192,7 @@ func TestReadAfterClosedState(t *testing.T) { peekBuf := make([]byte, 10) n, _, err := c.EP.Peek([][]byte{peekBuf}) if err != nil { - t.Fatalf("Peek failed: %v", err) + t.Fatalf("Peek failed: %s", err) } peekBuf = peekBuf[:n] @@ -3166,7 +3203,7 @@ func TestReadAfterClosedState(t *testing.T) { // Receive data. v, _, err := c.EP.Read(nil) if err != nil { - t.Fatalf("Read failed: %v", err) + t.Fatalf("Read failed: %s", err) } if !bytes.Equal(data, v) { @@ -3176,11 +3213,11 @@ func TestReadAfterClosedState(t *testing.T) { // Now that we drained the queue, check that functions fail with the // right error code. if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Read(nil) = %v, want = %s", err, tcpip.ErrClosedForReceive) } if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive) + t.Fatalf("got c.EP.Peek(...) = %v, want = %s", err, tcpip.ErrClosedForReceive) } } @@ -4347,7 +4384,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { // Send a SYN request. irs := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort, + // pick a different src port for new SYN. + SrcPort: context.TestPort + 1, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: irs, @@ -4893,3 +4931,545 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption) } } + +func TestTCPLingerTimeout(t *testing.T) { + c := context.New(t, 1500 /* mtu */) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + testCases := []struct { + name string + tcpLingerTimeout time.Duration + want time.Duration + }{ + {"NegativeLingerTimeout", -123123, 0}, + {"ZeroLingerTimeout", 0, 0}, + {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, + // Values > stack's TCPLingerTimeout are capped to the stack's + // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) + {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { + t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) + } + var v tcpip.TCPLingerTimeoutOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) + } + if got, want := time.Duration(v), tc.want; got != want { + t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) + } + }) + } +} + +func TestTCPTimeWaitRSTIgnored(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now send a RST and this should be ignored and not + // generate an ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + }) + + c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitOutOfOrder(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitNewSyn(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Send a SYN request w/ sequence number lower than + // the highest sequence number sent. We just reuse + // the same number. + iss = seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + + // Send a SYN request w/ sequence number higher than + // the highest sequence number sent. + iss = seqnum.Value(792) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b = c.GetPacket() + tcpHdr = header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } +} + +func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + time.Sleep(2 * time.Second) + + // Now send a duplicate FIN. This should cause the TIME_WAIT to extend + // by another 5 seconds and also send us a duplicate ACK as it should + // indicate that the final ACK was potentially lost. + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Sleep for 4 seconds so at this point we are 1 second past the + // original tcpLingerTimeout of 5 seconds. + time.Sleep(4 * time.Second) + + // Send an ACK and it should not generate any packet as the socket + // should still be in TIME_WAIT for another another 5 seconds due + // to the duplicate FIN we sent earlier. + *ackHeaders = *finHeaders + ackHeaders.SeqNum = ackHeaders.SeqNum + 1 + ackHeaders.Flags = header.TCPFlagAck + c.SendPacket(nil, ackHeaders) + + c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) + // Now sleep for another 2 seconds so that we are past the + // extended TIME_WAIT of 7 seconds (2 + 5). + time.Sleep(2 * time.Second) + + // Resend the same ACK. + c.SendPacket(nil, ackHeaders) + + // Receive the RST that should be generated as there is no valid + // endpoint. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(uint32(ackHeaders.SeqNum)), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) +} diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 3e5b6b3c3..722d14b53 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -9,7 +9,7 @@ syscall_test(test = "//test/syscalls/linux:accept_bind_stream_test") syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:accept_bind_test", ) @@ -434,7 +434,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_abstract_test", ) @@ -445,7 +445,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_domain_test", ) @@ -458,19 +458,19 @@ syscall_test( syscall_test( size = "large", add_overlay = True, - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_filesystem_test", ) syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_inet_loopback_test", ) syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test", ) @@ -481,13 +481,13 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_loopback_test", ) syscall_test( size = "medium", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test", ) @@ -498,7 +498,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_ip_udp_loopback_test", ) @@ -560,7 +560,7 @@ syscall_test( syscall_test( size = "large", add_overlay = True, - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_unix_pair_test", ) @@ -599,7 +599,7 @@ syscall_test( syscall_test( size = "large", - shard_count = 10, + shard_count = 50, test = "//test/syscalls/linux:socket_unix_unbound_stream_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 93bff8299..f8b8cb724 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2141,6 +2141,7 @@ cc_library( deps = [ ":socket_test_util", "//test/util:test_util", + "//test/util:thread_util", "@com_google_googletest//:gtest", ], alwayslink = 1, diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index ab375aaaf..2eeee352e 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -31,6 +32,7 @@ #include "gtest/gtest.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" @@ -267,6 +269,340 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog) { } } +// TCPFinWait2Test creates a pair of connected sockets then closes one end to +// trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local +// IP/port on a new socket and tries to connect. The connect should fail w/ +// an EADDRINUSE. Then we wait till the FIN_WAIT2 timeout is over and try the +// connect again with a new socket and this time it should succeed. +// +// TCP timers are not S/R today, this can cause this test to be flaky when run +// under random S/R due to timer being reset on a restore. +TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // Lower FIN_WAIT2 state to 5 seconds for test. + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + // Now bind and connect a new socket. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // Disable cooperative saves after this point. As a save between the first + // bind/connect and the second one can cause the linger timeout timer to + // be restarted causing the final bind/connect to fail. + DisableSave ds; + + // TODO(gvisor.dev/issue/1030): Portmanager does not track all 5 tuple + // reservations which causes the bind() to succeed on gVisor but connect + // correctly fails. + if (IsRunningOnGvisor()) { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast(&conn_bound_addr), + conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast(&conn_addr), + conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); + } else { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast(&conn_bound_addr), + conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); + } + + // Sleep for a little over the linger timeout to reduce flakiness in + // save/restore tests. + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + + ds.reset(); + + if (!IsRunningOnGvisor()) { + ASSERT_THAT( + bind(conn_fd2.get(), reinterpret_cast(&conn_bound_addr), + conn_addrlen), + SyscallSucceeds()); + } + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + +// TCPLinger2TimeoutAfterClose creates a pair of connected sockets +// then closes one end to trigger FIN_WAIT2 state for the closed endpont. +// It then sleeps for the TCP_LINGER2 timeout and verifies that bind/ +// connecting the same address succeeds. +// +// TCP timers are not S/R today, this can cause this test to be flaky when run +// under random S/R due to timer being reset on a restore. +TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + constexpr int kTCPLingerTimeout = 5; + EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + + // Now bind and connect a new socket and verify that we can immediately + // rebind the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast(&conn_bound_addr), conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + +// TCPResetAfterClose creates a pair of connected sockets then closes +// one end to trigger FIN_WAIT2 state for the closed endpoint verifies +// that we generate RSTs for any new data after the socket is fully +// closed. +TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // close the connecting FD to trigger FIN_WAIT2 on the connected fd. + conn_fd.reset(); + + int data = 1234; + + // Now send data which should trigger a RST as the other end should + // have timed out and closed the socket. + EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceeds()); + // Sleep for a shortwhile to get a RST back. + absl::SleepFor(absl::Seconds(1)); + + // Try writing again and we should get an EPIPE back. + EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0), + SyscallFailsWithErrno(EPIPE)); + + // Trying to read should return zero as the other end did send + // us a FIN. We do it twice to verify that the RST does not cause an + // ECONNRESET on the read after EOF has been read by applicaiton. + EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(0)); +} + +// This test is disabled under random save as the the restore run +// results in the stack.Seed() being different which can cause +// sequence number of final connect to be one that is considered +// old and can cause the test to be flaky. +TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + // We disable saves after this point as a S/R causes the netstack seed + // to be regenerated which changes what ports/ISN is picked for a given + // tuple (src ip,src port, dst ip, dst port). This can cause the final + // SYN to use a sequence number that looks like one from the current + // connection in TIME_WAIT and will not be accepted causing the test + // to timeout. + // + // TODO(gvisor.dev/issue/940): S/R portSeed/portHint + DisableSave ds; + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Get the address/port bound by the connecting socket. + sockaddr_storage conn_bound_addr; + socklen_t conn_addrlen = connector.addr_len; + ASSERT_THAT( + getsockname(conn_fd.get(), reinterpret_cast(&conn_bound_addr), + &conn_addrlen), + SyscallSucceeds()); + + // close the accept FD to trigger TIME_WAIT on the accepted socket which + // should cause the conn_fd to follow CLOSE_WAIT->LAST_ACK->CLOSED instead of + // TIME_WAIT. + accepted.reset(); + absl::SleepFor(absl::Seconds(1)); + conn_fd.reset(); + absl::SleepFor(absl::Seconds(1)); + + // Now bind and connect a new socket and verify that we can immediately + // rebind the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast(&conn_bound_addr), conn_addrlen), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + reinterpret_cast(&conn_addr), + conn_addrlen), + SyscallSucceeds()); +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetLoopbackTest, ::testing::Values( diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 592448289..a37b49447 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -26,6 +26,7 @@ #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { @@ -243,6 +244,31 @@ TEST_P(TCPSocketPairTest, ShutdownRdAllowsReadOfReceivedDataBeforeEOF) { SyscallSucceedsWithValue(0)); } +// This test verifies that a shutdown(wr) by the server after sending +// data allows the client to still read() the queued data and a client +// close after sending response allows server to read the incoming +// response. +TEST_P(TCPSocketPairTest, ShutdownWrServerClientClose) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + char buf[10] = {}; + ScopedThread t([&]() { + ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(close(sockets->release_first_fd()), + SyscallSucceedsWithValue(0)); + }); + ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(shutdown)(sockets->second_fd(), SHUT_WR), + SyscallSucceedsWithValue(0)); + t.Join(); + + ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); +} + TEST_P(TCPSocketPairTest, ClosedReadNonBlockingSocket) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -696,5 +722,72 @@ TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) { EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(old_cc))); } +// Linux and Netstack both default to a 60s TCP_LINGER2 timeout. +constexpr int kDefaultTCPLingerTimeout = 60; + +TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kDefaultTCPLingerTimeout); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZeroOrLess) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, + sizeof(kZero)), + SyscallSucceedsWithValue(0)); + + constexpr int kNegative = -1234; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kNegative, sizeof(kNegative)), + SyscallSucceedsWithValue(0)); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout + // on linux (defaults to 60 seconds on linux). + constexpr int kAboveDefault = kDefaultTCPLingerTimeout + 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kAboveDefault, sizeof(kAboveDefault)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kDefaultTCPLingerTimeout); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout + // on linux (defaults to 60 seconds on linux). + constexpr int kTCPLingerTimeout = kDefaultTCPLingerTimeout - 1; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, + &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPLingerTimeout); +} + } // namespace testing } // namespace gvisor -- cgit v1.2.3 From 5398530e45634b6f5ea4344d1a34b41cc8123457 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 12 Nov 2019 14:02:53 -0800 Subject: Discover on-link prefixes from Router Advertisements' Prefix Information options This change allows the netstack to do NDP's Prefix Discovery as outlined by RFC 4861 section 6.3.4. If configured to do so, when a new on-link prefix is discovered, the routing table will be updated with a device route through the nic the RA arrived at. Likewise, when such a prefix gets invalidated, the device route will be removed. Note, this change will not break existing uses of netstack as the default configuration for the stack options is set in such a way that Prefix Discovery will not be performed. See `stack.Options` and `stack.NDPConfigurations` for more details. This change reuses 1 option and introduces a new one that is required to take advantage of Prefix Discovery, all available under NDPConfigurations: - HandleRAs: Whether or not NDP RAs are processes - DiscoverOnLinkPrefixes: Whether or not Prefix Discovery is performed (new) Another note: for a NIC to process Prefix Information options (in Router Advertisements), it must not be a router itself. Currently the netstack does not have per-interface routing configuration; the routing/forwarding configuration is controlled stack-wide. Therefore, if the stack is configured to enable forwarding/routing, no router Advertisements (and by extension the Prefix Information options) will be processed. Tests: Unittest to make sure that Prefix Discovery and updates to the routing table only occur if explicitly configured to do so. Unittest to make sure at max stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. PiperOrigin-RevId: 280049278 --- pkg/tcpip/header/ndp_options.go | 25 +- pkg/tcpip/stack/ndp.go | 309 +++++++++++++++++++++-- pkg/tcpip/stack/ndp_test.go | 534 +++++++++++++++++++++++++++++++++++++++- pkg/tcpip/stack/nic.go | 1 + 4 files changed, 836 insertions(+), 33 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index a2b9d7435..1ca6199ef 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -85,17 +85,22 @@ const ( // within an NDPPrefixInformation. ndpPrefixInformationPrefixOffset = 14 - // NDPPrefixInformationInfiniteLifetime is a value that represents - // infinity for the Valid and Preferred Lifetime fields in a NDP Prefix - // Information option. Its value is (2^32 - 1)s = 4294967295s - NDPPrefixInformationInfiniteLifetime = time.Second * 4294967295 - // lengthByteUnits is the multiplier factor for the Length field of an // NDP option. That is, the length field for NDP options is in units of // 8 octets, as per RFC 4861 section 4.6. lengthByteUnits = 8 ) +var ( + // NDPPrefixInformationInfiniteLifetime is a value that represents + // infinity for the Valid and Preferred Lifetime fields in a NDP Prefix + // Information option. Its value is (2^32 - 1)s = 4294967295s + // + // This is a variable instead of a constant so that tests can change + // this value to a smaller value. It should only be modified by tests. + NDPPrefixInformationInfiniteLifetime = time.Second * 4294967295 +) + // NDPOptionIterator is an iterator of NDPOption. // // Note, between when an NDPOptionIterator is obtained and last used, no changes @@ -461,3 +466,13 @@ func (o NDPPrefixInformation) PreferredLifetime() time.Duration { func (o NDPPrefixInformation) Prefix() tcpip.Address { return tcpip.Address(o[ndpPrefixInformationPrefixOffset:][:IPv6AddressSize]) } + +// Subnet returns the Prefix field and Prefix Length field represented in a +// tcpip.Subnet. +func (o NDPPrefixInformation) Subnet() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: o.Prefix(), + PrefixLen: int(o.PrefixLength()), + } + return addrWithPrefix.Subnet() +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 8e49f7a56..8357dca77 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -46,11 +46,18 @@ const ( // defaultDiscoverDefaultRouters is the default configuration for // whether or not to discover default routers from incoming Router - // Advertisements as a host. + // Advertisements, as a host. // // Default = true. defaultDiscoverDefaultRouters = true + // defaultDiscoverOnLinkPrefixes is the default configuration for + // whether or not to discover on-link prefixes from incoming Router + // Advertisements' Prefix Information option, as a host. + // + // Default = true. + defaultDiscoverOnLinkPrefixes = true + // minimumRetransmitTimer is the minimum amount of time to wait between // sending NDP Neighbor solicitation messages. Note, RFC 4861 does // not impose a minimum Retransmit Timer, but we do here to make sure @@ -72,6 +79,14 @@ const ( // // Max = 10. MaxDiscoveredDefaultRouters = 10 + + // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered + // on-link prefixes. The stack should stop discovering new on-link + // prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link + // prefixes. + // + // Max = 10. + MaxDiscoveredOnLinkPrefixes = 10 ) // NDPDispatcher is the interface integrators of netstack must implement to @@ -106,6 +121,24 @@ type NDPDispatcher interface { // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route + + // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is + // discovered. Implementations must return true along with a new valid + // route table if the newly discovered on-link prefix should be + // remembered. If an implementation returns false, the second return + // value will be ignored. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) (bool, []tcpip.Route) + + // OnOnLinkPrefixInvalidated will be called when a discovered on-link + // prefix is invalidated. Implementers must return a new valid route + // table. + // + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. + OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) []tcpip.Route } // NDPConfigurations is the NDP configurations for the netstack. @@ -130,6 +163,11 @@ type NDPConfigurations struct { // be discovered from Router Advertisements. This configuration is // ignored if HandleRAs is false. DiscoverDefaultRouters bool + + // DiscoverOnLinkPrefixes determines whether or not on-link prefixes + // will be discovered from Router Advertisements' Prefix Information + // option. This configuration is ignored if HandleRAs is false. + DiscoverOnLinkPrefixes bool } // DefaultNDPConfigurations returns an NDPConfigurations populated with @@ -140,6 +178,7 @@ func DefaultNDPConfigurations() NDPConfigurations { RetransmitTimer: defaultRetransmitTimer, HandleRAs: defaultHandleRAs, DiscoverDefaultRouters: defaultDiscoverDefaultRouters, + DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, } } @@ -167,6 +206,10 @@ type ndpState struct { // The default routers discovered through Router Advertisements. defaultRouters map[tcpip.Address]defaultRouterState + + // The on-link prefixes discovered through Router Advertisements' Prefix + // Information option. + onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState } // dadState holds the Duplicate Address Detection timer and channel to signal @@ -183,11 +226,11 @@ type dadState struct { } // defaultRouterState holds data associated with a default router discovered by -// a Router Advertisement. +// a Router Advertisement (RA). type defaultRouterState struct { invalidationTimer *time.Timer - // Used to signal the timer not to invalidate the default router (R) in + // Used to inform the timer not to invalidate the default router (R) in // a race condition (T1 is a goroutine that handles an RA from R and T2 // is the goroutine that handles R's invalidation timer firing): // T1: Receive a new RA from R @@ -198,10 +241,33 @@ type defaultRouterState struct { // T2: Obtains NIC's lock & invalidates R immediately // // To resolve this, T1 will check to see if the timer already fired, and - // signal the timer using this channel to not invalidate R, so that once - // T2 obtains the lock, it will see that there is an event on this - // channel and do nothing further. - doNotInvalidateC chan struct{} + // inform the timer using doNotInvalidate to not invalidate R, so that + // once T2 obtains the lock, it will see that it is set to true and do + // nothing further. + doNotInvalidate *bool +} + +// onLinkPrefixState holds data associated with an on-link prefix discovered by +// a Router Advertisement's Prefix Information option (PI) when the NDP +// configurations was configured to do so. +type onLinkPrefixState struct { + invalidationTimer *time.Timer + + // Used to signal the timer not to invalidate the on-link prefix (P) in + // a race condition (T1 is a goroutine that handles a PI for P and T2 + // is the goroutine that handles P's invalidation timer firing): + // T1: Receive a new PI for P + // T1: Obtain the NIC's lock before processing the PI + // T2: P's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends P's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates P immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using doNotInvalidate to not invalidate P, so that + // once T2 obtains the lock, it will see that it is set to true and do + // nothing further. + doNotInvalidate *bool } // startDuplicateAddressDetection performs Duplicate Address Detection. @@ -440,14 +506,13 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { if !timer.Stop() { // If we reach this point, then we know the // timer fired after we already took the NIC - // lock. Signal the timer so that once it - // obtains the lock, it doesn't actually - // invalidate the router as we just got a new - // RA that refreshes its lifetime to a non-zero - // value. See - // defaultRouterState.doNotInvalidateC for more + // lock. Inform the timer not to invalidate the + // router when it obtains the lock as we just + // got a new RA that refreshes its lifetime to a + // non-zero value. See + // defaultRouterState.doNotInvalidate for more // details. - rtr.doNotInvalidateC <- struct{}{} + *rtr.doNotInvalidate = true } timer.Reset(rl) @@ -459,8 +524,117 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { } } - // TODO(b/140948104): Do Prefix Discovery. - // TODO(b/141556115): Do Parameter Discovery. + // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter + // Discovery. + + // We know the options is valid as far as wire format is concerned since + // we got the Router Advertisement, as documented by this fn. Given this + // we do not check the iterator for errors on calls to Next. + it, _ := ra.Options().Iter(false) + for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { + switch opt.Type() { + case header.NDPPrefixInformationType: + if !ndp.configs.DiscoverOnLinkPrefixes { + continue + } + + pi := opt.(header.NDPPrefixInformation) + + prefix := pi.Subnet() + + // Is the prefix a link-local? + if header.IsV6LinkLocalAddress(prefix.ID()) { + // ...Yes, skip as per RFC 4861 section 6.3.4. + continue + } + + // Is the Prefix Length 0? + if prefix.Prefix() == 0 { + // ...Yes, skip as this is an invalid prefix + // as all IPv6 addresses cannot be on-link. + continue + } + + if !pi.OnLinkFlag() { + // Not on-link so don't "discover" it as an + // on-link prefix. + continue + } + + prefixState, ok := ndp.onLinkPrefixes[prefix] + vl := pi.ValidLifetime() + switch { + case !ok && vl == 0: + // Don't know about this prefix but has a zero + // valid lifetime, so just ignore. + continue + + case !ok && vl != 0: + // This is a new on-link prefix we are + // discovering. + // + // Only remember it if we currently know about + // less than MaxDiscoveredOnLinkPrefixes on-link + // prefixes. + if len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes { + ndp.rememberOnLinkPrefix(prefix, vl) + } + continue + + case ok && vl == 0: + // We know about the on-link prefix, but it is + // no longer to be considered on-link, so + // invalidate it. + ndp.invalidateOnLinkPrefix(prefix) + continue + } + + // This is an already discovered on-link prefix with a + // new non-zero valid lifetime. + // Update the invalidation timer. + timer := prefixState.invalidationTimer + + if timer == nil && vl >= header.NDPPrefixInformationInfiniteLifetime { + // Had infinite valid lifetime before and + // continues to have an invalid lifetime. Do + // nothing further. + continue + } + + if timer != nil && !timer.Stop() { + // If we reach this point, then we know the + // timer already fired after we took the NIC + // lock. Inform the timer to not invalidate + // the prefix once it obtains the lock as we + // just got a new PI that refeshes its lifetime + // to a non-zero value. See + // onLinkPrefixState.doNotInvalidate for more + // details. + *prefixState.doNotInvalidate = true + } + + if vl >= header.NDPPrefixInformationInfiniteLifetime { + // Prefix is now valid forever so we don't need + // an invalidation timer. + prefixState.invalidationTimer = nil + ndp.onLinkPrefixes[prefix] = prefixState + continue + } + + if timer != nil { + // We already have a timer so just reset it to + // expire after the new valid lifetime. + timer.Reset(vl) + continue + } + + // We do not have a timer so just create a new one. + prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate) + ndp.onLinkPrefixes[prefix] = prefixState + } + + // TODO(b/141556115): Do (MTU) Parameter Discovery. + } } // invalidateDefaultRouter invalidates a discovered default router. @@ -477,8 +651,8 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { rtr.invalidationTimer.Stop() rtr.invalidationTimer = nil - close(rtr.doNotInvalidateC) - rtr.doNotInvalidateC = nil + *rtr.doNotInvalidate = true + rtr.doNotInvalidate = nil delete(ndp.defaultRouters, ip) @@ -508,9 +682,9 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { } // Used to signal the timer not to invalidate the default router (R) in - // a race condition. See defaultRouterState.doNotInvalidateC for more + // a race condition. See defaultRouterState.doNotInvalidate for more // details. - doNotInvalidateC := make(chan struct{}, 1) + var doNotInvalidate bool ndp.defaultRouters[ip] = defaultRouterState{ invalidationTimer: time.AfterFunc(rl, func() { @@ -519,16 +693,103 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { ndp.nic.mu.Lock() defer ndp.nic.mu.Unlock() - select { - case <-doNotInvalidateC: + if doNotInvalidate { + doNotInvalidate = false return - default: } ndp.invalidateDefaultRouter(ip) }), - doNotInvalidateC: doNotInvalidateC, + doNotInvalidate: &doNotInvalidate, + } + + ndp.nic.stack.routeTable = routeTable +} + +// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 +// address with prefix prefix with lifetime l. +// +// The prefix identified by prefix MUST NOT already be known. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { + if ndp.nic.stack.ndpDisp == nil { + return + } + + // Inform the integrator when we discovered an on-link prefix. + remember, routeTable := ndp.nic.stack.ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) + if !remember { + // Informed by the integrator to not remember the prefix, do + // nothing further. + return + } + + // Used to signal the timer not to invalidate the on-link prefix (P) in + // a race condition. See onLinkPrefixState.doNotInvalidate for more + // details. + var doNotInvalidate bool + var timer *time.Timer + + // Only create a timer if the lifetime is not infinite. + if l < header.NDPPrefixInformationInfiniteLifetime { + timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate) + } + + ndp.onLinkPrefixes[prefix] = onLinkPrefixState{ + invalidationTimer: timer, + doNotInvalidate: &doNotInvalidate, } ndp.nic.stack.routeTable = routeTable } + +// invalidateOnLinkPrefix invalidates a discovered on-link prefix. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { + s, ok := ndp.onLinkPrefixes[prefix] + + // Is the on-link prefix still discovered? + if !ok { + // ...Nope, do nothing further. + return + } + + if s.invalidationTimer != nil { + s.invalidationTimer.Stop() + s.invalidationTimer = nil + *s.doNotInvalidate = true + } + + s.doNotInvalidate = nil + + delete(ndp.onLinkPrefixes, prefix) + + // Let the integrator know a discovered on-link prefix is invalidated. + if ndp.nic.stack.ndpDisp != nil { + ndp.nic.stack.routeTable = ndp.nic.stack.ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) + } +} + +// prefixInvalidationCallback returns a new on-link prefix invalidation timer +// for prefix that fires after vl. +// +// doNotInvalidate is used to signal the timer when it fires at the same time +// that a prefix's valid lifetime gets refreshed. See +// onLinkPrefixState.doNotInvalidate for more details. +func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Duration, doNotInvalidate *bool) *time.Timer { + return time.AfterFunc(vl, func() { + ndp.nic.stack.mu.Lock() + defer ndp.nic.stack.mu.Unlock() + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if *doNotInvalidate { + *doNotInvalidate = false + return + } + + ndp.invalidateOnLinkPrefix(prefix) + }) +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 50ce1bbfa..494244368 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -96,6 +96,13 @@ type ndpRouterEvent struct { discovered bool } +type ndpPrefixEvent struct { + nicID tcpip.NICID + prefix tcpip.Subnet + // true if prefix was discovered, false if invalidated. + discovered bool +} + var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP @@ -104,6 +111,8 @@ type ndpDispatcher struct { dadC chan ndpDADEvent routerC chan ndpRouterEvent rememberRouter bool + prefixC chan ndpPrefixEvent + rememberPrefix bool routeTable []tcpip.Route } @@ -169,6 +178,54 @@ func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip return rt } +// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered. +func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) (bool, []tcpip.Route) { + if n.prefixC != nil { + n.prefixC <- ndpPrefixEvent{ + nicID, + prefix, + true, + } + } + + if !n.rememberPrefix { + return false, nil + } + + rt := append([]tcpip.Route(nil), n.routeTable...) + rt = append(rt, tcpip.Route{ + Destination: prefix, + NIC: nicID, + }) + n.routeTable = rt + return true, rt +} + +// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated. +func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) []tcpip.Route { + if n.prefixC != nil { + n.prefixC <- ndpPrefixEvent{ + nicID, + prefix, + false, + } + } + + rt := make([]tcpip.Route, 0) + exclude := tcpip.Route{ + Destination: prefix, + NIC: nicID, + } + + for _, r := range n.routeTable { + if r != exclude { + rt = append(rt, r) + } + } + n.routeTable = rt + return rt +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -682,16 +739,19 @@ func TestSetNDPConfigurations(t *testing.T) { } } -// raBuf returns a valid NDP Router Advertisement. +// raBufWithOpts returns a valid NDP Router Advertisement with options. // -// Note, raBuf does not populate any of the RA fields other than the +// Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { - icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { + icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(0) + ra := header.NDPRouterAdvert(pkt.NDPPayload()) + opts := ra.Options() + opts.Serialize(optSer) // Populate the Router Lifetime. binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl) pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) @@ -708,6 +768,35 @@ func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} } +// raBuf returns a valid NDP Router Advertisement. +// +// Note, raBuf does not populate any of the RA fields other than the +// Router Lifetime. +func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { + return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) +} + +// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix +// Information option. +// +// Note, raBufWithPI does not populate any of the RA fields other than the +// Router Lifetime. +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink bool, vl uint32) tcpip.PacketBuffer { + flags := uint8(0) + if onLink { + flags |= 128 + } + + buf := [30]byte{} + buf[0] = uint8(prefix.PrefixLen) + buf[1] = flags + binary.BigEndian.PutUint32(buf[2:], vl) + copy(buf[14:], prefix.Address) + return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{ + header.NDPPrefixInformation(buf[:]), + }) +} + // TestNoRouterDiscovery tests that router discovery will not be performed if // configured not to. func TestNoRouterDiscovery(t *testing.T) { @@ -1011,3 +1100,440 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt) } } + +// TestNoPrefixDiscovery tests that prefix discovery will not be performed if +// configured not to. +func TestNoPrefixDiscovery(t *testing.T) { + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + PrefixLen: 64, + } + + // Being configured to discover prefixes means handle and + // discover are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, discover = + // true and forwarding = false (the required configuration to do + // prefix discovery) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + discover := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + DiscoverOnLinkPrefixes: discover, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with prefix with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, 10)) + + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix when configured not to") + case <-time.After(defaultTimeout): + } + }) + } +} + +// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not +// remember a discovered on-link prefix when the dispatcher asks it not to. +func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + PrefixLen: 64, + } + subnet := prefix.Subnet() + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + routeTable := []tcpip.Route{ + { + header.IPv6EmptySubnet, + llAddr3, + 1, + }, + } + s.SetRouteTable(routeTable) + + // Rx an RA with prefix with a short lifetime. + const lifetime = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, lifetime)) + select { + case r := <-ndpDisp.prefixC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.prefix != subnet { + t.Fatalf("got r.prefix = %s, want = %s", r.prefix, subnet) + } + if !r.discovered { + t.Fatal("got r.discovered = false, want = true") + } + case <-time.After(defaultTimeout): + t.Fatal("timeout waiting for prefix discovery event") + } + + // Original route table should not have been modified. + if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + } + + // Wait for the normal invalidation time plus some buffer to + // make sure we do not actually receive any invalidation events as + // we should not have remembered the prefix in the first place. + select { + case <-ndpDisp.prefixC: + t.Fatal("should not have received any prefix events") + case <-time.After(lifetime*time.Second + defaultTimeout): + } + + // Original route table should not have been modified. + if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + } +} + +func TestPrefixDiscovery(t *testing.T) { + prefix1 := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + PrefixLen: 64, + } + prefix2 := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x09\x00\x00\x00\x00\x00\x00\x00\x00"), + PrefixLen: 64, + } + prefix3 := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x09\x0a\x00\x00\x00\x00\x00\x00\x00"), + PrefixLen: 72, + } + subnet1 := prefix1.Subnet() + subnet2 := prefix2.Subnet() + subnet3 := prefix3.Subnet() + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 10), + rememberPrefix: true, + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + waitForEvent := func(subnet tcpip.Subnet, discovered bool, timeout time.Duration) { + t.Helper() + + select { + case r := <-ndpDisp.prefixC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.prefix != subnet { + t.Fatalf("got r.prefix = %s, want = %s", r.prefix, subnet) + } + if r.discovered != discovered { + t.Fatalf("got r.discovered = %t, want = %t", r.discovered, discovered) + } + case <-time.After(timeout): + t.Fatal("timeout waiting for prefix discovery event") + } + } + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 0)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly discovered a prefix with 0 lifetime") + case <-time.After(defaultTimeout): + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 100)) + waitForEvent(subnet1, true, defaultTimeout) + + // Should have added a device route for subnet1 through the nic. + if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, 100)) + waitForEvent(subnet2, true, defaultTimeout) + + // Should have added a device route for subnet2 through the nic. + if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Receive an RA with prefix3 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, 100)) + waitForEvent(subnet3, true, defaultTimeout) + + // Should have added a device route for subnet3 through the nic. + if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Receive an RA with prefix1 in a PI with lifetime = 0. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 0)) + waitForEvent(subnet1, false, defaultTimeout) + + // Should have removed the device route for subnet1 through the nic. + if got, want := s.GetRouteTable(), []tcpip.Route{{subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Receive an RA with prefix2 in a PI with lesser lifetime. + lifetime := uint32(2) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, lifetime)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly received prefix event when updating lifetime") + case <-time.After(defaultTimeout): + } + + // Should not have updated route table. + if got, want := s.GetRouteTable(), []tcpip.Route{{subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Wait for prefix2's most recent invalidation timer plus some buffer to + // expire. + waitForEvent(subnet2, false, time.Duration(lifetime)*time.Second+defaultTimeout) + + // Should have removed the device route for subnet2 through the nic. + if got, want := s.GetRouteTable(), []tcpip.Route{{subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + } + + // Receive RA to invalidate prefix3. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, 0)) + waitForEvent(subnet3, false, defaultTimeout) + + // Should not have any routes. + if got := len(s.GetRouteTable()); got != 0 { + t.Fatalf("got len(s.GetRouteTable()) = %d, want = 0", got) + } +} + +func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { + // Update the infinite lifetime value to a smaller value so we can test + // that when we receive a PI with such a lifetime value, we do not + // invalidate the prefix. + const testInfiniteLifetimeSeconds = 2 + const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second + saved := header.NDPPrefixInformationInfiniteLifetime + header.NDPPrefixInformationInfiniteLifetime = testInfiniteLifetime + defer func() { + header.NDPPrefixInformationInfiniteLifetime = saved + }() + + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), + PrefixLen: 64, + } + subnet := prefix.Subnet() + + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, 10), + rememberPrefix: true, + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + waitForEvent := func(discovered bool, timeout time.Duration) { + t.Helper() + + select { + case r := <-ndpDisp.prefixC: + if r.nicID != 1 { + t.Errorf("got r.nicID = %d, want = 1", r.nicID) + } + if r.prefix != subnet { + t.Errorf("got r.prefix = %s, want = %s", r.prefix, subnet) + } + if r.discovered != discovered { + t.Errorf("got r.discovered = %t, want = %t", r.discovered, discovered) + } + case <-time.After(timeout): + t.Fatal("timeout waiting for prefix discovery event") + } + } + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix in an NDP Prefix Information option (PI) + // with infinite valid lifetime which should not get invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds)) + waitForEvent(true, defaultTimeout) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After(testInfiniteLifetime + defaultTimeout): + } + + // Receive an RA with finite lifetime. + // The prefix should get invalidated after 1s. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds-1)) + waitForEvent(false, testInfiniteLifetime) + + // Receive an RA with finite lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds-1)) + waitForEvent(true, defaultTimeout) + + // Receive an RA with prefix with an infinite lifetime. + // The prefix should not be invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After(testInfiniteLifetime + defaultTimeout): + } + + // Receive an RA with a prefix with a lifetime value greater than the + // set infinite lifetime value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds+1)) + select { + case <-ndpDisp.prefixC: + t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") + case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout): + } + + // Receive an RA with 0 lifetime. + // The prefix should get invalidated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, 0)) + waitForEvent(false, defaultTimeout) +} + +// TestPrefixDiscoveryMaxRouters tests that only +// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. +func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { + ndpDisp := ndpDispatcher{ + prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), + rememberPrefix: true, + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: false, + DiscoverOnLinkPrefixes: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2) + expectedRt := [stack.MaxDiscoveredOnLinkPrefixes]tcpip.Route{} + prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} + + // Receive an RA with 2 more than the max number of discovered on-link + // prefixes. + for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { + prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0} + prefixAddr[7] = byte(i) + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address(prefixAddr[:]), + PrefixLen: 64, + } + prefixes[i] = prefix.Subnet() + buf := [30]byte{} + buf[0] = uint8(prefix.PrefixLen) + buf[1] = 128 + binary.BigEndian.PutUint32(buf[2:], 10) + copy(buf[14:], prefix.Address) + + optSer[i] = header.NDPPrefixInformation(buf[:]) + + if i < stack.MaxDiscoveredOnLinkPrefixes { + expectedRt[i] = tcpip.Route{prefixes[i], tcpip.Address([]byte(nil)), 1} + } + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) + for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { + if i < stack.MaxDiscoveredOnLinkPrefixes { + select { + case r := <-ndpDisp.prefixC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.prefix != prefixes[i] { + t.Fatalf("got r.prefix = %s, want = %s", r.prefix, prefixes[i]) + } + if !r.discovered { + t.Fatal("got r.discovered = false, want = true") + } + case <-time.After(defaultTimeout): + t.Fatal("timeout waiting for prefix discovery event") + } + } else { + select { + case <-ndpDisp.prefixC: + t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes") + case <-time.After(defaultTimeout): + } + } + } + + // Should only have device routes for the first + // stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes. + if got := s.GetRouteTable(); !cmp.Equal(got, expectedRt[:]) { + t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt) + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 28a28ae6e..9ed9e1e7c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -118,6 +118,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback configs: stack.ndpConfigs, dad: make(map[tcpip.Address]dadState), defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), }, } nic.ndp.nic = nic -- cgit v1.2.3 From 3f51bef8cdad5f0555e7c6b05f777769d23aaf77 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 12 Nov 2019 15:48:34 -0800 Subject: Do not handle TCP packets that include a non-unicast IP address This change drops TCP packets with a non-unicast IP address as the source or destination address as TCP is meant for communication between two endpoints. Test: Make sure that if the source or destination address contains a non-unicast address, no TCP packet is sent in response and the packet is dropped. PiperOrigin-RevId: 280073731 --- pkg/tcpip/stack/transport_demuxer.go | 45 ++++- pkg/tcpip/transport/tcp/tcp_test.go | 204 +++++++++++++++++++++ pkg/tcpip/transport/tcp/testing/context/context.go | 34 +++- 3 files changed, 269 insertions(+), 14 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index cb805522b..67c21be42 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -389,8 +389,8 @@ var loopbackSubnet = func() tcpip.Subnet { }() // deliverPacket attempts to find one or more matching transport endpoints, and -// then, if matches are found, delivers the packet to them. Returns true if it -// found one or more endpoints, false otherwise. +// then, if matches are found, delivers the packet to them. Returns true if +// the packet no longer needs to be handled. func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { @@ -400,13 +400,38 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto eps.mu.RLock() // Determine which transport endpoint or endpoints to deliver this packet to. - // If the packet is a broadcast or multicast, then find all matching - // transport endpoints. + // If the packet is a UDP broadcast or multicast, then find all matching + // transport endpoints. If the packet is a TCP packet with a non-unicast + // source or destination address, then do nothing further and instruct + // the caller to do the same. var destEps []*endpointsByNic - if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { - destEps = d.findAllEndpointsLocked(eps, id) - } else if ep := d.findEndpointLocked(eps, id); ep != nil { - destEps = append(destEps, ep) + switch protocol { + case header.UDPProtocolNumber: + if isMulticastOrBroadcast(id.LocalAddress) { + destEps = d.findAllEndpointsLocked(eps, id) + break + } + + if ep := d.findEndpointLocked(eps, id); ep != nil { + destEps = append(destEps, ep) + } + + case header.TCPProtocolNumber: + if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) { + // TCP can only be used to communicate between a single + // source and a single destination; the addresses must + // be unicast. + eps.mu.RUnlock() + r.Stats().TCP.InvalidSegmentsReceived.Increment() + return true + } + + fallthrough + + default: + if ep := d.findEndpointLocked(eps, id); ep != nil { + destEps = append(destEps, ep) + } } eps.mu.RUnlock() @@ -587,3 +612,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN func isMulticastOrBroadcast(addr tcpip.Address) bool { return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) } + +func isUnicast(addr tcpip.Address) bool { + return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr) +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 84579ce52..b443fe9dc 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4242,6 +4242,210 @@ func TestListenBacklogFull(t *testing.T) { } } +// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a +// non unicast IPv4 address are not accepted. +func TestListenNoAcceptNonUnicastV4(t *testing.T) { + multicastAddr := tcpip.Address("\xe0\x00\x01\x02") + otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + + tests := []struct { + name string + srcAddr tcpip.Address + dstAddr tcpip.Address + }{ + { + "SourceUnspecified", + header.IPv4Any, + context.StackAddr, + }, + { + "SourceBroadcast", + header.IPv4Broadcast, + context.StackAddr, + }, + { + "SourceOurMulticast", + multicastAddr, + context.StackAddr, + }, + { + "SourceOtherMulticast", + otherMulticastAddr, + context.StackAddr, + }, + { + "DestUnspecified", + context.TestAddr, + header.IPv4Any, + }, + { + "DestBroadcast", + context.TestAddr, + header.IPv4Broadcast, + }, + { + "DestOurMulticast", + context.TestAddr, + multicastAddr, + }, + { + "DestOtherMulticast", + context.TestAddr, + otherMulticastAddr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil { + t.Fatalf("JoinGroup failed: %s", err) + } + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + irs := seqnum.Value(789) + c.SendPacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, test.srcAddr, test.dstAddr) + c.CheckNoPacket("Should not have received a response") + + // Handle normal packet. + c.SendPacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, context.TestAddr, context.StackAddr) + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + }) + } +} + +// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a +// non unicast IPv6 address are not accepted. +func TestListenNoAcceptNonUnicastV6(t *testing.T) { + multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") + otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") + + tests := []struct { + name string + srcAddr tcpip.Address + dstAddr tcpip.Address + }{ + { + "SourceUnspecified", + header.IPv6Any, + context.StackV6Addr, + }, + { + "SourceAllNodes", + header.IPv6AllNodesMulticastAddress, + context.StackV6Addr, + }, + { + "SourceOurMulticast", + multicastAddr, + context.StackV6Addr, + }, + { + "SourceOtherMulticast", + otherMulticastAddr, + context.StackV6Addr, + }, + { + "DestUnspecified", + context.TestV6Addr, + header.IPv6Any, + }, + { + "DestAllNodes", + context.TestV6Addr, + header.IPv6AllNodesMulticastAddress, + }, + { + "DestOurMulticast", + context.TestV6Addr, + multicastAddr, + }, + { + "DestOtherMulticast", + context.TestV6Addr, + otherMulticastAddr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil { + t.Fatalf("JoinGroup failed: %s", err) + } + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + irs := seqnum.Value(789) + c.SendV6PacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, test.srcAddr, test.dstAddr) + c.CheckNoPacket("Should not have received a response") + + // Handle normal packet. + c.SendV6PacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, context.TestV6Addr, context.StackV6Addr) + checker.IPv6(t, c.GetV6Packet(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + }) + } +} + func TestListenSynRcvdQueueFull(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 4854e719d..0a733fa94 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -309,6 +309,12 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt // BuildSegment builds a TCP segment based on the given Headers and payload. func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView { + return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr) +} + +// BuildSegmentWithAddrs builds a TCP segment based on the given Headers, +// payload and source and destination IPv4 addresses. +func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView { // Allocate a buffer for data and headers. buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -321,8 +327,8 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(tcp.ProtocolNumber), - SrcAddr: TestAddr, - DstAddr: StackAddr, + SrcAddr: src, + DstAddr: dst, }) ip.SetChecksum(^ip.CalculateChecksum()) @@ -339,7 +345,7 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView }) // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestAddr, StackAddr, uint16(len(t))) + xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) // Calculate the TCP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -365,6 +371,15 @@ func (c *Context) SendPacket(payload []byte, h *Headers) { }) } +// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload +// & TCPheaders) in an IPv4 packet via the link layer endpoint using the +// provided source and destination IPv4 addresses. +func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { + c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Data: c.BuildSegmentWithAddrs(payload, h, src, dst), + }) +} + // SendAck sends an ACK packet. func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) { c.SendAckWithSACK(seq, bytesReceived, nil) @@ -490,6 +505,13 @@ func (c *Context) GetV6Packet() []byte { // SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of // the context. func (c *Context) SendV6Packet(payload []byte, h *Headers) { + c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr) +} + +// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer +// endpoint of the context using the provided source and destination IPv6 +// addresses. +func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -500,8 +522,8 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { PayloadLength: uint16(header.TCPMinimumSize + len(payload)), NextHeader: uint8(tcp.ProtocolNumber), HopLimit: 65, - SrcAddr: TestV6Addr, - DstAddr: StackV6Addr, + SrcAddr: src, + DstAddr: dst, }) // Initialize the TCP header. @@ -517,7 +539,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) { }) // Calculate the TCP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestV6Addr, StackV6Addr, uint16(len(t))) + xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) // Calculate the TCP checksum and set it. xsum = header.Checksum(payload, xsum) -- cgit v1.2.3 From 3f7d9370909a598cf83dfa07a1e87545a66e182f Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 14 Nov 2019 10:14:07 -0800 Subject: Use PacketBuffers for outgoing packets. PiperOrigin-RevId: 280455453 --- pkg/tcpip/buffer/prependable.go | 6 ++ pkg/tcpip/link/channel/channel.go | 42 ++++---- pkg/tcpip/link/fdbased/endpoint.go | 21 ++-- pkg/tcpip/link/fdbased/endpoint_test.go | 10 +- pkg/tcpip/link/loopback/loopback.go | 30 +++--- pkg/tcpip/link/muxed/injectable.go | 6 +- pkg/tcpip/link/muxed/injectable_test.go | 12 ++- pkg/tcpip/link/sharedmem/sharedmem.go | 13 +-- pkg/tcpip/link/sharedmem/sharedmem_test.go | 59 ++++++++--- pkg/tcpip/link/sniffer/sniffer.go | 33 +++--- pkg/tcpip/link/waitable/waitable.go | 8 +- pkg/tcpip/link/waitable/waitable_test.go | 10 +- pkg/tcpip/network/arp/arp.go | 26 +++-- pkg/tcpip/network/arp/arp_test.go | 10 +- pkg/tcpip/network/ip_test.go | 18 ++-- pkg/tcpip/network/ipv4/icmp.go | 6 +- pkg/tcpip/network/ipv4/ipv4.go | 116 ++++++++++++--------- pkg/tcpip/network/ipv4/ipv4_test.go | 41 ++++---- pkg/tcpip/network/ipv6/icmp.go | 13 ++- pkg/tcpip/network/ipv6/icmp_test.go | 16 +-- pkg/tcpip/network/ipv6/ipv6.go | 16 +-- pkg/tcpip/packet_buffer.go | 21 ++-- pkg/tcpip/packet_buffer_state.go | 1 + pkg/tcpip/stack/ndp.go | 4 +- pkg/tcpip/stack/ndp_test.go | 2 +- pkg/tcpip/stack/nic.go | 6 +- pkg/tcpip/stack/registration.go | 17 +-- pkg/tcpip/stack/route.go | 12 +-- pkg/tcpip/stack/stack.go | 6 +- pkg/tcpip/stack/stack_test.go | 22 ++-- pkg/tcpip/stack/transport_test.go | 9 +- pkg/tcpip/transport/icmp/endpoint.go | 12 ++- pkg/tcpip/transport/raw/endpoint.go | 9 +- pkg/tcpip/transport/tcp/connect.go | 5 +- pkg/tcpip/transport/tcp/testing/context/context.go | 18 ++-- pkg/tcpip/transport/udp/endpoint.go | 5 +- pkg/tcpip/transport/udp/protocol.go | 10 +- pkg/tcpip/transport/udp/udp_test.go | 14 +-- 38 files changed, 406 insertions(+), 279 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go index 48a2a2713..ba21f4eca 100644 --- a/pkg/tcpip/buffer/prependable.go +++ b/pkg/tcpip/buffer/prependable.go @@ -77,3 +77,9 @@ func (p *Prependable) Prepend(size int) []byte { p.usedIdx -= size return p.View()[:size:size] } + +// DeepCopy copies p and the bytes backing it. +func (p Prependable) DeepCopy() Prependable { + p.buf = append(View(nil), p.buf...) + return p +} diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 22eefb564..9fe8e9f9d 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -25,10 +25,9 @@ import ( // PacketInfo holds all the information about an outbound packet. type PacketInfo struct { - Header buffer.View - Payload buffer.View - Proto tcpip.NetworkProtocolNumber - GSO *stack.GSO + Pkt tcpip.PacketBuffer + Proto tcpip.NetworkProtocolNumber + GSO *stack.GSO } // Endpoint is link layer endpoint that stores outbound packets in a channel @@ -118,12 +117,11 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { p := PacketInfo{ - Header: hdr.View(), - Proto: protocol, - Payload: payload.ToView(), - GSO: gso, + Pkt: pkt, + Proto: protocol, + GSO: gso, } select { @@ -139,15 +137,16 @@ func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, hdrs []stack.Pac payloadView := payload.ToView() n := 0 packetLoop: - for i := range hdrs { - hdr := &hdrs[i].Hdr - off := hdrs[i].Off - size := hdrs[i].Size + for _, hdr := range hdrs { + off := hdr.Off + size := hdr.Size p := PacketInfo{ - Header: hdr.View(), - Proto: protocol, - Payload: buffer.NewViewFromBytes(payloadView[off : off+size]), - GSO: gso, + Pkt: tcpip.PacketBuffer{ + Header: hdr.Hdr, + Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(), + }, + Proto: protocol, + GSO: gso, } select { @@ -162,12 +161,11 @@ packetLoop: } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Header: packet.ToView(), - Proto: 0, - Payload: buffer.View{}, - GSO: nil, + Pkt: tcpip.PacketBuffer{Data: vv}, + Proto: 0, + GSO: nil, } select { diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index edef7db26..98109c5dc 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -386,10 +386,11 @@ const ( // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { // Add ethernet header if needed. - eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ DstAddr: r.RemoteLinkAddress, Type: protocol, @@ -408,13 +409,13 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen vnetHdr := virtioNetHdr{} vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr) if gso != nil { - vnetHdr.hdrLen = uint16(hdr.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) if gso.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen vnetHdr.csumOffset = gso.CsumOffset } - if gso.Type != stack.GSONone && uint16(payload.Size()) > gso.MSS { + if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { switch gso.Type { case stack.GSOTCPv4: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 @@ -427,14 +428,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } } - return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, hdr.View(), payload.ToView()) + return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) } - if payload.Size() == 0 { - return rawfile.NonBlockingWrite(e.fds[0], hdr.View()) + if pkt.Data.Size() == 0 { + return rawfile.NonBlockingWrite(e.fds[0], pkt.Header.View()) } - return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil) + return rawfile.NonBlockingWrite3(e.fds[0], pkt.Header.View(), pkt.Data.ToView(), nil) } // WritePackets writes outbound packets to the file descriptor. If it is not @@ -555,8 +556,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.Pac } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { - return rawfile.NonBlockingWrite(e.fds[0], packet.ToView()) +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + return rawfile.NonBlockingWrite(e.fds[0], vv.ToView()) } // InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 7e08e033b..2066987eb 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -168,7 +168,10 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, hdr, payload.ToVectorisedView(), proto); err != nil { + if err := c.ep.WritePacket(r, gso, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -258,7 +261,10 @@ func TestPreserveSrcAddress(t *testing.T) { // WritePacket panics given a prependable with anything less than // the minimum size of the ethernet header. hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { + if err := c.ep.WritePacket(r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.VectorisedView{}, + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index bc5d8a2f3..563a67188 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,16 +76,16 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - - // Because we're immediately turning around and writing the packet back to the - // rx path, we intentionally don't preserve the remote and local link - // addresses from the stack.Route we're passed. +func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + + // Because we're immediately turning around and writing the packet back + // to the rx path, we intentionally don't preserve the remote and local + // link addresses from the stack.Route we're passed. e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, tcpip.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) return nil @@ -97,17 +97,17 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, hdrs []stack.Packe } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Reject the packet if it's shorter than an ethernet header. - if packet.Size() < header.EthernetMinimumSize { + if vv.Size() < header.EthernetMinimumSize { return tcpip.ErrBadAddress } - // There should be an ethernet header at the beginning of packet. - linkHeader := header.Ethernet(packet.First()[:header.EthernetMinimumSize]) - packet.TrimFront(len(linkHeader)) + // There should be an ethernet header at the beginning of vv. + linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), tcpip.PacketBuffer{ - Data: packet, + Data: vv, LinkHeader: buffer.View(linkHeader), }) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 9a8e8ebfe..55ed2a28e 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -98,15 +98,15 @@ func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs [ // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { - return endpoint.WritePacket(r, gso, hdr, payload, protocol) + return endpoint.WritePacket(r, gso, protocol, pkt) } return tcpip.ErrNoRoute } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (m *InjectableEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { // WriteRawPacket doesn't get a route or network address, so there's // nowhere to write this. return tcpip.ErrNoRoute diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 9cd300af8..63b249837 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -50,8 +50,10 @@ func TestInjectableEndpointDispatch(t *testing.T) { hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, - buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), ipv4.ProtocolNumber) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), + }) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) @@ -68,8 +70,10 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { hdr := buffer.NewPrependable(1) hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, hdr, - buffer.NewView(0).ToVectorisedView(), ipv4.ProtocolNumber) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.NewView(0).ToVectorisedView(), + }) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) if err != nil { diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 2bace5298..88947a03a 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -185,9 +185,10 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { // Add the ethernet header here. - eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ DstAddr: r.RemoteLinkAddress, Type: protocol, @@ -199,10 +200,10 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependa } eth.Encode(ethHdr) - v := payload.ToView() + v := pkt.Data.ToView() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(hdr.View(), v) + ok := e.tx.transmit(pkt.Header.View(), v) e.mu.Unlock() if !ok { @@ -218,8 +219,8 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.Packe } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { - v := packet.ToView() +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + v := vv.ToView() // Transmit the packet. e.mu.Lock() ok := e.tx.transmit(v, buffer.View{}) diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 199406886..89603c48f 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -273,7 +273,10 @@ func TestSimpleSend(t *testing.T) { randomFill(buf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), proto); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -342,7 +345,9 @@ func TestPreserveSrcAddressInSend(t *testing.T) { hdr := buffer.NewPrependable(header.EthernetMinimumSize) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -396,7 +401,10 @@ func TestFillTxQueue(t *testing.T) { for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -411,7 +419,10 @@ func TestFillTxQueue(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -436,7 +447,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -456,7 +470,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -471,7 +488,10 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -494,7 +514,10 @@ func TestFillTxMemory(t *testing.T) { ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -510,7 +533,10 @@ func TestFillTxMemory(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber) + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }) if want := tcpip.ErrWouldBlock; err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } @@ -535,7 +561,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -548,7 +577,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, uu, header.IPv4ProtocolNumber); err != want { + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: uu, + }); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -556,7 +588,10 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write the one-buffer packet again. It must succeed. { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + Data: buf.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index d71a03cd2..122680e10 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -193,19 +193,19 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) dumpPacket(gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) { +func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { - logPacket("send", protocol, hdr.View(), gso) + logPacket("send", protocol, pkt.Header.View(), gso) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - hdrBuf := hdr.View() - length := len(hdrBuf) + payload.Size() + hdrBuf := pkt.Header.View() + length := len(hdrBuf) + pkt.Data.Size() if length > int(e.maxPCAPLen) { length = int(e.maxPCAPLen) } buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+payload.Size()))); err != nil { + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+pkt.Data.Size()))); err != nil { panic(err) } if len(hdrBuf) > length { @@ -215,7 +215,7 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, hdr buffer.Prependable, payload bu panic(err) } length -= len(hdrBuf) - logVectorisedView(payload, length, buf) + logVectorisedView(pkt.Data, length, buf) if _, err := e.file.Write(buf.Bytes()); err != nil { panic(err) } @@ -225,9 +225,9 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, hdr buffer.Prependable, payload bu // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - e.dumpPacket(gso, hdr, payload, protocol) - return e.lower.WritePacket(r, gso, hdr, payload, protocol) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + e.dumpPacket(gso, protocol, pkt) + return e.lower.WritePacket(r, gso, protocol, pkt) } // WritePackets implements the stack.LinkEndpoint interface. It is called by @@ -236,32 +236,35 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { view := payload.ToView() for _, d := range hdrs { - e.dumpPacket(gso, d.Hdr, buffer.NewVectorisedView(d.Size, []buffer.View{view[d.Off:][:d.Size]}), protocol) + e.dumpPacket(gso, protocol, tcpip.PacketBuffer{ + Header: d.Hdr, + Data: view[d.Off:][:d.Size].ToVectorisedView(), + }) } return e.lower.WritePackets(r, gso, hdrs, payload, protocol) } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */) } if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { - length := packet.Size() + length := vv.Size() if length > int(e.maxPCAPLen) { length = int(e.maxPCAPLen) } buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) - if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(packet.Size()))); err != nil { + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil { panic(err) } - logVectorisedView(packet, length, buf) + logVectorisedView(vv, length, buf) if _, err := e.file.Write(buf.Bytes()); err != nil { panic(err) } } - return e.lower.WriteRawPacket(packet) + return e.lower.WriteRawPacket(vv) } func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) { diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index b440970e0..12e7c1932 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -99,12 +99,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { if !e.writeGate.Enter() { return nil } - err := e.lower.WritePacket(r, gso, hdr, payload, protocol) + err := e.lower.WritePacket(r, gso, protocol, pkt) e.writeGate.Leave() return err } @@ -123,12 +123,12 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.Pac } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { if !e.writeGate.Enter() { return nil } - err := e.lower.WriteRawPacket(packet) + err := e.lower.WriteRawPacket(vv) e.writeGate.Leave() return err } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index df2e70e54..0fc0c2ebe 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -65,7 +65,7 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { e.writeCount++ return nil } @@ -76,7 +76,7 @@ func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stac return len(hdrs), nil } -func (e *countedEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error { +func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { e.writeCount++ return nil } @@ -89,21 +89,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0) + wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 0ee509ebe..30aec9ba7 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -79,7 +79,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -88,7 +88,7 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketDescript return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -106,14 +106,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { return // we have no useful answer, ignore the request } hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) - pkt := header.ARP(hdr.Prepend(header.ARPSize)) - pkt.SetIPv4OverEthernet() - pkt.SetOp(header.ARPReply) - copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) - copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender()) - copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + packet := header.ARP(hdr.Prepend(header.ARPSize)) + packet.SetIPv4OverEthernet() + packet.SetOp(header.ARPReply) + copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) + copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) + copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) + copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) + e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) fallthrough // also fill the cache from requests case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -165,7 +167,9 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 47098bfdc..8e6048a21 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -110,13 +110,13 @@ func TestDirectRequest(t *testing.T) { for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { t.Run(strconv.Itoa(i), func(t *testing.T) { inject(address) - pkt := <-c.linkEP.C - if pkt.Proto != arp.ProtocolNumber { - t.Fatalf("expected ARP response, got network protocol number %d", pkt.Proto) + pi := <-c.linkEP.C + if pi.Proto != arp.ProtocolNumber { + t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } - rep := header.ARP(pkt.Header) + rep := header.ARP(pi.Pkt.Header.View()) if !rep.IsValid() { - t.Fatalf("invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) } if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index fe499d47e..1de188738 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -150,24 +150,24 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address if t.v4 { - h := header.IPv4(hdr.View()) + h := header.IPv4(pkt.Header.View()) prot = tcpip.TransportProtocolNumber(h.Protocol()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } else { - h := header.IPv6(hdr.View()) + h := header.IPv6(pkt.Header.View()) prot = tcpip.TransportProtocolNumber(h.NextHeader()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } - t.checkValues(prot, payload, srcAddr, dstAddr) + t.checkValues(prot, pkt.Data, srcAddr, dstAddr) return nil } @@ -239,7 +239,10 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -477,7 +480,10 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil { + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }); err != nil { t.Fatalf("WritePacket failed: %v", err) } } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index ce771631c..32bf39e43 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -99,7 +99,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: vv, + TransportHeader: buffer.View(pkt), + }); err != nil { sent.Dropped.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index ac16c8add..040329a74 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -117,13 +117,14 @@ func (e *endpoint) GSOMaxSize() uint32 { } // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is entirely in hdr but does not assume -// that only the IP header is in hdr. It assumes that the input packet's stated -// length matches the length of the hdr+payload. mtu includes the IP header and -// options. This does not support the DontFragment IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error { +// write. It assumes that the IP header is entirely in pkt.Header but does not +// assume that only the IP header is in pkt.Header. It assumes that the input +// packet's stated length matches the length of the header+payload. mtu +// includes the IP header and options. This does not support the DontFragment +// IP flag. +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt tcpip.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. - ip := header.IPv4(hdr.View()) + ip := header.IPv4(pkt.Header.View()) flags := ip.Flags() // Update mtu to take into account the header, which will exist in all @@ -137,62 +138,77 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff outerMTU := innerMTU + int(ip.HeaderLength()) offset := ip.FragmentOffset() - originalAvailableLength := hdr.AvailableLength() + originalAvailableLength := pkt.Header.AvailableLength() for i := 0; i < n; i++ { // Where possible, the first fragment that is sent has the same - // hdr.UsedLength() as the input packet. The link-layer endpoint may depends - // on this for looking at, eg, L4 headers. + // pkt.Header.UsedLength() as the input packet. The link-layer + // endpoint may depend on this for looking at, eg, L4 headers. h := ip if i > 0 { - hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) - h = header.IPv4(hdr.Prepend(int(ip.HeaderLength()))) + pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) + h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength()))) copy(h, ip[:ip.HeaderLength()]) } if i != n-1 { h.SetTotalLength(uint16(outerMTU)) h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) } else { - h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size())) + h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size())) h.SetFlagsFragmentOffset(flags, offset) } h.SetChecksum(0) h.SetChecksum(^h.CalculateChecksum()) offset += uint16(innerMTU) if i > 0 { - newPayload := payload.Clone([]buffer.View{}) + newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayload.Size()) + pkt.Data.TrimFront(newPayload.Size()) continue } - // Special handling for the first fragment because it comes from the hdr. - if outerMTU >= hdr.UsedLength() { - // This fragment can fit all of hdr and possibly some of payload, too. - newPayload := payload.Clone([]buffer.View{}) - newPayloadLength := outerMTU - hdr.UsedLength() + // Special handling for the first fragment because it comes + // from the header. + if outerMTU >= pkt.Header.UsedLength() { + // This fragment can fit all of pkt.Header and possibly + // some of pkt.Data, too. + newPayload := pkt.Data.Clone(nil) + newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - payload.TrimFront(newPayloadLength) + pkt.Data.TrimFront(newPayloadLength) } else { - // The fragment is too small to fit all of hdr. - startOfHdr := hdr - startOfHdr.TrimBack(hdr.UsedLength() - outerMTU) + // The fragment is too small to fit all of pkt.Header. + startOfHdr := pkt.Header + startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + Header: startOfHdr, + Data: emptyVV, + NetworkHeader: buffer.View(h), + }); err != nil { return err } r.Stats().IP.PacketsSent.Increment() - // Add the unused bytes of hdr into the payload that remains to be sent. - restOfHdr := hdr.View()[outerMTU:] + // Add the unused bytes of pkt.Header into the pkt.Data + // that remains to be sent. + restOfHdr := pkt.Header.View()[outerMTU:] tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) - tmp.Append(payload) - payload = tmp + tmp.Append(pkt.Data) + pkt.Data = tmp } } return nil @@ -222,17 +238,17 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { - ip := e.addIPHeader(r, &hdr, payload.Size(), params) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() e.HandlePacket(&loopedR, tcpip.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), NetworkHeader: buffer.View(ip), }) @@ -241,10 +257,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen if loop&stack.PacketOut == 0 { return nil } - if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU())) + if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) } - if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { return err } r.Stats().IP.PacketsSent.Increment() @@ -270,16 +286,16 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.Pac // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(payload.First()) - if !ip.IsValid(payload.Size()) { + ip := header.IPv4(pkt.Data.First()) + if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } // Always set the total length. - ip.SetTotalLength(uint16(payload.Size())) + ip.SetTotalLength(uint16(pkt.Data.Size())) // Set the source address when zero. if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { @@ -293,7 +309,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect // Set the packet ID when zero. if ip.ID() == 0 { id := uint32(0) - if payload.Size() > header.IPv4MaximumHeaderSize+8 { + if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) @@ -306,18 +322,18 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect ip.SetChecksum(^ip.CalculateChecksum()) if loop&stack.PacketLoop != 0 { - e.HandlePacket(r, tcpip.PacketBuffer{ - Data: payload, - NetworkHeader: buffer.View(ip), - }) + e.HandlePacket(r, pkt.Clone()) } if loop&stack.PacketOut == 0 { return nil } - hdr := buffer.NewPrependableFromView(payload.ToView()) r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + + ip = ip[:ip.HeaderLength()] + pkt.Header = buffer.NewPrependableFromView(buffer.View(ip)) + pkt.Data.TrimFront(int(ip.HeaderLength())) + return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 01dfb5f20..e900f1b45 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -113,12 +113,12 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) { +func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Payload.ToView()...) + source = append(source, sourcePacketInfo.Data.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -132,7 +132,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe for i, packet := range packets { // Confirm that the packet is valid. allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Payload) + allBytes.Append(packet.Data) ip := header.IPv4(allBytes.ToView()) if !ip.IsValid(len(ip)) { t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) @@ -173,7 +173,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe type errorChannel struct { *channel.Endpoint - Ch chan packetInfo + Ch chan tcpip.PacketBuffer packetCollectorErrors []*tcpip.Error } @@ -183,17 +183,11 @@ type errorChannel struct { func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { return &errorChannel{ Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan packetInfo, size), + Ch: make(chan tcpip.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } } -// packetInfo holds all the information about an outbound packet. -type packetInfo struct { - Header buffer.Prependable - Payload buffer.VectorisedView -} - // Drain removes all outbound packets from the channel and counts them. func (e *errorChannel) Drain() int { c := 0 @@ -208,14 +202,9 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { - p := packetInfo{ - Header: hdr, - Payload: payload, - } - +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { select { - case e.Ch <- p: + case e.Ch <- pkt: default: } @@ -292,18 +281,21 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := packetInfo{ + source := tcpip.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. - Payload: payload.Clone([]buffer.View{}), + Data: payload.Clone(nil), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) if err != nil { t.Errorf("err got %v, want %v", err, nil) } - var results []packetInfo + var results []tcpip.PacketBuffer L: for { select { @@ -345,7 +337,10 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}) + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 6629951c6..1c3410618 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -226,7 +226,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { sent.Dropped.Increment() return } @@ -291,7 +293,10 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P copy(packet, h) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, hdr, pkt.Data, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: pkt.Data, + }); err != nil { sent.Dropped.Increment() return } @@ -417,7 +422,9 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + Header: hdr, + }) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 6037a1ef8..335f634d5 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -55,7 +55,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.NetworkProtocolNumber) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error { return nil } @@ -276,22 +276,22 @@ type routeArgs struct { func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pkt := <-args.src.C + pi := <-args.src.C { - views := []buffer.View{pkt.Header, pkt.Payload} - size := len(pkt.Header) + len(pkt.Payload) + views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} + size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ Data: vv, }) } - if pkt.Proto != ProtocolNumber { - t.Errorf("unexpected protocol number %d", pkt.Proto) + if pi.Proto != ProtocolNumber { + t.Errorf("unexpected protocol number %d", pi.Proto) return } - ipv6 := header.IPv6(pkt.Header) + ipv6 := header.IPv6(pi.Pkt.Header.View()) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { t.Errorf("unexpected transport protocol number %d", transProto) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 4cee848a1..8d1578ed9 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -112,17 +112,17 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { - ip := e.addIPHeader(r, &hdr, payload.Size(), params) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() e.HandlePacket(&loopedR, tcpip.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+payload.Size(), views), + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), NetworkHeader: buffer.View(ip), }) @@ -133,7 +133,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen } r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber) + return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. @@ -158,7 +158,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.Pac // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // TODO(b/119580726): Support IPv6 header-included packets. return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go index 10b04239d..695f7b188 100644 --- a/pkg/tcpip/packet_buffer.go +++ b/pkg/tcpip/packet_buffer.go @@ -31,12 +31,19 @@ type PacketBuffer struct { // or otherwise modified. Data buffer.VectorisedView + // Header holds the headers of outbound packets. As a packet is passed + // down the stack, each layer adds to Header. + Header buffer.Prependable + + // These fields are used by both inbound and outbound packets. They + // typically overlap with the Data and Header fields. + // // The bytes backing these views are immutable. Each field may be nil // if either it has not been set yet or no such header exists (e.g. // packets sent via loopback may not have a link header). // - // These fields may be Views into other Views. SR dosen't support this, - // so deep copies are necessary in some cases. + // These fields may be Views into other slices (either Data or Header). + // SR dosen't support this, so deep copies are necessary in some cases. LinkHeader buffer.View NetworkHeader buffer.View TransportHeader buffer.View @@ -44,11 +51,9 @@ type PacketBuffer struct { // Clone makes a copy of pk. It clones the Data field, which creates a new // VectorisedView but does not deep copy the underlying bytes. +// +// Clone also does not deep copy any of its other fields. func (pk PacketBuffer) Clone() PacketBuffer { - return PacketBuffer{ - Data: pk.Data.Clone(nil), - LinkHeader: pk.LinkHeader, - NetworkHeader: pk.NetworkHeader, - TransportHeader: pk.TransportHeader, - } + pk.Data = pk.Data.Clone(nil) + return pk } diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/packet_buffer_state.go index 04c4cf136..ad3cc24fa 100644 --- a/pkg/tcpip/packet_buffer_state.go +++ b/pkg/tcpip/packet_buffer_state.go @@ -20,6 +20,7 @@ import "gvisor.dev/gvisor/pkg/tcpip/buffer" func (pk *PacketBuffer) beforeSave() { // Non-Data fields may be slices of the Data field. This causes // problems for SR, so during save we make each header independent. + pk.Header = pk.Header.DeepCopy() pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 8357dca77..cfdd0496e 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -421,7 +421,9 @@ func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining u pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, hdr, buffer.VectorisedView{}, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}); err != nil { + if err := r.WritePacket(nil, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + }); err != nil { sent.Dropped.Increment() return false, err } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 494244368..5b901f947 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -344,7 +344,7 @@ func TestDADResolve(t *testing.T) { } // Check NDP packet. - checker.IPv6(t, p.Header.ToVectorisedView().First(), + checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(addr1))) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 9ed9e1e7c..3f8d7312c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -812,15 +812,15 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } else { // n doesn't have a destination endpoint. // Send the packet out of n. - hdr := buffer.NewPrependableFromView(pkt.Data.First()) + pkt.Header = buffer.NewPrependableFromView(pkt.Data.First()) pkt.Data.RemoveFirst() // TODO(b/128629022): use route.WritePacket. - if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, pkt.Data, protocol); err != nil { + if err := n.linkEP.WritePacket(&r, nil /* gso */, protocol, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + pkt.Data.Size())) + n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) } } return diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index c0026f5a3..7fd4e4a65 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -232,8 +232,9 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error + // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have + // already been set. + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and // protocol. @@ -241,7 +242,7 @@ type NetworkEndpoint interface { // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -361,13 +362,15 @@ type LinkEndpoint interface { // link endpoint. LinkAddress() tcpip.LinkAddress - // WritePacket writes a packet with the given protocol through the given - // route. + // WritePacket writes a packet with the given protocol through the + // given route. It sets pkt.LinkHeader if a link layer header exists. + // pkt.NetworkHeader and pkt.TransportHeader must have already been + // set. // // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error // WritePackets writes packets with the given protocol through the // given route. @@ -379,7 +382,7 @@ type LinkEndpoint interface { // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. - WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error + WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 1a0a51b57..617f5a57c 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -154,17 +154,17 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.Loop) + err := r.ref.ep.WritePacket(r, gso, params, r.Loop, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + payload.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) } return err } @@ -209,17 +209,17 @@ func (r *Route) WritePackets(gso *GSO, hdrs []PacketDescriptor, payload buffer.V // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.Loop); err != nil { + if err := r.ref.ep.WriteHeaderIncludedPacket(r, r.Loop, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payload.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Data.Size())) return nil } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 2f8d8e822..0e88643a4 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1366,10 +1366,10 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t } fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) fakeHeader.Encode(ðFields) - ethHeader := buffer.View(fakeHeader).ToVectorisedView() - ethHeader.Append(payload) + vv := buffer.View(fakeHeader).ToVectorisedView() + vv.Append(payload) - if err := nic.linkEP.WriteRawPacket(ethHeader); err != nil { + if err := nic.linkEP.WriteRawPacket(vv); err != nil { return err } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index bf1d6974c..f979e2b1a 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -122,31 +122,30 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ // Add the protocol's header to the packet and send it to the link // endpoint. - b := hdr.Prepend(fakeNetHeaderLen) + b := pkt.Header.Prepend(fakeNetHeaderLen) b[0] = r.RemoteAddress[0] b[1] = f.id.LocalAddress[0] b[2] = byte(params.Protocol) if loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(payload.Views())) - views[0] = hdr.View() - views = append(views, payload.Views()...) - vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) f.HandlePacket(r, tcpip.PacketBuffer{ - Data: vv, + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) } if loop&stack.PacketOut == 0 { return nil } - return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber) + return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. @@ -154,7 +153,7 @@ func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -330,7 +329,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload.ToVectorisedView(), + }) } func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 2cacea99a..748ce4ea5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -83,7 +83,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}); err != nil { + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.View(v).ToVectorisedView(), + }); err != nil { return 0, nil, err } @@ -617,10 +620,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Header[0]; dst != 3 { + if dst := p.Pkt.Header.View()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Header[1]; src != 1 { + if src := p.Pkt.Header.View()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 70e008d36..9c40931b5 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -429,7 +429,11 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: data.ToVectorisedView(), + TransportHeader: buffer.View(icmpv4), + }) } func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { @@ -455,7 +459,11 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, hdr, dataVV, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: dataVV, + TransportHeader: buffer.View(icmpv6), + }) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 230a1537a..5aafe2615 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -338,13 +338,18 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, switch e.NetProto { case header.IPv4ProtocolNumber: if !e.associated { - if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil { + if err := route.WriteHeaderIncludedPacket(tcpip.PacketBuffer{ + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { return 0, nil, err } break } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { return 0, nil, err } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index be066d877..49f2b9685 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -723,7 +723,10 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise if ttl == 0 { ttl = r.DefaultTTL() } - if err := r.WritePacket(gso, d.Hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{ + Header: d.Hdr, + Data: data, + }); err != nil { r.Stats().TCP.SegmentSendErrors.Increment() return err } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 0a733fa94..04fdaaed1 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -236,9 +236,9 @@ func (c *Context) GetPacket() []byte { if p.Proto != ipv4.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) @@ -264,9 +264,9 @@ func (c *Context) GetPacketNonBlocking() []byte { if p.Proto != ipv4.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) return b @@ -488,9 +488,9 @@ func (c *Context) GetV6Packet() []byte { if p.Proto != ipv6.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) + copy(b, p.Pkt.Header.View()) + copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) return b diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index dda7af910..2d97d1398 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -817,7 +817,10 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{ + Header: hdr, + Data: data, + }); err != nil { r.Stats().UDP.PacketSendErrors.Increment() return err } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 43f11b700..259c3072a 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -135,7 +135,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) case header.IPv6AddressSize: if !r.Stack().AllowICMPMessage() { @@ -169,7 +172,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: payload, + }) } return true } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 30ee9801b..7051a7a9c 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -356,9 +356,9 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw if p.Proto != flow.netProto() { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } - b := make([]byte, len(p.Header)+len(p.Payload)) - copy(b, p.Header) - copy(b[len(p.Header):], p.Payload) + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) h := flow.header4Tuple(outgoing) checkers := append( @@ -1453,8 +1453,8 @@ func TestV4UnknownDestination(t *testing.T) { select { case p := <-c.linkEP.C: var pkt []byte - pkt = append(pkt, p.Header...) - pkt = append(pkt, p.Payload...) + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1527,8 +1527,8 @@ func TestV6UnknownDestination(t *testing.T) { select { case p := <-c.linkEP.C: var pkt []byte - pkt = append(pkt, p.Header...) - pkt = append(pkt, p.Payload...) + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) if got, want := len(pkt), header.IPv6MinimumMTU; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } -- cgit v1.2.3 From 9db08c4e583e758e3eb1aed03875743ce02b8cff Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Fri, 22 Nov 2019 14:41:04 -0800 Subject: Use PacketBuffers with GSO. PiperOrigin-RevId: 282045221 --- pkg/tcpip/link/channel/channel.go | 12 ++++----- pkg/tcpip/link/fdbased/endpoint.go | 25 +++++++++++++------ pkg/tcpip/link/loopback/loopback.go | 2 +- pkg/tcpip/link/muxed/injectable.go | 4 +-- pkg/tcpip/link/sharedmem/sharedmem.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 12 ++++----- pkg/tcpip/link/waitable/waitable.go | 6 ++--- pkg/tcpip/link/waitable/waitable_test.go | 6 ++--- pkg/tcpip/network/arp/arp.go | 2 +- pkg/tcpip/network/ip_test.go | 2 +- pkg/tcpip/network/ipv4/ipv4.go | 10 ++++---- pkg/tcpip/network/ipv6/ipv6.go | 12 ++++----- pkg/tcpip/packet_buffer.go | 8 ++++++ pkg/tcpip/stack/registration.go | 8 +++--- pkg/tcpip/stack/route.go | 29 ++++----------------- pkg/tcpip/stack/stack_test.go | 2 +- pkg/tcpip/transport/tcp/connect.go | 43 ++++++++++++++++++-------------- 17 files changed, 95 insertions(+), 90 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 9fe8e9f9d..70188551f 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -133,16 +133,16 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.Ne } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - payloadView := payload.ToView() +func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + payloadView := pkts[0].Data.ToView() n := 0 packetLoop: - for _, hdr := range hdrs { - off := hdr.Off - size := hdr.Size + for _, pkt := range pkts { + off := pkt.DataOffset + size := pkt.DataSize p := PacketInfo{ Pkt: tcpip.PacketBuffer{ - Header: hdr.Hdr, + Header: pkt.Header, Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(), }, Proto: protocol, diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 98109c5dc..fa8a703d9 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -440,7 +440,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { var ethHdrBuf []byte // hdr + data iovLen := 2 @@ -463,9 +463,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.Pac iovLen++ } - n := len(hdrs) + n := len(pkts) - views := payload.Views() + views := pkts[0].Data.Views() /* * Each bondary in views can add one more iovec. * @@ -483,14 +483,20 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.Pac viewOff := 0 off := 0 nextOff := 0 - for i := range hdrs { + for i := range pkts { + // TODO(b/134618279): Different packets may have different data + // in the future. We should handle this. + if !viewsEqual(pkts[i].Data.Views(), views) { + panic("All packets in pkts should have the same Data.") + } + prevIovecIdx := iovecIdx mmsgHdr := &mmsgHdrs[i] mmsgHdr.Msg.Iov = &iovec[iovecIdx] - packetSize := hdrs[i].Size - hdr := &hdrs[i].Hdr + packetSize := pkts[i].DataSize + hdr := &pkts[i].Header - off = hdrs[i].Off + off = pkts[i].DataOffset if off != nextOff { // We stop in a different point last time. size := packetSize @@ -555,6 +561,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.Pac return packets, nil } +// viewsEqual tests whether v1 and v2 refer to the same backing bytes. +func viewsEqual(vs1, vs2 []buffer.View) bool { + return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0]) +} + // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { return rawfile.NonBlockingWrite(e.fds[0], vv.ToView()) diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 563a67188..499cc608f 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 55ed2a28e..445b22c17 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -87,12 +87,12 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { return 0, tcpip.ErrNoRoute } - return endpoint.WritePackets(r, gso, hdrs, payload, protocol) + return endpoint.WritePackets(r, gso, pkts, protocol) } // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 88947a03a..080f9d667 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -214,7 +214,7 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 122680e10..767f14303 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -233,15 +233,15 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - view := payload.ToView() - for _, d := range hdrs { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + view := pkts[0].Data.ToView() + for _, pkt := range pkts { e.dumpPacket(gso, protocol, tcpip.PacketBuffer{ - Header: d.Hdr, - Data: view[d.Off:][:d.Size].ToVectorisedView(), + Header: pkt.Header, + Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(), }) } - return e.lower.WritePackets(r, gso, hdrs, payload, protocol) + return e.lower.WritePackets(r, gso, pkts, protocol) } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 12e7c1932..a8de38979 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -112,12 +112,12 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { if !e.writeGate.Enter() { - return len(hdrs), nil + return len(pkts), nil } - n, err := e.lower.WritePackets(r, gso, hdrs, payload, protocol) + n, err := e.lower.WritePackets(r, gso, pkts, protocol) e.writeGate.Leave() return n, err } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 0fc0c2ebe..31b11a27a 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -71,9 +71,9 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcp } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - e.writeCount += len(hdrs) - return len(hdrs), nil +func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + e.writeCount += len(pkts) + return len(pkts), nil } func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 30aec9ba7..da8482509 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderPara } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketDescriptor, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 1de188738..4144a7837 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -172,7 +172,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, hdr []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 040329a74..7059600f5 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -268,18 +268,18 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { if loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } if loop&stack.PacketOut == 0 { - return len(hdrs), nil + return len(pkts), nil } - for i := range hdrs { - e.addIPHeader(r, &hdrs[i].Hdr, hdrs[i].Size, params) + for i := range pkts { + e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params) } - n, err := e.linkEP.WritePackets(r, gso, hdrs, payload, ProtocolNumber) + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) return n, err } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 8d1578ed9..c9087ffa7 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -137,21 +137,21 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { if loop&stack.PacketLoop != 0 { panic("not implemented") } if loop&stack.PacketOut == 0 { - return len(hdrs), nil + return len(pkts), nil } - for i := range hdrs { - hdr := &hdrs[i].Hdr - size := hdrs[i].Size + for i := range pkts { + hdr := &pkts[i].Header + size := pkts[i].DataSize e.addIPHeader(r, hdr, size, params) } - n, err := e.linkEP.WritePackets(r, gso, hdrs, payload, ProtocolNumber) + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) return n, err } diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go index 695f7b188..ab24372e7 100644 --- a/pkg/tcpip/packet_buffer.go +++ b/pkg/tcpip/packet_buffer.go @@ -31,6 +31,14 @@ type PacketBuffer struct { // or otherwise modified. Data buffer.VectorisedView + // DataOffset is used for GSO output. It is the offset into the Data + // field where the payload of this packet starts. + DataOffset int + + // DataSize is used for GSO output. It is the size of this packet's + // payload. + DataSize int + // Header holds the headers of outbound packets. As a packet is passed // down the stack, each layer adds to Header. Header buffer.Prependable diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 7fd4e4a65..61fd46d66 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -237,8 +237,8 @@ type NetworkEndpoint interface { WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and - // protocol. - WritePackets(r *Route, gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error) + // protocol. pkts must not be zero length. + WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. @@ -373,12 +373,12 @@ type LinkEndpoint interface { WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error // WritePackets writes packets with the given protocol through the - // given route. + // given route. pkts must not be zero length. // // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, it may // require to change syscall filters. - WritePackets(r *Route, gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 617f5a57c..34307ae07 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -17,7 +17,6 @@ package stack import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -169,39 +168,21 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.Pack return err } -// PacketDescriptor is a packet descriptor which contains a packet header and -// offset and size of packet data in a payload view. -type PacketDescriptor struct { - Hdr buffer.Prependable - Off int - Size int -} - -// NewPacketDescriptors allocates a set of packet descriptors. -func NewPacketDescriptors(n int, hdrSize int) []PacketDescriptor { - buf := make([]byte, n*hdrSize) - hdrs := make([]PacketDescriptor, n) - for i := range hdrs { - hdrs[i].Hdr = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize]) - } - return hdrs -} - // WritePackets writes the set of packets through the given route. -func (r *Route) WritePackets(gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, params NetworkHeaderParams) (int, *tcpip.Error) { +func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { if !r.ref.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } - n, err := r.ref.ep.WritePackets(r, gso, hdrs, payload, params, r.Loop) + n, err := r.ref.ep.WritePackets(r, gso, pkts, params, r.Loop) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(hdrs) - n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n)) } r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) payloadSize := 0 for i := 0; i < n; i++ { - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdrs[i].Hdr.UsedLength())) - payloadSize += hdrs[i].Size + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkts[i].Header.UsedLength())) + payloadSize += pkts[i].DataSize } r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize)) return n, err diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index f979e2b1a..8fc034ca1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -149,7 +149,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 75b7c0828..00c0c9a92 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -631,11 +631,11 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu return nil } -func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, d *stack.PacketDescriptor, data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) { +func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *tcpip.PacketBuffer, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) { optLen := len(opts) - hdr := &d.Hdr - packetSize := d.Size - off := d.Off + hdr := &pkt.Header + packetSize := pkt.DataSize + off := pkt.DataOffset // Initialize the header. tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) tcp.Encode(&header.TCPFields{ @@ -659,7 +659,7 @@ func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, d *stack.PacketDe // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVVWithOffset(data, xsum, off, packetSize) + xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } @@ -674,7 +674,13 @@ func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.Vect mss := int(gso.MSS) n := (data.Size() + mss - 1) / mss - hdrs := stack.NewPacketDescriptors(n, header.TCPMinimumSize+int(r.MaxHeaderLength())+optLen) + // Allocate one big slice for all the headers. + hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen + buf := make([]byte, n*hdrSize) + pkts := make([]tcpip.PacketBuffer, n) + for i := range pkts { + pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize]) + } size := data.Size() off := 0 @@ -684,16 +690,17 @@ func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.Vect packetSize = size } size -= packetSize - hdrs[i].Off = off - hdrs[i].Size = packetSize - buildTCPHdr(r, id, &hdrs[i], data, flags, seq, ack, rcvWnd, opts, gso) + pkts[i].DataOffset = off + pkts[i].DataSize = packetSize + pkts[i].Data = data + buildTCPHdr(r, id, &pkts[i], flags, seq, ack, rcvWnd, opts, gso) off += packetSize seq = seq.Add(seqnum.Size(packetSize)) } if ttl == 0 { ttl = r.DefaultTTL() } - sent, err := r.WritePackets(gso, hdrs, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}) + sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}) if err != nil { r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent)) } @@ -713,20 +720,18 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso) } - d := &stack.PacketDescriptor{ - Hdr: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), - Off: 0, - Size: data.Size(), + pkt := tcpip.PacketBuffer{ + Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), + DataOffset: 0, + DataSize: data.Size(), + Data: data, } - buildTCPHdr(r, id, d, data, flags, seq, ack, rcvWnd, opts, gso) + buildTCPHdr(r, id, &pkt, flags, seq, ack, rcvWnd, opts, gso) if ttl == 0 { ttl = r.DefaultTTL() } - if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{ - Header: d.Hdr, - Data: data, - }); err != nil { + if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, pkt); err != nil { r.Stats().TCP.SegmentSendErrors.Increment() return err } -- cgit v1.2.3 From c3b93afeafeff4555b57aa22c2a91375f9e38e28 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Sat, 23 Nov 2019 23:21:04 -0800 Subject: Cleanup visibility. PiperOrigin-RevId: 282194656 --- pkg/tcpip/hash/jenkins/BUILD | 4 +--- pkg/tcpip/link/channel/BUILD | 2 +- pkg/tcpip/link/fdbased/BUILD | 4 +--- pkg/tcpip/link/loopback/BUILD | 2 +- pkg/tcpip/link/muxed/BUILD | 4 +--- pkg/tcpip/link/rawfile/BUILD | 4 +--- pkg/tcpip/link/sharedmem/BUILD | 4 +--- pkg/tcpip/link/sharedmem/pipe/BUILD | 2 +- pkg/tcpip/link/sharedmem/queue/BUILD | 2 +- pkg/tcpip/link/sniffer/BUILD | 4 +--- pkg/tcpip/link/tun/BUILD | 4 +--- pkg/tcpip/link/waitable/BUILD | 4 +--- pkg/tcpip/network/arp/BUILD | 4 +--- pkg/tcpip/network/fragmentation/BUILD | 10 +--------- pkg/tcpip/network/ipv4/BUILD | 4 +--- pkg/tcpip/network/ipv6/BUILD | 4 +--- pkg/tcpip/ports/BUILD | 2 +- pkg/tcpip/seqnum/BUILD | 4 +--- pkg/tcpip/stack/BUILD | 12 +----------- pkg/tcpip/transport/icmp/BUILD | 8 -------- pkg/tcpip/transport/packet/BUILD | 8 -------- pkg/tcpip/transport/raw/BUILD | 8 -------- pkg/tcpip/transport/tcp/BUILD | 8 -------- pkg/tcpip/transport/tcp/testing/context/BUILD | 2 +- pkg/tcpip/transport/udp/BUILD | 8 -------- pkg/waiter/BUILD | 8 -------- 26 files changed, 20 insertions(+), 110 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD index 0c5c20cea..e648efa71 100644 --- a/pkg/tcpip/hash/jenkins/BUILD +++ b/pkg/tcpip/hash/jenkins/BUILD @@ -7,9 +7,7 @@ go_library( name = "jenkins", srcs = ["jenkins.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) go_test( diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 97a794986..7dbc05754 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -6,7 +6,7 @@ go_library( name = "channel", srcs = ["channel.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 8fa9e3984..897c94821 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -14,9 +14,7 @@ go_library( "packet_dispatchers.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD index 23e4d1418..f35fcdff4 100644 --- a/pkg/tcpip/link/loopback/BUILD +++ b/pkg/tcpip/link/loopback/BUILD @@ -6,7 +6,7 @@ go_library( name = "loopback", srcs = ["loopback.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 1bab380b0..1ac7948b6 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -7,9 +7,7 @@ go_library( name = "muxed", srcs = ["injectable.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 05c7b8024..d8211e93d 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -13,9 +13,7 @@ go_library( "rawfile_unsafe.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index 0a5ea3dc4..a4f9cdd69 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -12,9 +12,7 @@ go_library( "tx.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem", - visibility = [ - "//:sandbox", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index 330ed5e94..6b5bc542c 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -12,7 +12,7 @@ go_library( "tx.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], ) go_test( diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD index de1ce043d..8c9234d54 100644 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -10,7 +10,7 @@ go_library( "tx.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip/link/sharedmem/pipe", diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index 1756114e6..d6ae0368a 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -9,9 +9,7 @@ go_library( "sniffer.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip", diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 92dce8fac..a71a493fc 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -6,7 +6,5 @@ go_library( name = "tun", srcs = ["tun_unsafe.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 0746dc8ec..134837943 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -9,9 +9,7 @@ go_library( "waitable.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/gate", "//pkg/tcpip", diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index df0d3a8c0..e7617229b 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -7,9 +7,7 @@ go_library( name = "arp", srcs = ["arp.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 2cad0a0b6..acf1e022c 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -25,7 +25,7 @@ go_library( "reassembler_list.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip", @@ -44,11 +44,3 @@ go_test( embed = [":fragmentation"], deps = ["//pkg/tcpip/buffer"], ) - -filegroup( - name = "autogen", - srcs = [ - "reassembler_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 58e537aad..aeddfcdd4 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -10,9 +10,7 @@ go_library( "ipv4.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index f06622a8b..e4e273460 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -10,9 +10,7 @@ go_library( "ipv6.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 11efb4e44..4839f0a65 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -7,7 +7,7 @@ go_library( name = "ports", srcs = ["ports.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", - visibility = ["//:sandbox"], + visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", ], diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD index 29b7d761c..b31ddba2f 100644 --- a/pkg/tcpip/seqnum/BUILD +++ b/pkg/tcpip/seqnum/BUILD @@ -6,7 +6,5 @@ go_library( name = "seqnum", srcs = ["seqnum.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 460db3cf8..69077669a 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -31,9 +31,7 @@ go_library( "transport_demuxer.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/stack", - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//pkg/ilist", "//pkg/rand", @@ -87,11 +85,3 @@ go_test( "//pkg/tcpip", ], ) - -filegroup( - name = "autogen", - srcs = [ - "linkaddrentry_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 9254c3dea..d8c5b5058 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -38,11 +38,3 @@ go_library( "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "icmp_packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD index 8ea2e6ee5..44b58ff6b 100644 --- a/pkg/tcpip/transport/packet/BUILD +++ b/pkg/tcpip/transport/packet/BUILD @@ -36,11 +36,3 @@ go_library( "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 4af49218c..00991ac8e 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -38,11 +38,3 @@ go_library( "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "raw_packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 3f47b328d..dd1728f9c 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -61,14 +61,6 @@ go_library( ], ) -filegroup( - name = "autogen", - srcs = [ - "tcp_segment_list.go", - ], - visibility = ["//:sandbox"], -) - go_test( name = "tcp_test", size = "medium", diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD index 19b0d31c5..b33ec2087 100644 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ b/pkg/tcpip/transport/tcp/testing/context/BUILD @@ -8,7 +8,7 @@ go_library( srcs = ["context.go"], importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context", visibility = [ - "//:sandbox", + "//visibility:public", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index c9460aa0d..8d4c3808f 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -59,11 +59,3 @@ go_test( "//pkg/waiter", ], ) - -filegroup( - name = "autogen", - srcs = [ - "udp_packet_list.go", - ], - visibility = ["//:sandbox"], -) diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 1f7efb064..0427bc41f 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -34,11 +34,3 @@ go_test( ], embed = [":waiter"], ) - -filegroup( - name = "autogen", - srcs = [ - "waiter_list.go", - ], - visibility = ["//:sandbox"], -) -- cgit v1.2.3 From 10f7b109ab98c95783357b82e1934586f338c2b3 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 5 Dec 2019 10:40:18 -0800 Subject: Add a type to represent the NDP Recursive DNS Server option This change adds a type to represent the NDP Recursive DNS Server option, as defined by RFC 8106 section 5.1. PiperOrigin-RevId: 284005493 --- pkg/tcpip/header/BUILD | 5 +- pkg/tcpip/header/ndp_options.go | 118 +++++++++++++++++++++- pkg/tcpip/header/ndp_test.go | 215 ++++++++++++++++++++++++++++++++++++++++ pkg/tcpip/stack/ndp.go | 6 +- pkg/tcpip/stack/ndp_test.go | 6 +- 5 files changed, 339 insertions(+), 11 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index a3485b35c..8392cb9e5 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -55,5 +55,8 @@ go_test( "ndp_test.go", ], embed = [":header"], - deps = ["//pkg/tcpip"], + deps = [ + "//pkg/tcpip", + "@com_github_google_go-cmp//cmp:go_default_library", + ], ) diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 1ca6199ef..2652e7b67 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -85,6 +85,23 @@ const ( // within an NDPPrefixInformation. ndpPrefixInformationPrefixOffset = 14 + // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS + // Server option, as per RFC 8106 section 5.1. + NDPRecursiveDNSServerOptionType = 25 + + // ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte + // Lifetime field within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerLifetimeOffset = 2 + + // ndpRecursiveDNSServerAddressesOffset is the start of the addresses + // for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer. + ndpRecursiveDNSServerAddressesOffset = 6 + + // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS + // Server option's length field value when it contains at least one + // IPv6 address. + minNDPRecursiveDNSServerLength = 3 + // lengthByteUnits is the multiplier factor for the Length field of an // NDP option. That is, the length field for NDP options is in units of // 8 octets, as per RFC 4861 section 4.6. @@ -92,13 +109,13 @@ const ( ) var ( - // NDPPrefixInformationInfiniteLifetime is a value that represents + // NDPInfiniteLifetime is a value that represents // infinity for the Valid and Preferred Lifetime fields in a NDP Prefix // Information option. Its value is (2^32 - 1)s = 4294967295s // // This is a variable instead of a constant so that tests can change // this value to a smaller value. It should only be modified by tests. - NDPPrefixInformationInfiniteLifetime = time.Second * 4294967295 + NDPInfiniteLifetime = time.Second * 4294967295 ) // NDPOptionIterator is an iterator of NDPOption. @@ -118,6 +135,7 @@ var ( ErrNDPOptBufExhausted = errors.New("Buffer unexpectedly exhausted") ErrNDPOptZeroLength = errors.New("NDP option has zero-valued Length field") ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body") + ErrNDPInvalidLength = errors.New("NDP option's Length value is invalid as per relevant RFC") ) // Next returns the next element in the backing NDPOptions, or true if we are @@ -182,6 +200,22 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { } return NDPPrefixInformation(body), false, nil + + case NDPRecursiveDNSServerOptionType: + // RFC 8106 section 5.3.1 outlines that the RDNSS option + // must have a minimum length of 3 so it contains at + // least one IPv6 address. + if l < minNDPRecursiveDNSServerLength { + return nil, true, ErrNDPInvalidLength + } + + opt := NDPRecursiveDNSServer(body) + if len(opt.Addresses()) == 0 { + return nil, true, ErrNDPOptMalformedBody + } + + return opt, false, nil + default: // We do not yet recognize the option, just skip for // now. This is okay because RFC 4861 allows us to @@ -434,7 +468,7 @@ func (o NDPPrefixInformation) AutonomousAddressConfigurationFlag() bool { // // Note, a value of 0 implies the prefix should not be considered as on-link, // and a value of infinity/forever is represented by -// NDPPrefixInformationInfiniteLifetime. +// NDPInfiniteLifetime. func (o NDPPrefixInformation) ValidLifetime() time.Duration { // The field is the time in seconds, as per RFC 4861 section 4.6.2. return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationValidLifetimeOffset:])) @@ -447,7 +481,7 @@ func (o NDPPrefixInformation) ValidLifetime() time.Duration { // // Note, a value of 0 implies that addresses generated from the prefix should // no longer remain preferred, and a value of infinity is represented by -// NDPPrefixInformationInfiniteLifetime. +// NDPInfiniteLifetime. // // Also note that the value of this field MUST NOT exceed the Valid Lifetime // field to avoid preferring addresses that are no longer valid, for the @@ -476,3 +510,79 @@ func (o NDPPrefixInformation) Subnet() tcpip.Subnet { } return addrWithPrefix.Subnet() } + +// NDPRecursiveDNSServer is the NDP Recursive DNS Server option, as defined by +// RFC 8106 section 5.1. +// +// To make sure that the option meets its minimum length and does not end in the +// middle of a DNS server's IPv6 address, the length of a valid +// NDPRecursiveDNSServer must meet the following constraint: +// (Length - ndpRecursiveDNSServerAddressesOffset) % IPv6AddressSize == 0 +type NDPRecursiveDNSServer []byte + +// Type returns the type of an NDP Recursive DNS Server option. +// +// Type implements NDPOption.Type. +func (NDPRecursiveDNSServer) Type() uint8 { + return NDPRecursiveDNSServerOptionType +} + +// Length implements NDPOption.Length. +func (o NDPRecursiveDNSServer) Length() int { + return len(o) +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the reserved bytes that are before the Lifetime field. + for i := 0; i < ndpRecursiveDNSServerLifetimeOffset; i++ { + b[i] = 0 + } + + return used +} + +// Lifetime returns the length of time that the DNS server addresses +// in this option may be used for name resolution. +// +// Note, a value of 0 implies the addresses should no longer be used, +// and a value of infinity/forever is represented by NDPInfiniteLifetime. +// +// Lifetime may panic if o does not have enough bytes to hold the Lifetime +// field. +func (o NDPRecursiveDNSServer) Lifetime() time.Duration { + // The field is the time in seconds, as per RFC 8106 section 5.1. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRecursiveDNSServerLifetimeOffset:])) +} + +// Addresses returns the recursive DNS server IPv6 addresses that may be +// used for name resolution. +// +// Note, some of the addresses returned MAY be link-local addresses. +// +// Addresses may panic if o does not hold valid IPv6 addresses. +func (o NDPRecursiveDNSServer) Addresses() []tcpip.Address { + l := len(o) + if l < ndpRecursiveDNSServerAddressesOffset { + return nil + } + + l -= ndpRecursiveDNSServerAddressesOffset + if l%IPv6AddressSize != 0 { + return nil + } + + buf := o[ndpRecursiveDNSServerAddressesOffset:] + var addrs []tcpip.Address + for len(buf) > 0 { + addr := tcpip.Address(buf[:IPv6AddressSize]) + if !IsV6UnicastAddress(addr) { + return nil + } + addrs = append(addrs, addr) + buf = buf[IPv6AddressSize:] + } + return addrs +} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index ad6daafcd..2c439d70c 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -369,6 +370,175 @@ func TestNDPPrefixInformationOption(t *testing.T) { } } +func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { + b := []byte{ + 9, 8, + 1, 2, 4, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + expected := []byte{ + 25, 3, 0, 0, + 1, 2, 4, 8, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + } + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPRecursiveDNSServer(b), + } + if got, want := opts.Serialize(serializer), len(expected); got != want { + t.Errorf("got Serialize = %d, want = %d", got, want) + } + if !bytes.Equal(targetBuf, expected) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPRecursiveDNSServerOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + + opt, ok := next.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) + } + if got := opt.Type(); got != 25 { + t.Errorf("got Type = %d, want = 31", got) + } + if got := opt.Length(); got != 22 { + t.Errorf("got Length = %d, want = 22", got) + } + if got, want := opt.Lifetime(), 16909320*time.Second; got != want { + t.Errorf("got Lifetime = %s, want = %s", got, want) + } + want := []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + } + if got := opt.Addresses(); !cmp.Equal(got, want) { + t.Errorf("got Addresses = %v, want = %v", got, want) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + +func TestNDPRecursiveDNSServerOption(t *testing.T) { + tests := []struct { + name string + buf []byte + lifetime time.Duration + addrs []tcpip.Address + }{ + { + "Valid1Addr", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + }, + }, + { + "Valid2Addr", + []byte{ + 25, 5, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", + }, + }, + { + "Valid3Addr", + []byte{ + 25, 7, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, + 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17, + }, + 0, + []tcpip.Address{ + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10", + "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11", + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + // Iterator should get our option. + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPRecursiveDNSServerOptionType { + t.Fatalf("got Type %= %d, want = %d", got, NDPRecursiveDNSServerOptionType) + } + + opt, ok := next.(NDPRecursiveDNSServer) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next) + } + if got := opt.Lifetime(); got != test.lifetime { + t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) + } + if got := opt.Addresses(); !cmp.Equal(got, test.addrs) { + t.Errorf("got Addresses = %v, want = %v", got, test.addrs) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } + }) + } +} + // TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions // the iterator was returned for is malformed. func TestNDPOptionsIterCheck(t *testing.T) { @@ -473,6 +643,51 @@ func TestNDPOptionsIterCheck(t *testing.T) { }, nil, }, + { + "InvalidRecursiveDNSServerCutsOffAddress", + []byte{ + 25, 4, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 1, 2, 3, 4, 5, 6, 7, + }, + ErrNDPOptMalformedBody, + }, + { + "InvalidRecursiveDNSServerInvalidLengthField", + []byte{ + 25, 2, 0, 0, + 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, + }, + ErrNDPInvalidLength, + }, + { + "RecursiveDNSServerTooSmall", + []byte{ + 25, 1, 0, 0, + 0, 0, 0, + }, + ErrNDPOptBufExhausted, + }, + { + "RecursiveDNSServerMulticast", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }, + ErrNDPOptMalformedBody, + }, + { + "RecursiveDNSServerUnspecified", + []byte{ + 25, 3, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, + ErrNDPOptMalformedBody, + }, } for _, test := range tests { diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index cfdd0496e..1d202deb5 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -596,7 +596,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // Update the invalidation timer. timer := prefixState.invalidationTimer - if timer == nil && vl >= header.NDPPrefixInformationInfiniteLifetime { + if timer == nil && vl >= header.NDPInfiniteLifetime { // Had infinite valid lifetime before and // continues to have an invalid lifetime. Do // nothing further. @@ -615,7 +615,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { *prefixState.doNotInvalidate = true } - if vl >= header.NDPPrefixInformationInfiniteLifetime { + if vl >= header.NDPInfiniteLifetime { // Prefix is now valid forever so we don't need // an invalidation timer. prefixState.invalidationTimer = nil @@ -734,7 +734,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) var timer *time.Timer // Only create a timer if the lifetime is not infinite. - if l < header.NDPPrefixInformationInfiniteLifetime { + if l < header.NDPInfiniteLifetime { timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate) } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 5b901f947..b2af78212 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1364,10 +1364,10 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // invalidate the prefix. const testInfiniteLifetimeSeconds = 2 const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second - saved := header.NDPPrefixInformationInfiniteLifetime - header.NDPPrefixInformationInfiniteLifetime = testInfiniteLifetime + saved := header.NDPInfiniteLifetime + header.NDPInfiniteLifetime = testInfiniteLifetime defer func() { - header.NDPPrefixInformationInfiniteLifetime = saved + header.NDPInfiniteLifetime = saved }() prefix := tcpip.AddressWithPrefix{ -- cgit v1.2.3 From ab3f7bc39392aaa7e7961ae6d82d94f2cae18adb Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 6 Dec 2019 14:40:10 -0800 Subject: Do IPv6 Stateless Address Auto-Configuration (SLAAC) This change allows the netstack to do SLAAC as outlined by RFC 4862 section 5.5. Note, this change will not break existing uses of netstack as the default configuration for the stack options is set in such a way that SLAAC will not be performed. See `stack.Options` and `stack.NDPConfigurations` for more details. This change reuses 1 option and introduces a new one that is required to take advantage of SLAAC, all available under NDPConfigurations: - HandleRAs: Whether or not NDP RAs are processes - AutoGenGlobalAddresses: Whether or not SLAAC is performed. Also note, this change does not deprecate SLAAC generated addresses after the preferred lifetime. That will come in a later change (b/143713887). Currently, only the valid lifetime is honoured. Tests: Unittest to make sure that SLAAC generates and adds addresses only when configured to do so. Tests also makes sure that conflicts with static addresses do not modify the static address. PiperOrigin-RevId: 284265317 --- pkg/tcpip/header/ipv6.go | 49 +++- pkg/tcpip/header/ndp_options.go | 9 +- pkg/tcpip/stack/ndp.go | 509 ++++++++++++++++++++++++++++----- pkg/tcpip/stack/ndp_test.go | 616 +++++++++++++++++++++++++++++++++++++--- pkg/tcpip/stack/nic.go | 74 +++-- 5 files changed, 1109 insertions(+), 148 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 0caa51c1e..5275b34d4 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -90,6 +90,18 @@ const ( // IPv6Any is the non-routable IPv6 "any" meta address. It is also // known as the unspecified address. IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + // IIDSize is the size of an interface identifier (IID), in bytes, as + // defined by RFC 4291 section 2.5.1. + IIDSize = 8 + + // IIDOffsetInIPv6Address is the offset, in bytes, from the start + // of an IPv6 address to the beginning of the interface identifier + // (IID) for auto-generated addresses. That is, all bytes before + // the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all + // bytes including and after the IIDOffsetInIPv6Address-th byte are + // for the IID. + IIDOffsetInIPv6Address = 8 ) // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the @@ -266,6 +278,28 @@ func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { return solicitedNodeMulticastPrefix + addr[len(addr)-3:] } +// EthernetAdddressToEUI64IntoBuf populates buf with a EUI-64 from a 48-bit +// Ethernet/MAC address. +// +// buf MUST be at least 8 bytes. +func EthernetAdddressToEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) { + buf[0] = linkAddr[0] ^ 2 + buf[1] = linkAddr[1] + buf[2] = linkAddr[2] + buf[3] = 0xFE + buf[4] = 0xFE + buf[5] = linkAddr[3] + buf[6] = linkAddr[4] + buf[7] = linkAddr[5] +} + +// EthernetAddressToEUI64 computes an EUI-64 from a 48-bit Ethernet/MAC address. +func EthernetAddressToEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte { + var buf [IIDSize]byte + EthernetAdddressToEUI64IntoBuf(linkAddr, buf[:]) + return buf +} + // LinkLocalAddr computes the default IPv6 link-local address from a link-layer // (MAC) address. func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { @@ -275,18 +309,11 @@ func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { // The conversion is very nearly: // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff // Note the capital A. The conversion aa->Aa involves a bit flip. - lladdrb := [16]byte{ - 0: 0xFE, - 1: 0x80, - 8: linkAddr[0] ^ 2, - 9: linkAddr[1], - 10: linkAddr[2], - 11: 0xFF, - 12: 0xFE, - 13: linkAddr[3], - 14: linkAddr[4], - 15: linkAddr[5], + lladdrb := [IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, } + EthernetAdddressToEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:]) return tcpip.Address(lladdrb[:]) } diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 2652e7b67..06e0bace2 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -17,6 +17,7 @@ package header import ( "encoding/binary" "errors" + "math" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -109,13 +110,13 @@ const ( ) var ( - // NDPInfiniteLifetime is a value that represents - // infinity for the Valid and Preferred Lifetime fields in a NDP Prefix - // Information option. Its value is (2^32 - 1)s = 4294967295s + // NDPInfiniteLifetime is a value that represents infinity for the + // 4-byte lifetime fields found in various NDP options. Its value is + // (2^32 - 1)s = 4294967295s. // // This is a variable instead of a constant so that tests can change // this value to a smaller value. It should only be modified by tests. - NDPInfiniteLifetime = time.Second * 4294967295 + NDPInfiniteLifetime = time.Second * math.MaxUint32 ) // NDPOptionIterator is an iterator of NDPOption. diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 1d202deb5..75d3ecdac 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -58,6 +58,14 @@ const ( // Default = true. defaultDiscoverOnLinkPrefixes = true + // defaultAutoGenGlobalAddresses is the default configuration for + // whether or not to generate global IPv6 addresses in response to + // receiving a new Prefix Information option with its Autonomous + // Address AutoConfiguration flag set, as a host. + // + // Default = true. + defaultAutoGenGlobalAddresses = true + // minimumRetransmitTimer is the minimum amount of time to wait between // sending NDP Neighbor solicitation messages. Note, RFC 4861 does // not impose a minimum Retransmit Timer, but we do here to make sure @@ -87,6 +95,24 @@ const ( // // Max = 10. MaxDiscoveredOnLinkPrefixes = 10 + + // validPrefixLenForAutoGen is the expected prefix length that an + // address can be generated for. Must be 64 bits as the interface + // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so + // 128 - 64 = 64. + validPrefixLenForAutoGen = 64 +) + +var ( + // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid + // Lifetime to update the valid lifetime of a generated address by + // SLAAC. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // Min = 2hrs. + MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour ) // NDPDispatcher is the interface integrators of netstack must implement to @@ -139,6 +165,22 @@ type NDPDispatcher interface { // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) []tcpip.Route + + // OnAutoGenAddress will be called when a new prefix with its + // autonomous address-configuration flag set has been received and SLAAC + // has been performed. Implementations may prevent the stack from + // assigning the address to the NIC by returning false. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + + // OnAutoGenAddressInvalidated will be called when an auto-generated + // address (as part of SLAAC) has been invalidated. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) } // NDPConfigurations is the NDP configurations for the netstack. @@ -168,6 +210,17 @@ type NDPConfigurations struct { // will be discovered from Router Advertisements' Prefix Information // option. This configuration is ignored if HandleRAs is false. DiscoverOnLinkPrefixes bool + + // AutoGenGlobalAddresses determines whether or not global IPv6 + // addresses will be generated for a NIC in response to receiving a new + // Prefix Information option with its Autonomous Address + // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC). + // + // Note, if an address was already generated for some unique prefix, as + // part of SLAAC, this option does not affect whether or not the + // lifetime(s) of the generated address changes; this option only + // affects the generation of new addresses as part of SLAAC. + AutoGenGlobalAddresses bool } // DefaultNDPConfigurations returns an NDPConfigurations populated with @@ -179,6 +232,7 @@ func DefaultNDPConfigurations() NDPConfigurations { HandleRAs: defaultHandleRAs, DiscoverDefaultRouters: defaultDiscoverDefaultRouters, DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, + AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, } } @@ -210,6 +264,9 @@ type ndpState struct { // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState + + // The addresses generated by SLAAC. + autoGenAddresses map[tcpip.Address]autoGenAddressState } // dadState holds the Duplicate Address Detection timer and channel to signal @@ -270,6 +327,32 @@ type onLinkPrefixState struct { doNotInvalidate *bool } +// autoGenAddressState holds data associated with an address generated via +// SLAAC. +type autoGenAddressState struct { + invalidationTimer *time.Timer + + // Used to signal the timer not to invalidate the SLAAC address (A) in + // a race condition (T1 is a goroutine that handles a PI for A and T2 + // is the goroutine that handles A's invalidation timer firing): + // T1: Receive a new PI for A + // T1: Obtain the NIC's lock before processing the PI + // T2: A's invalidation timer fires, and gets blocked on obtaining the + // NIC's lock + // T1: Refreshes/extends A's lifetime & releases NIC's lock + // T2: Obtains NIC's lock & invalidates A immediately + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using doNotInvalidate to not invalidate A, so that + // once T2 obtains the lock, it will see that it is set to true and do + // nothing further. + doNotInvalidate *bool + + // Nonzero only when the address is not valid forever (invalidationTimer + // is not nil). + validUntil time.Time +} + // startDuplicateAddressDetection performs Duplicate Address Detection. // // This function must only be called by IPv6 addresses that are currently @@ -536,17 +619,14 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { switch opt.Type() { case header.NDPPrefixInformationType: - if !ndp.configs.DiscoverOnLinkPrefixes { - continue - } - pi := opt.(header.NDPPrefixInformation) prefix := pi.Subnet() // Is the prefix a link-local? if header.IsV6LinkLocalAddress(prefix.ID()) { - // ...Yes, skip as per RFC 4861 section 6.3.4. + // ...Yes, skip as per RFC 4861 section 6.3.4, + // and RFC 4862 section 5.5.3.b (for SLAAC). continue } @@ -557,82 +637,13 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { continue } - if !pi.OnLinkFlag() { - // Not on-link so don't "discover" it as an - // on-link prefix. - continue - } - - prefixState, ok := ndp.onLinkPrefixes[prefix] - vl := pi.ValidLifetime() - switch { - case !ok && vl == 0: - // Don't know about this prefix but has a zero - // valid lifetime, so just ignore. - continue - - case !ok && vl != 0: - // This is a new on-link prefix we are - // discovering. - // - // Only remember it if we currently know about - // less than MaxDiscoveredOnLinkPrefixes on-link - // prefixes. - if len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes { - ndp.rememberOnLinkPrefix(prefix, vl) - } - continue - - case ok && vl == 0: - // We know about the on-link prefix, but it is - // no longer to be considered on-link, so - // invalidate it. - ndp.invalidateOnLinkPrefix(prefix) - continue - } - - // This is an already discovered on-link prefix with a - // new non-zero valid lifetime. - // Update the invalidation timer. - timer := prefixState.invalidationTimer - - if timer == nil && vl >= header.NDPInfiniteLifetime { - // Had infinite valid lifetime before and - // continues to have an invalid lifetime. Do - // nothing further. - continue - } - - if timer != nil && !timer.Stop() { - // If we reach this point, then we know the - // timer already fired after we took the NIC - // lock. Inform the timer to not invalidate - // the prefix once it obtains the lock as we - // just got a new PI that refeshes its lifetime - // to a non-zero value. See - // onLinkPrefixState.doNotInvalidate for more - // details. - *prefixState.doNotInvalidate = true + if pi.OnLinkFlag() { + ndp.handleOnLinkPrefixInformation(pi) } - if vl >= header.NDPInfiniteLifetime { - // Prefix is now valid forever so we don't need - // an invalidation timer. - prefixState.invalidationTimer = nil - ndp.onLinkPrefixes[prefix] = prefixState - continue + if pi.AutonomousAddressConfigurationFlag() { + ndp.handleAutonomousPrefixInformation(pi) } - - if timer != nil { - // We already have a timer so just reset it to - // expire after the new valid lifetime. - timer.Reset(vl) - continue - } - - // We do not have a timer so just create a new one. - prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate) - ndp.onLinkPrefixes[prefix] = prefixState } // TODO(b/141556115): Do (MTU) Parameter Discovery. @@ -795,3 +806,345 @@ func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Dur ndp.invalidateOnLinkPrefix(prefix) }) } + +// handleOnLinkPrefixInformation handles a Prefix Information option with +// its on-link flag set, as per RFC 4861 section 6.3.4. +// +// handleOnLinkPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the on-link flag is set. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) { + prefix := pi.Subnet() + prefixState, ok := ndp.onLinkPrefixes[prefix] + vl := pi.ValidLifetime() + + if !ok && vl == 0 { + // Don't know about this prefix but it has a zero valid + // lifetime, so just ignore. + return + } + + if !ok && vl != 0 { + // This is a new on-link prefix we are discovering + // + // Only remember it if we currently know about less than + // MaxDiscoveredOnLinkPrefixes on-link prefixes. + if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes { + ndp.rememberOnLinkPrefix(prefix, vl) + } + return + } + + if ok && vl == 0 { + // We know about the on-link prefix, but it is + // no longer to be considered on-link, so + // invalidate it. + ndp.invalidateOnLinkPrefix(prefix) + return + } + + // This is an already discovered on-link prefix with a + // new non-zero valid lifetime. + // Update the invalidation timer. + timer := prefixState.invalidationTimer + + if timer == nil && vl >= header.NDPInfiniteLifetime { + // Had infinite valid lifetime before and + // continues to have an invalid lifetime. Do + // nothing further. + return + } + + if timer != nil && !timer.Stop() { + // If we reach this point, then we know the timer alread fired + // after we took the NIC lock. Inform the timer to not + // invalidate the prefix once it obtains the lock as we just + // got a new PI that refreshes its lifetime to a non-zero value. + // See onLinkPrefixState.doNotInvalidate for more details. + *prefixState.doNotInvalidate = true + } + + if vl >= header.NDPInfiniteLifetime { + // Prefix is now valid forever so we don't need + // an invalidation timer. + prefixState.invalidationTimer = nil + ndp.onLinkPrefixes[prefix] = prefixState + return + } + + if timer != nil { + // We already have a timer so just reset it to + // expire after the new valid lifetime. + timer.Reset(vl) + return + } + + // We do not have a timer so just create a new one. + prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate) + ndp.onLinkPrefixes[prefix] = prefixState +} + +// handleAutonomousPrefixInformation handles a Prefix Information option with +// its autonomous flag set, as per RFC 4862 section 5.5.3. +// +// handleAutonomousPrefixInformation assumes that the prefix this pi is for is +// not the link-local prefix and the autonomous flag is set. +// +// The NIC that ndp belongs to and its associated stack MUST be locked. +func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { + vl := pi.ValidLifetime() + pl := pi.PreferredLifetime() + + // If the preferred lifetime is greater than the valid lifetime, + // silently ignore the Prefix Information option, as per RFC 4862 + // section 5.5.3.c. + if pl > vl { + return + } + + prefix := pi.Subnet() + + // Check if we already have an auto-generated address for prefix. + for _, ref := range ndp.nic.endpoints { + if ref.protocol != header.IPv6ProtocolNumber { + continue + } + + if ref.configType != slaac { + continue + } + + addr := ref.ep.ID().LocalAddress + refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: ref.ep.PrefixLen()} + if refAddrWithPrefix.Subnet() != prefix { + continue + } + + // + // At this point, we know we are refreshing a SLAAC generated + // IPv6 address with the prefix, prefix. Do the work as outlined + // by RFC 4862 section 5.5.3.e. + // + + addrState, ok := ndp.autoGenAddresses[addr] + if !ok { + panic(fmt.Sprintf("must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr)) + } + + // TODO(b/143713887): Handle deprecating auto-generated address + // after the preferred lifetime. + + // As per RFC 4862 section 5.5.3.e, the valid lifetime of the + // address generated by SLAAC is as follows: + // + // 1) If the received Valid Lifetime is greater than 2 hours or + // greater than RemainingLifetime, set the valid lifetime of + // the address to the advertised Valid Lifetime. + // + // 2) If RemainingLifetime is less than or equal to 2 hours, + // ignore the advertised Valid Lifetime. + // + // 3) Otherwise, reset the valid lifetime of the address to 2 + // hours. + + // Handle the infinite valid lifetime separately as we do not + // keep a timer in this case. + if vl >= header.NDPInfiniteLifetime { + if addrState.invalidationTimer != nil { + // Valid lifetime was finite before, but now it + // is valid forever. + if !addrState.invalidationTimer.Stop() { + *addrState.doNotInvalidate = true + } + addrState.invalidationTimer = nil + addrState.validUntil = time.Time{} + ndp.autoGenAddresses[addr] = addrState + } + + return + } + + var effectiveVl time.Duration + var rl time.Duration + + // If the address was originally set to be valid forever, + // assume the remaining time to be the maximum possible value. + if addrState.invalidationTimer == nil { + rl = header.NDPInfiniteLifetime + } else { + rl = time.Until(addrState.validUntil) + } + + if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { + effectiveVl = vl + } else if rl <= MinPrefixInformationValidLifetimeForUpdate { + ndp.autoGenAddresses[addr] = addrState + return + } else { + effectiveVl = MinPrefixInformationValidLifetimeForUpdate + } + + if addrState.invalidationTimer == nil { + addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate) + } else { + if !addrState.invalidationTimer.Stop() { + *addrState.doNotInvalidate = true + } + addrState.invalidationTimer.Reset(effectiveVl) + } + + addrState.validUntil = time.Now().Add(effectiveVl) + ndp.autoGenAddresses[addr] = addrState + return + } + + // We do not already have an address within the prefix, prefix. Do the + // work as outlined by RFC 4862 section 5.5.3.d if n is configured + // to auto-generated global addresses by SLAAC. + + // Are we configured to auto-generate new global addresses? + if !ndp.configs.AutoGenGlobalAddresses { + return + } + + // If we do not already have an address for this prefix and the valid + // lifetime is 0, no need to do anything further, as per RFC 4862 + // section 5.5.3.d. + if vl == 0 { + return + } + + // Make sure the prefix is valid (as far as its length is concerned) to + // generate a valid IPv6 address from an interface identifier (IID), as + // per RFC 4862 sectiion 5.5.3.d. + if prefix.Prefix() != validPrefixLenForAutoGen { + return + } + + // Only attempt to generate an interface-specific IID if we have a valid + // link address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address + // (provided by LinkEndpoint.LinkAddress) before reaching this + // point. + linkAddr := ndp.nic.linkEP.LinkAddress() + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return + } + + // Generate an address within prefix from the EUI-64 of ndp's NIC's + // Ethernet MAC address. + addrBytes := make([]byte, header.IPv6AddressSize) + copy(addrBytes[:header.IIDOffsetInIPv6Address], prefix.ID()[:header.IIDOffsetInIPv6Address]) + header.EthernetAdddressToEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + addr := tcpip.Address(addrBytes) + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + } + + // If the nic already has this address, do nothing further. + if ndp.nic.hasPermanentAddrLocked(addr) { + return + } + + // Inform the integrator that we have a new SLAAC address. + if ndp.nic.stack.ndpDisp == nil { + return + } + if !ndp.nic.stack.ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) { + // Informed by the integrator not to add the address. + return + } + + if _, err := ndp.nic.addAddressLocked(tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addrWithPrefix, + }, FirstPrimaryEndpoint, permanent, slaac); err != nil { + panic(err) + } + + // Setup the timers to deprecate and invalidate this newly generated + // address. + + // TODO(b/143713887): Handle deprecating auto-generated addresses + // after the preferred lifetime. + + var doNotInvalidate bool + var vTimer *time.Timer + if vl < header.NDPInfiniteLifetime { + vTimer = ndp.autoGenAddrInvalidationTimer(addr, vl, &doNotInvalidate) + } + + ndp.autoGenAddresses[addr] = autoGenAddressState{ + invalidationTimer: vTimer, + doNotInvalidate: &doNotInvalidate, + validUntil: time.Now().Add(vl), + } +} + +// invalidateAutoGenAddress invalidates an auto-generated address. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) { + if !ndp.cleanupAutoGenAddrResourcesAndNotify(addr) { + return + } + + ndp.nic.removePermanentAddressLocked(addr) +} + +// cleanupAutoGenAddrResourcesAndNotify cleans up an invalidated auto-generated +// address's resources from ndp. If the stack has an NDP dispatcher, it will +// be notified that addr has been invalidated. +// +// Returns true if ndp had resources for addr to cleanup. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool { + state, ok := ndp.autoGenAddresses[addr] + + if !ok { + return false + } + + if state.invalidationTimer != nil { + state.invalidationTimer.Stop() + state.invalidationTimer = nil + *state.doNotInvalidate = true + } + + state.doNotInvalidate = nil + + delete(ndp.autoGenAddresses, addr) + + if ndp.nic.stack.ndpDisp != nil { + ndp.nic.stack.ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + }) + } + + return true +} + +// autoGenAddrInvalidationTimer returns a new invalidation timer for an +// auto-generated address that fires after vl. +// +// doNotInvalidate is used to inform the timer when it fires at the same time +// that an auto-generated address's valid lifetime gets refreshed. See +// autoGenAddrState.doNotInvalidate for more details. +func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Duration, doNotInvalidate *bool) *time.Timer { + return time.AfterFunc(vl, func() { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if *doNotInvalidate { + *doNotInvalidate = false + return + } + + ndp.invalidateAutoGenAddress(addr) + }) +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index b2af78212..e9aa20148 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -38,7 +38,7 @@ const ( linkAddr1 = "\x02\x02\x03\x04\x05\x06" linkAddr2 = "\x02\x02\x03\x04\x05\x07" linkAddr3 = "\x02\x02\x03\x04\x05\x08" - defaultTimeout = 250 * time.Millisecond + defaultTimeout = 100 * time.Millisecond ) var ( @@ -47,6 +47,31 @@ var ( llAddr3 = header.LinkLocalAddr(linkAddr3) ) +// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent +// tcpip.Subnet, and an address where the lower half of the address is composed +// of the EUI-64 of linkAddr if it is a valid unicast ethernet address. +func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) { + prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0} + prefix := tcpip.AddressWithPrefix{ + Address: tcpip.Address(prefixBytes), + PrefixLen: 64, + } + + subnet := prefix.Subnet() + + var addr tcpip.AddressWithPrefix + if header.IsValidUnicastEthernetAddress(linkAddr) { + addrBytes := []byte(subnet.ID()) + header.EthernetAdddressToEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: 64, + } + } + + return prefix, subnet, addr +} + // TestDADDisabled tests that an address successfully resolves immediately // when DAD is not enabled (the default for an empty stack.Options). func TestDADDisabled(t *testing.T) { @@ -103,6 +128,19 @@ type ndpPrefixEvent struct { discovered bool } +type ndpAutoGenAddrEventType int + +const ( + newAddr ndpAutoGenAddrEventType = iota + invalidatedAddr +) + +type ndpAutoGenAddrEvent struct { + nicID tcpip.NICID + addr tcpip.AddressWithPrefix + eventType ndpAutoGenAddrEventType +} + var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP @@ -113,6 +151,7 @@ type ndpDispatcher struct { rememberRouter bool prefixC chan ndpPrefixEvent rememberPrefix bool + autoGenAddrC chan ndpAutoGenAddrEvent routeTable []tcpip.Route } @@ -211,7 +250,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi } } - rt := make([]tcpip.Route, 0) + var rt []tcpip.Route exclude := tcpip.Route{ Destination: prefix, NIC: nicID, @@ -226,6 +265,27 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi return rt } +func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { + if n.autoGenAddrC != nil { + n.autoGenAddrC <- ndpAutoGenAddrEvent{ + nicID, + addr, + newAddr, + } + } + return true +} + +func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { + if n.autoGenAddrC != nil { + n.autoGenAddrC <- ndpAutoGenAddrEvent{ + nicID, + addr, + invalidatedAddr, + } + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -247,6 +307,8 @@ func TestDADResolve(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } @@ -781,16 +843,33 @@ func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { // // Note, raBufWithPI does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink bool, vl uint32) tcpip.PacketBuffer { +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer { flags := uint8(0) if onLink { - flags |= 128 + // The OnLink flag is the 7th bit in the flags byte. + flags |= 1 << 7 + } + if auto { + // The Address Auto-Configuration flag is the 6th bit in the + // flags byte. + flags |= 1 << 6 } + // A valid header.NDPPrefixInformation must be 30 bytes. buf := [30]byte{} + // The first byte in a header.NDPPrefixInformation is the Prefix Length + // field. buf[0] = uint8(prefix.PrefixLen) + // The 2nd byte within a header.NDPPrefixInformation is the Flags field. buf[1] = flags + // The Valid Lifetime field starts after the 2nd byte within a + // header.NDPPrefixInformation. binary.BigEndian.PutUint32(buf[2:], vl) + // The Preferred Lifetime field starts after the 6th byte within a + // header.NDPPrefixInformation. + binary.BigEndian.PutUint32(buf[6:], pl) + // The Prefix Address field starts after the 14th byte within a + // header.NDPPrefixInformation. copy(buf[14:], prefix.Address) return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{ header.NDPPrefixInformation(buf[:]), @@ -800,6 +879,8 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on // TestNoRouterDiscovery tests that router discovery will not be performed if // configured not to. func TestNoRouterDiscovery(t *testing.T) { + t.Parallel() + // Being configured to discover routers means handle and // discover are set to true and forwarding is set to false. // This tests all possible combinations of the configurations, @@ -812,6 +893,8 @@ func TestNoRouterDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 10), } @@ -844,6 +927,8 @@ func TestNoRouterDiscovery(t *testing.T) { // TestRouterDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered router when the dispatcher asks it not to. func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 10), } @@ -909,6 +994,8 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } func TestRouterDiscovery(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 10), rememberRouter: true, @@ -1040,6 +1127,8 @@ func TestRouterDiscovery(t *testing.T) { // TestRouterDiscoveryMaxRouters tests that only // stack.MaxDiscoveredDefaultRouters discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 10), rememberRouter: true, @@ -1104,6 +1193,8 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { // TestNoPrefixDiscovery tests that prefix discovery will not be performed if // configured not to. func TestNoPrefixDiscovery(t *testing.T) { + t.Parallel() + prefix := tcpip.AddressWithPrefix{ Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), PrefixLen: 64, @@ -1121,6 +1212,8 @@ func TestNoPrefixDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, 10), } @@ -1140,7 +1233,7 @@ func TestNoPrefixDiscovery(t *testing.T) { } // Rx an RA with prefix with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, 10)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0)) select { case <-ndpDisp.prefixC: @@ -1154,11 +1247,9 @@ func TestNoPrefixDiscovery(t *testing.T) { // TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered on-link prefix when the dispatcher asks it not to. func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { - prefix := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - subnet := prefix.Subnet() + t.Parallel() + + prefix, subnet, _ := prefixSubnetAddr(0, "") ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, 10), @@ -1189,7 +1280,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { // Rx an RA with prefix with a short lifetime. const lifetime = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, lifetime)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetime, 0)) select { case r := <-ndpDisp.prefixC: if r.nicID != 1 { @@ -1226,21 +1317,11 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { } func TestPrefixDiscovery(t *testing.T) { - prefix1 := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - prefix2 := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x09\x00\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 64, - } - prefix3 := tcpip.AddressWithPrefix{ - Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x09\x0a\x00\x00\x00\x00\x00\x00\x00"), - PrefixLen: 72, - } - subnet1 := prefix1.Subnet() - subnet2 := prefix2.Subnet() - subnet3 := prefix3.Subnet() + t.Parallel() + + prefix1, subnet1, _ := prefixSubnetAddr(0, "") + prefix2, subnet2, _ := prefixSubnetAddr(1, "") + prefix3, subnet3, _ := prefixSubnetAddr(2, "") ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, 10), @@ -1281,7 +1362,7 @@ func TestPrefixDiscovery(t *testing.T) { // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly discovered a prefix with 0 lifetime") @@ -1290,7 +1371,7 @@ func TestPrefixDiscovery(t *testing.T) { // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) waitForEvent(subnet1, true, defaultTimeout) // Should have added a device route for subnet1 through the nic. @@ -1299,7 +1380,7 @@ func TestPrefixDiscovery(t *testing.T) { } // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) waitForEvent(subnet2, true, defaultTimeout) // Should have added a device route for subnet2 through the nic. @@ -1308,7 +1389,7 @@ func TestPrefixDiscovery(t *testing.T) { } // Receive an RA with prefix3 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) waitForEvent(subnet3, true, defaultTimeout) // Should have added a device route for subnet3 through the nic. @@ -1317,7 +1398,7 @@ func TestPrefixDiscovery(t *testing.T) { } // Receive an RA with prefix1 in a PI with lifetime = 0. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) waitForEvent(subnet1, false, defaultTimeout) // Should have removed the device route for subnet1 through the nic. @@ -1327,7 +1408,7 @@ func TestPrefixDiscovery(t *testing.T) { // Receive an RA with prefix2 in a PI with lesser lifetime. lifetime := uint32(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, lifetime)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly received prefix event when updating lifetime") @@ -1349,7 +1430,7 @@ func TestPrefixDiscovery(t *testing.T) { } // Receive RA to invalidate prefix3. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) waitForEvent(subnet3, false, defaultTimeout) // Should not have any routes. @@ -1415,7 +1496,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with prefix in an NDP Prefix Information option (PI) // with infinite valid lifetime which should not get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) waitForEvent(true, defaultTimeout) select { case <-ndpDisp.prefixC: @@ -1425,16 +1506,16 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with finite lifetime. // The prefix should get invalidated after 1s. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds-1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) waitForEvent(false, testInfiniteLifetime) // Receive an RA with finite lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds-1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) waitForEvent(true, defaultTimeout) // Receive an RA with prefix with an infinite lifetime. // The prefix should not be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") @@ -1443,7 +1524,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with a prefix with a lifetime value greater than the // set infinite lifetime value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, testInfiniteLifetimeSeconds+1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") @@ -1452,13 +1533,15 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with 0 lifetime. // The prefix should get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0)) waitForEvent(false, defaultTimeout) } // TestPrefixDiscoveryMaxRouters tests that only // stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), rememberPrefix: true, @@ -1537,3 +1620,458 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt) } } + +// Checks to see if list contains an IPv6 address, item. +func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { + protocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: item, + } + + for _, i := range list { + if i == protocolAddress { + return true + } + } + + return false +} + +// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. +func TestNoAutoGenAddr(t *testing.T) { + t.Parallel() + + prefix, _, _ := prefixSubnetAddr(0, "") + + // Being configured to auto-generate addresses means handle and + // autogen are set to true and forwarding is set to false. + // This tests all possible combinations of the configurations, + // except for the configuration where handle = true, autogen = + // true and forwarding = false (the required configuration to do + // SLAAC) - that will done in other tests. + for i := 0; i < 7; i++ { + handle := i&1 != 0 + autogen := i&2 != 0 + forwarding := i&4 == 0 + + t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: handle, + AutoGenGlobalAddresses: autogen, + }, + NDPDisp: &ndpDisp, + }) + s.SetForwarding(forwarding) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Rx an RA with prefix with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0)) + + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when configured not to") + case <-time.After(defaultTimeout): + } + }) + } +} + +// TestAutoGenAddr tests that an address is properly generated and invalidated +// when configured to do so. +func TestAutoGenAddr(t *testing.T) { + const newMinVL = 2 + newMinVLDuration := newMinVL * time.Second + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + waitForEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case r := <-ndpDisp.autoGenAddrC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.addr != addr { + t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + } + if r.eventType != eventType { + t.Fatalf("got r.eventType = %v, want = %v", r.eventType, eventType) + } + case <-time.After(timeout): + t.Fatal("timeout waiting for addr auto gen event") + } + } + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with 0 lifetime") + case <-time.After(defaultTimeout): + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + waitForEvent(addr1, newAddr, defaultTimeout) + if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") + case <-time.After(defaultTimeout): + } + + // Receive an RA with prefix2 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + waitForEvent(addr2, newAddr, defaultTimeout) + if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } + + // Refresh valid lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") + case <-time.After(defaultTimeout): + } + + // Wait for addr of prefix1 to be invalidated. + waitForEvent(addr1, invalidatedAddr, newMinVLDuration+defaultTimeout) + if contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + t.Fatalf("Should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + t.Fatalf("Should have %s in the list of addresses", addr2) + } +} + +// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an +// auto-generated address only gets updated when required to, as specified in +// RFC 4862 section 5.5.3.e. +func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { + const infiniteVL = 4294967295 + const newMinVL = 5 + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + tests := []struct { + name string + ovl uint32 + nvl uint32 + evl uint32 + }{ + // Should update the VL to the minimum VL for updating if the + // new VL is less than newMinVL but was originally greater than + // it. + { + "LargeVLToVLLessThanMinVLForUpdate", + 9999, + 1, + newMinVL, + }, + { + "LargeVLTo0", + 9999, + 0, + newMinVL, + }, + { + "InfiniteVLToVLLessThanMinVLForUpdate", + infiniteVL, + 1, + newMinVL, + }, + { + "InfiniteVLTo0", + infiniteVL, + 0, + newMinVL, + }, + + // Should not update VL if original VL was less than newMinVL + // and the new VL is also less than newMinVL. + { + "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", + newMinVL - 1, + newMinVL - 3, + newMinVL - 1, + }, + + // Should take the new VL if the new VL is greater than the + // remaining time or is greater than newMinVL. + { + "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", + newMinVL + 5, + newMinVL + 3, + newMinVL + 3, + }, + { + "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", + newMinVL - 3, + newMinVL - 1, + newMinVL - 1, + }, + { + "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", + newMinVL - 3, + newMinVL + 1, + newMinVL + 1, + }, + } + + const delta = 500 * time.Millisecond + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix with initial VL, test.ovl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) + select { + case r := <-ndpDisp.autoGenAddrC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.addr != addr { + t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + } + if r.eventType != newAddr { + t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) + } + case <-time.After(defaultTimeout): + t.Fatal("timeout waiting for addr auto gen event") + } + + // Receive an new RA with prefix with new VL, test.nvl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) + + // + // Validate that the VL for the address got set to + // test.evl. + // + + // Make sure we do not get any invalidation events + // until atleast 500ms (delta) before test.evl. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly received an auto gen addr event") + case <-time.After(time.Duration(test.evl)*time.Second - delta): + } + + // Wait for another second (2x delta), but now we expect + // the invalidation event. + select { + case r := <-ndpDisp.autoGenAddrC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.addr != addr { + t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + } + if r.eventType != invalidatedAddr { + t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) + } + case <-time.After(2 * delta): + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } +} + +// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed +// by the user, its resources will be cleaned up and an invalidation event will +// be sent to the integrator. +func TestAutoGenAddrRemoval(t *testing.T) { + t.Parallel() + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with prefix with its valid lifetime = lifetime. + const lifetime = 5 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetime, 0)) + select { + case r := <-ndpDisp.autoGenAddrC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.addr != addr { + t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + } + if r.eventType != newAddr { + t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) + } + case <-time.After(defaultTimeout): + t.Fatal("timeout waiting for addr auto gen event") + } + + // Remove the address. + if err := s.RemoveAddress(1, addr.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err) + } + + // Should get the invalidation event immediately. + select { + case r := <-ndpDisp.autoGenAddrC: + if r.nicID != 1 { + t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + } + if r.addr != addr { + t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + } + if r.eventType != invalidatedAddr { + t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) + } + case <-time.After(defaultTimeout): + t.Fatal("timeout waiting for addr auto gen event") + } + + // Wait for the original valid lifetime to make sure the original timer + // got stopped/cleaned up. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly received an auto gen addr event") + case <-time.After(lifetime*time.Second + defaultTimeout): + } +} + +// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that +// is already assigned to the NIC, the static address remains. +func TestAutoGenAddrStaticConflict(t *testing.T) { + t.Parallel() + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Add the address as a static address before SLAAC tries to add it. + if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Receive a PI where the generated address will be the same as the one + // that we already have assigned statically. + const lifetime = 5 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetime, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") + case <-time.After(defaultTimeout): + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } + + // Should not get an invalidation event after the PI's invalidation + // time. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + case <-time.After(lifetime*time.Second + defaultTimeout): + } + if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + t.Fatalf("Should have %s in the list of addresses", addr1) + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3f8d7312c..e8401c673 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -115,10 +115,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback }, }, ndp: ndpState{ - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), }, } nic.ndp.nic = nic @@ -244,6 +245,20 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return nil } +// hasPermanentAddrLocked returns true if n has a permanent (including currently +// tentative) address, addr. +func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { + ref, ok := n.endpoints[NetworkEndpointID{addr}] + + if !ok { + return false + } + + kind := ref.getKind() + + return kind == permanent || kind == permanentTentative +} + func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) } @@ -335,7 +350,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t Address: address, PrefixLen: netProto.DefaultPrefixLen(), }, - }, peb, temporary) + }, peb, temporary, static) n.mu.Unlock() return ref @@ -384,10 +399,10 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } } - return n.addAddressLocked(protocolAddress, peb, permanent) + return n.addAddressLocked(protocolAddress, peb, permanent, static) } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType) (*referencedNetworkEndpoint, *tcpip.Error) { // TODO(b/141022673): Validate IP address before adding them. // Sanity check. @@ -417,11 +432,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - kind: kind, + refs: 1, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, + configType: configType, } // Set up cache if link address resolution exists for this protocol. @@ -624,9 +640,18 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) - // If we are removing a tentative IPv6 unicast address, stop DAD. - if isIPv6Unicast && kind == permanentTentative { - n.ndp.stopDuplicateAddressDetection(addr) + if isIPv6Unicast { + // If we are removing a tentative IPv6 unicast address, stop + // DAD. + if kind == permanentTentative { + n.ndp.stopDuplicateAddressDetection(addr) + } + + // If we are removing an address generated via SLAAC, cleanup + // its SLAAC resources and notify the integrator. + if r.configType == slaac { + n.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) + } } r.setKind(permanentExpired) @@ -989,7 +1014,7 @@ const ( // removing the permanent address from the NIC. permanent - // An expired permanent endoint is a permanent endoint that had its address + // An expired permanent endpoint is a permanent endpoint that had its address // removed from the NIC, and it is waiting to be removed once no more routes // hold a reference to it. This is achieved by decreasing its reference count // by 1. If its address is re-added before the endpoint is removed, its type @@ -1035,6 +1060,19 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep } } +type networkEndpointConfigType int32 + +const ( + // A statically configured endpoint is an address that was added by + // some user-specified action (adding an explicit address, joining a + // multicast group). + static networkEndpointConfigType = iota + + // A slaac configured endpoint is an IPv6 endpoint that was + // added by SLAAC as per RFC 4862 section 5.5.3. + slaac +) + type referencedNetworkEndpoint struct { ep NetworkEndpoint nic *NIC @@ -1050,6 +1088,10 @@ type referencedNetworkEndpoint struct { // networkEndpointKind must only be accessed using {get,set}Kind(). kind networkEndpointKind + + // configType is the method that was used to configure this endpoint. + // This must never change after the endpoint is added to a NIC. + configType networkEndpointConfigType } func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { -- cgit v1.2.3 From 4ff71b5be462f3f16808abea11de1f5e01567e5d Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 10 Dec 2019 18:03:43 -0800 Subject: Inform the integrator on receipt of an NDP Recursive DNS Server option This change adds support to let an integrator know when it receives an NDP Router Advertisement message with the NDP Recursive DNS Server option with at least one DNS server's address. The stack will not maintain any state related to the DNS servers - the integrator is expected to maintain any required state and invalidate the servers after its valid lifetime expires, or refresh the lifetime when a new one is received for a known DNS server. Test: Unittest to make sure that an event is sent to the integrator when an NDP Recursive DNS Server option is received with at least one address. PiperOrigin-RevId: 284890502 --- pkg/tcpip/stack/ndp.go | 32 ++++++--- pkg/tcpip/stack/ndp_test.go | 172 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+), 8 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 75d3ecdac..060a2e7c6 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -181,6 +181,17 @@ type NDPDispatcher interface { // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix) + + // OnRecursiveDNSServerOption will be called when an NDP option with + // recursive DNS servers has been received. Note, addrs may contain + // link-local addresses. + // + // It is up to the caller to use the DNS Servers only for their valid + // lifetime. OnRecursiveDNSServerOption may be called for new or + // already known DNS servers. If called with known DNS servers, their + // valid lifetimes must be refreshed to lifetime (it may be increased, + // decreased, or completely invalidated when lifetime = 0). + OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) } // NDPConfigurations is the NDP configurations for the netstack. @@ -617,11 +628,16 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // we do not check the iterator for errors on calls to Next. it, _ := ra.Options().Iter(false) for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { - switch opt.Type() { - case header.NDPPrefixInformationType: - pi := opt.(header.NDPPrefixInformation) + switch opt := opt.(type) { + case header.NDPRecursiveDNSServer: + if ndp.nic.stack.ndpDisp == nil { + continue + } + + ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), opt.Addresses(), opt.Lifetime()) - prefix := pi.Subnet() + case header.NDPPrefixInformation: + prefix := opt.Subnet() // Is the prefix a link-local? if header.IsV6LinkLocalAddress(prefix.ID()) { @@ -637,12 +653,12 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { continue } - if pi.OnLinkFlag() { - ndp.handleOnLinkPrefixInformation(pi) + if opt.OnLinkFlag() { + ndp.handleOnLinkPrefixInformation(opt) } - if pi.AutonomousAddressConfigurationFlag() { - ndp.handleAutonomousPrefixInformation(pi) + if opt.AutonomousAddressConfigurationFlag() { + ndp.handleAutonomousPrefixInformation(opt) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index e9aa20148..8d811eb8e 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -141,6 +141,16 @@ type ndpAutoGenAddrEvent struct { eventType ndpAutoGenAddrEventType } +type ndpRDNSS struct { + addrs []tcpip.Address + lifetime time.Duration +} + +type ndpRDNSSEvent struct { + nicID tcpip.NICID + rdnss ndpRDNSS +} + var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP @@ -152,6 +162,7 @@ type ndpDispatcher struct { prefixC chan ndpPrefixEvent rememberPrefix bool autoGenAddrC chan ndpAutoGenAddrEvent + rdnssC chan ndpRDNSSEvent routeTable []tcpip.Route } @@ -286,6 +297,19 @@ func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpi } } +// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. +func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { + if n.rdnssC != nil { + n.rdnssC <- ndpRDNSSEvent{ + nicID, + ndpRDNSS{ + addrs, + lifetime, + }, + } + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -2075,3 +2099,151 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { t.Fatalf("Should have %s in the list of addresses", addr1) } } + +// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event +// to the integrator when an RA is received with the NDP Recursive DNS Server +// option with at least one valid address. +func TestNDPRecursiveDNSServerDispatch(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opt header.NDPRecursiveDNSServer + expected *ndpRDNSS + }{ + { + "Unspecified", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }), + nil, + }, + { + "Multicast", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + }), + nil, + }, + { + "OptionTooSmall", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, + }), + nil, + }, + { + "0Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + }), + nil, + }, + { + "Valid1Address", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + }, + 2 * time.Second, + }, + }, + { + "Valid2Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", + }, + time.Second, + }, + }, + { + "Valid3Addresses", + header.NDPRecursiveDNSServer([]byte{ + 0, 0, + 0, 0, 0, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2, + 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3, + }), + &ndpRDNSS{ + []tcpip.Address{ + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02", + "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03", + }, + 0, + }, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + // We do not expect more than a single RDNSS + // event at any time for this test. + rdnssC: make(chan ndpRDNSSEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt})) + + if test.expected != nil { + select { + case e := <-ndpDisp.rdnssC: + if e.nicID != 1 { + t.Errorf("got rdnss nicID = %d, want = 1", e.nicID) + } + if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" { + t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff) + } + if e.rdnss.lifetime != test.expected.lifetime { + t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime) + } + default: + t.Fatal("expected an RDNSS option event") + } + } + + // Should have no more RDNSS options. + select { + case e := <-ndpDisp.rdnssC: + t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e) + default: + } + }) + } +} -- cgit v1.2.3 From ad80dcf47077a1938631fe36f6b406256f3f3f4f Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 13 Dec 2019 16:26:06 -0800 Subject: Properly generate the EUI64 interface identifier from an Ethernet address Fixed a bug where the interface identifier was not properly generated from an Ethernet address. Tests: Unittests to make sure the functions generating the EUI64 interface identifier are correct. PiperOrigin-RevId: 285494562 --- pkg/tcpip/header/BUILD | 3 +++ pkg/tcpip/header/ipv6.go | 21 ++++++++++---------- pkg/tcpip/header/ipv6_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++ pkg/tcpip/stack/ndp.go | 6 +++--- pkg/tcpip/stack/ndp_test.go | 2 +- 5 files changed, 63 insertions(+), 14 deletions(-) create mode 100644 pkg/tcpip/header/ipv6_test.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 8392cb9e5..f1d837196 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -38,12 +38,15 @@ go_test( size = "small", srcs = [ "checksum_test.go", + "ipv6_test.go", "ipversion_test.go", "tcp_test.go", ], deps = [ ":header", + "//pkg/tcpip", "//pkg/tcpip/buffer", + "@com_github_google_go-cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 5275b34d4..fc671e439 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -278,33 +278,34 @@ func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { return solicitedNodeMulticastPrefix + addr[len(addr)-3:] } -// EthernetAdddressToEUI64IntoBuf populates buf with a EUI-64 from a 48-bit -// Ethernet/MAC address. +// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64 +// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1. // // buf MUST be at least 8 bytes. -func EthernetAdddressToEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) { +func EthernetAdddressToModifiedEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) { buf[0] = linkAddr[0] ^ 2 buf[1] = linkAddr[1] buf[2] = linkAddr[2] - buf[3] = 0xFE + buf[3] = 0xFF buf[4] = 0xFE buf[5] = linkAddr[3] buf[6] = linkAddr[4] buf[7] = linkAddr[5] } -// EthernetAddressToEUI64 computes an EUI-64 from a 48-bit Ethernet/MAC address. -func EthernetAddressToEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte { +// EthernetAddressToModifiedEUI64 computes a modified EUI-64 from a 48-bit +// Ethernet/MAC address, as per RFC 4291 section 2.5.1. +func EthernetAddressToModifiedEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte { var buf [IIDSize]byte - EthernetAdddressToEUI64IntoBuf(linkAddr, buf[:]) + EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) return buf } // LinkLocalAddr computes the default IPv6 link-local address from a link-layer // (MAC) address. func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { - // Convert a 48-bit MAC to an EUI-64 and then prepend the link-local - // header, FE80::. + // Convert a 48-bit MAC to a modified EUI-64 and then prepend the + // link-local header, FE80::. // // The conversion is very nearly: // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff @@ -313,7 +314,7 @@ func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { 0: 0xFE, 1: 0x80, } - EthernetAdddressToEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:]) + EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:]) return tcpip.Address(lladdrb[:]) } diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go new file mode 100644 index 000000000..42c5c6fc1 --- /dev/null +++ b/pkg/tcpip/header/ipv6_test.go @@ -0,0 +1,45 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + +func TestEthernetAdddressToModifiedEUI64(t *testing.T) { + expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6} + + if diff := cmp.Diff(expectedIID, header.EthernetAddressToModifiedEUI64(linkAddr)); diff != "" { + t.Errorf("EthernetAddressToModifiedEUI64(%s) mismatch (-want +got):\n%s", linkAddr, diff) + } + + var buf [header.IIDSize]byte + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:]) + if diff := cmp.Diff(expectedIID, buf); diff != "" { + t.Errorf("EthernetAddressToModifiedEUI64IntoBuf(%s, _) mismatch (-want +got):\n%s", linkAddr, diff) + } +} + +func TestLinkLocalAddr(t *testing.T) { + if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want { + t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) + } +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 060a2e7c6..27bd02e76 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -1049,11 +1049,11 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform return } - // Generate an address within prefix from the EUI-64 of ndp's NIC's - // Ethernet MAC address. + // Generate an address within prefix from the modified EUI-64 of ndp's + // NIC's Ethernet MAC address. addrBytes := make([]byte, header.IPv6AddressSize) copy(addrBytes[:header.IIDOffsetInIPv6Address], prefix.ID()[:header.IIDOffsetInIPv6Address]) - header.EthernetAdddressToEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) addr := tcpip.Address(addrBytes) addrWithPrefix := tcpip.AddressWithPrefix{ Address: addr, diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 8d811eb8e..d8e7ce67e 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -62,7 +62,7 @@ func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWi var addr tcpip.AddressWithPrefix if header.IsValidUnicastEthernetAddress(linkAddr) { addrBytes := []byte(subnet.ID()) - header.EthernetAdddressToEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) addr = tcpip.AddressWithPrefix{ Address: tcpip.Address(addrBytes), PrefixLen: 64, -- cgit v1.2.3 From 628948b1e197010177acc08ddc9f93d9925fba6b Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 18 Dec 2019 13:35:52 -0800 Subject: Cleanup NDP Tests This change makes sure that test variables are captured before running tests in parallel, and removes unneeded buffered channel allocations. This change also removes unnecessary timeouts. PiperOrigin-RevId: 286255066 --- pkg/tcpip/stack/ndp_test.go | 625 ++++++++++++++++++++++---------------------- 1 file changed, 308 insertions(+), 317 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index d8e7ce67e..9f589a471 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -79,7 +79,7 @@ func TestDADDisabled(t *testing.T) { NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -330,6 +330,8 @@ func TestDADResolve(t *testing.T) { } for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { t.Parallel() @@ -520,7 +522,7 @@ func TestDADFail(t *testing.T) { } opts.NDPConfigs.RetransmitTimer = time.Second * 2 - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -601,7 +603,7 @@ func TestDADStop(t *testing.T) { NDPConfigs: ndpConfigs, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) @@ -702,7 +704,7 @@ func TestSetNDPConfigurations(t *testing.T) { ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPDisp: &ndpDisp, @@ -903,8 +905,6 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on // TestNoRouterDiscovery tests that router discovery will not be performed if // configured not to. func TestNoRouterDiscovery(t *testing.T) { - t.Parallel() - // Being configured to discover routers means handle and // discover are set to true and forwarding is set to false. // This tests all possible combinations of the configurations, @@ -920,9 +920,9 @@ func TestNoRouterDiscovery(t *testing.T) { t.Parallel() ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), + routerC: make(chan ndpRouterEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -942,21 +942,27 @@ func TestNoRouterDiscovery(t *testing.T) { select { case <-ndpDisp.routerC: t.Fatal("unexpectedly discovered a router when configured not to") - case <-time.After(defaultTimeout): + default: } }) } } +// Check e to make sure that the event is for addr on nic with ID 1, and the +// discovered flag set to discovered. +func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { + return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) +} + // TestRouterDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered router when the dispatcher asks it not to. func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { t.Parallel() ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), + routerC: make(chan ndpRouterEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -979,41 +985,35 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } s.SetRouteTable(routeTable) - // Rx an RA with short lifetime. - lifetime := time.Duration(1) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(lifetime))) + // Receive an RA for a router we should not remember. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) select { - case r := <-ndpDisp.routerC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != llAddr2 { - t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr2) - } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr2, true); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for router discovery event") + default: + t.Fatal("expected router discovery event") } // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + if diff := cmp.Diff(routeTable, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } - // Wait for the normal invalidation time plus an extra second to - // make sure we do not actually receive any invalidation events as - // we should not have remembered the router in the first place. + // Wait for the invalidation time plus some buffer to make sure we do + // not actually receive any invalidation events as we should not have + // remembered the router in the first place. select { case <-ndpDisp.routerC: t.Fatal("should not have received any router events") - case <-time.After(lifetime*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + if diff := cmp.Diff(routeTable, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } } @@ -1021,10 +1021,10 @@ func TestRouterDiscovery(t *testing.T) { t.Parallel() ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), + routerC: make(chan ndpRouterEvent, 1), rememberRouter: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1034,22 +1034,29 @@ func TestRouterDiscovery(t *testing.T) { NDPDisp: &ndpDisp, }) - waitForEvent := func(addr tcpip.Address, discovered bool, timeout time.Duration) { + expectRouterEvent := func(addr tcpip.Address, discovered bool) { t.Helper() select { - case r := <-ndpDisp.routerC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, discovered); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) - } - if r.discovered != discovered { - t.Fatalf("got r.discovered = %t, want = %t", r.discovered, discovered) + default: + t.Fatal("expected router discovery event") + } + } + + expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, addr, false); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) } case <-time.After(timeout): - t.Fatal("timeout waiting for router discovery event") + t.Fatal("timed out waiting for router discovery event") } } @@ -1063,26 +1070,27 @@ func TestRouterDiscovery(t *testing.T) { select { case <-ndpDisp.routerC: t.Fatal("unexpectedly discovered a router with 0 lifetime") - case <-time.After(defaultTimeout): + default: } // Rx an RA from lladdr2 with a huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - waitForEvent(llAddr2, true, defaultTimeout) + expectRouterEvent(llAddr2, true) // Should have a default route through the discovered router. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + if diff := cmp.Diff([]tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}}, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } // Rx an RA from another router (lladdr3) with non-zero lifetime. l3Lifetime := time.Duration(6) e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime))) - waitForEvent(llAddr3, true, defaultTimeout) + expectRouterEvent(llAddr3, true) // Should have default routes through the discovered routers. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + want := []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}} + if diff := cmp.Diff(want, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } // Rx an RA from lladdr2 with lesser lifetime. @@ -1091,12 +1099,12 @@ func TestRouterDiscovery(t *testing.T) { select { case <-ndpDisp.routerC: t.Fatal("Should not receive a router event when updating lifetimes for known routers") - case <-time.After(defaultTimeout): + default: } // Should still have a default route through the discovered routers. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + if diff := cmp.Diff(want, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } // Wait for lladdr2's router invalidation timer to fire. The lifetime @@ -1106,30 +1114,30 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - waitForEvent(llAddr2, false, l2Lifetime*time.Second+defaultTimeout) + expectAsyncRouterInvalidationEvent(llAddr2, l2Lifetime*time.Second+defaultTimeout) // Should no longer have the default route through lladdr2. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + if diff := cmp.Diff([]tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } // Rx an RA from lladdr2 with huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - waitForEvent(llAddr2, true, defaultTimeout) + expectRouterEvent(llAddr2, true) // Should have a default route through the discovered routers. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}, {header.IPv6EmptySubnet, llAddr2, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + if diff := cmp.Diff([]tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}, {header.IPv6EmptySubnet, llAddr2, 1}}, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - waitForEvent(llAddr2, false, defaultTimeout) + expectRouterEvent(llAddr2, false) // Should have deleted the default route through the router that just // got invalidated. - if got, want := s.GetRouteTable(), []tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + if diff := cmp.Diff([]tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } // Wait for lladdr3's router invalidation timer to fire. The lifetime @@ -1139,7 +1147,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - waitForEvent(llAddr3, false, l3Lifetime*time.Second+defaultTimeout) + expectAsyncRouterInvalidationEvent(llAddr3, l3Lifetime*time.Second+defaultTimeout) // Should not have any routes now that all discovered routers have been // invalidated. @@ -1154,10 +1162,10 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { t.Parallel() ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 10), + routerC: make(chan ndpRouterEvent, 1), rememberRouter: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1184,41 +1192,33 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { if i <= stack.MaxDiscoveredDefaultRouters { expectedRt[i-1] = tcpip.Route{header.IPv6EmptySubnet, llAddr, 1} select { - case r := <-ndpDisp.routerC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != llAddr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, llAddr) - } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") + case e := <-ndpDisp.routerC: + if diff := checkRouterEvent(e, llAddr, true); diff != "" { + t.Errorf("router event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for router discovery event") + default: + t.Fatal("expected router discovery event") } } else { select { case <-ndpDisp.routerC: t.Fatal("should not have discovered a new router after we already discovered the max number of routers") - case <-time.After(defaultTimeout): + default: } } } // Should only have default routes for the first // stack.MaxDiscoveredDefaultRouters discovered routers. - if got := s.GetRouteTable(); !cmp.Equal(got, expectedRt[:]) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt) + if diff := cmp.Diff(expectedRt[:], s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } } // TestNoPrefixDiscovery tests that prefix discovery will not be performed if // configured not to. func TestNoPrefixDiscovery(t *testing.T) { - t.Parallel() - prefix := tcpip.AddressWithPrefix{ Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), PrefixLen: 64, @@ -1239,9 +1239,9 @@ func TestNoPrefixDiscovery(t *testing.T) { t.Parallel() ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), + prefixC: make(chan ndpPrefixEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1262,12 +1262,18 @@ func TestNoPrefixDiscovery(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly discovered a prefix when configured not to") - case <-time.After(defaultTimeout): + default: } }) } } +// Check e to make sure that the event is for prefix on nic with ID 1, and the +// discovered flag set to discovered. +func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string { + return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) +} + // TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered on-link prefix when the dispatcher asks it not to. func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { @@ -1276,9 +1282,9 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { prefix, subnet, _ := prefixSubnetAddr(0, "") ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), + prefixC: make(chan ndpPrefixEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1302,41 +1308,35 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { } s.SetRouteTable(routeTable) - // Rx an RA with prefix with a short lifetime. - const lifetime = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetime, 0)) + // Receive an RA with prefix that we should not remember. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, true); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - if r.prefix != subnet { - t.Fatalf("got r.prefix = %s, want = %s", r.prefix, subnet) - } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for prefix discovery event") + default: + t.Fatal("expected prefix discovery event") } // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + if diff := cmp.Diff(routeTable, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } - // Wait for the normal invalidation time plus some buffer to - // make sure we do not actually receive any invalidation events as - // we should not have remembered the prefix in the first place. + // Wait for the invalidation time plus some buffer to make sure we do + // not actually receive any invalidation events as we should not have + // remembered the prefix in the first place. select { case <-ndpDisp.prefixC: t.Fatal("should not have received any prefix events") - case <-time.After(lifetime*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } // Original route table should not have been modified. - if got := s.GetRouteTable(); !cmp.Equal(got, routeTable) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, routeTable) + if diff := cmp.Diff(routeTable, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } } @@ -1348,10 +1348,10 @@ func TestPrefixDiscovery(t *testing.T) { prefix3, subnet3, _ := prefixSubnetAddr(2, "") ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), + prefixC: make(chan ndpPrefixEvent, 1), rememberPrefix: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1361,73 +1361,68 @@ func TestPrefixDiscovery(t *testing.T) { NDPDisp: &ndpDisp, }) - waitForEvent := func(subnet tcpip.Subnet, discovered bool, timeout time.Duration) { + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { t.Helper() select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.prefix != subnet { - t.Fatalf("got r.prefix = %s, want = %s", r.prefix, subnet) + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - if r.discovered != discovered { - t.Fatalf("got r.discovered = %t, want = %t", r.discovered, discovered) - } - case <-time.After(timeout): - t.Fatal("timeout waiting for prefix discovery event") + default: + t.Fatal("expected prefix discovery event") } } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with zero valid lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly discovered a prefix with 0 lifetime") - case <-time.After(defaultTimeout): + default: } // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) - waitForEvent(subnet1, true, defaultTimeout) + expectPrefixEvent(subnet1, true) // Should have added a device route for subnet1 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + if diff := cmp.Diff([]tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}}, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } // Receive an RA with prefix2 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) - waitForEvent(subnet2, true, defaultTimeout) + expectPrefixEvent(subnet2, true) // Should have added a device route for subnet2 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + if diff := cmp.Diff([]tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}}, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } // Receive an RA with prefix3 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) - waitForEvent(subnet3, true, defaultTimeout) + expectPrefixEvent(subnet3, true) // Should have added a device route for subnet3 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + if diff := cmp.Diff([]tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } // Receive an RA with prefix1 in a PI with lifetime = 0. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) - waitForEvent(subnet1, false, defaultTimeout) + expectPrefixEvent(subnet1, false) // Should have removed the device route for subnet1 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + want := []tcpip.Route{{subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}} + if diff := cmp.Diff(want, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } // Receive an RA with prefix2 in a PI with lesser lifetime. @@ -1436,26 +1431,33 @@ func TestPrefixDiscovery(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly received prefix event when updating lifetime") - case <-time.After(defaultTimeout): + default: } // Should not have updated route table. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + if diff := cmp.Diff(want, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } // Wait for prefix2's most recent invalidation timer plus some buffer to // expire. - waitForEvent(subnet2, false, time.Duration(lifetime)*time.Second+defaultTimeout) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet2, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(lifetime)*time.Second + defaultTimeout): + t.Fatal("timed out waiting for prefix discovery event") + } // Should have removed the device route for subnet2 through the nic. - if got, want := s.GetRouteTable(), []tcpip.Route{{subnet3, tcpip.Address([]byte(nil)), 1}}; !cmp.Equal(got, want) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, want) + if diff := cmp.Diff([]tcpip.Route{{subnet3, tcpip.Address([]byte(nil)), 1}}, s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } // Receive RA to invalidate prefix3. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) - waitForEvent(subnet3, false, defaultTimeout) + expectPrefixEvent(subnet3, false) // Should not have any routes. if got := len(s.GetRouteTable()); got != 0 { @@ -1482,10 +1484,10 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { subnet := prefix.Subnet() ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 10), + prefixC: make(chan ndpPrefixEvent, 1), rememberPrefix: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1495,33 +1497,27 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { NDPDisp: &ndpDisp, }) - waitForEvent := func(discovered bool, timeout time.Duration) { + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) { t.Helper() select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Errorf("got r.nicID = %d, want = 1", r.nicID) - } - if r.prefix != subnet { - t.Errorf("got r.prefix = %s, want = %s", r.prefix, subnet) - } - if r.discovered != discovered { - t.Errorf("got r.discovered = %t, want = %t", r.discovered, discovered) + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefix, discovered); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): - t.Fatal("timeout waiting for prefix discovery event") + default: + t.Fatal("expected prefix discovery event") } } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Receive an RA with prefix in an NDP Prefix Information option (PI) // with infinite valid lifetime which should not get invalidated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - waitForEvent(true, defaultTimeout) + expectPrefixEvent(subnet, true) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") @@ -1531,11 +1527,18 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with finite lifetime. // The prefix should get invalidated after 1s. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) - waitForEvent(false, testInfiniteLifetime) + select { + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, subnet, false); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) + } + case <-time.After(testInfiniteLifetime): + t.Fatal("timed out waiting for prefix discovery event") + } // Receive an RA with finite lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) - waitForEvent(true, defaultTimeout) + expectPrefixEvent(subnet, true) // Receive an RA with prefix with an infinite lifetime. // The prefix should not be invalidated. @@ -1558,7 +1561,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with 0 lifetime. // The prefix should get invalidated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0)) - waitForEvent(false, defaultTimeout) + expectPrefixEvent(subnet, false) } // TestPrefixDiscoveryMaxRouters tests that only @@ -1570,7 +1573,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), rememberPrefix: true, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1616,32 +1619,26 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ { if i < stack.MaxDiscoveredOnLinkPrefixes { select { - case r := <-ndpDisp.prefixC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.prefix != prefixes[i] { - t.Fatalf("got r.prefix = %s, want = %s", r.prefix, prefixes[i]) + case e := <-ndpDisp.prefixC: + if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" { + t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - if !r.discovered { - t.Fatal("got r.discovered = false, want = true") - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for prefix discovery event") + default: + t.Fatal("expected prefix discovery event") } } else { select { case <-ndpDisp.prefixC: t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes") - case <-time.After(defaultTimeout): + default: } } } // Should only have device routes for the first // stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes. - if got := s.GetRouteTable(); !cmp.Equal(got, expectedRt[:]) { - t.Fatalf("got GetRouteTable = %v, want = %v", got, expectedRt) + if diff := cmp.Diff(expectedRt[:], s.GetRouteTable()); diff != "" { + t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) } } @@ -1663,8 +1660,6 @@ func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { // TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. func TestNoAutoGenAddr(t *testing.T) { - t.Parallel() - prefix, _, _ := prefixSubnetAddr(0, "") // Being configured to auto-generate addresses means handle and @@ -1682,9 +1677,9 @@ func TestNoAutoGenAddr(t *testing.T) { t.Parallel() ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1705,12 +1700,18 @@ func TestNoAutoGenAddr(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly auto-generated an address when configured not to") - case <-time.After(defaultTimeout): + default: } }) } } +// Check e to make sure that the event is for addr on nic with ID 1, and the +// event type is set to eventType. +func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string { + return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) +} + // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. func TestAutoGenAddr(t *testing.T) { @@ -1726,9 +1727,9 @@ func TestAutoGenAddr(t *testing.T) { prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1738,42 +1739,36 @@ func TestAutoGenAddr(t *testing.T) { NDPDisp: &ndpDisp, }) - waitForEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) - } - if r.eventType != eventType { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, eventType) - } - case <-time.After(timeout): - t.Fatal("timeout waiting for addr auto gen event") + default: + t.Fatal("expected addr auto gen event") } } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with zero valid lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly auto-generated an address with 0 lifetime") - case <-time.After(defaultTimeout): + default: } // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - waitForEvent(addr1, newAddr, defaultTimeout) + expectAutoGenAddrEvent(addr1, newAddr) if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should have %s in the list of addresses", addr1) } @@ -1784,12 +1779,12 @@ func TestAutoGenAddr(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime") - case <-time.After(defaultTimeout): + default: } // Receive an RA with prefix2 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - waitForEvent(addr2, newAddr, defaultTimeout) + expectAutoGenAddrEvent(addr2, newAddr) if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should have %s in the list of addresses", addr1) } @@ -1802,11 +1797,18 @@ func TestAutoGenAddr(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") - case <-time.After(defaultTimeout): + default: } // Wait for addr of prefix1 to be invalidated. - waitForEvent(addr1, invalidatedAddr, newMinVLDuration+defaultTimeout) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(newMinVLDuration + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } if contains(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should not have %s in the list of addresses", addr1) } @@ -1896,78 +1898,81 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { const delta = 500 * time.Millisecond - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - }) + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + t.Run(test.name, func(t *testing.T) { + t.Parallel() - // Receive an RA with prefix with initial VL, test.ovl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) - select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), } - if r.eventType != newAddr { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) + e := channel.New(10, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for addr auto gen event") - } - // Receive an new RA with prefix with new VL, test.nvl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) + // Receive an RA with prefix with initial VL, + // test.ovl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } - // - // Validate that the VL for the address got set to - // test.evl. - // + // Receive an new RA with prefix with new VL, + // test.nvl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) - // Make sure we do not get any invalidation events - // until atleast 500ms (delta) before test.evl. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - delta): - } + // + // Validate that the VL for the address got set + // to test.evl. + // - // Wait for another second (2x delta), but now we expect - // the invalidation event. - select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) + // Make sure we do not get any invalidation + // events until atleast 500ms (delta) before + // test.evl. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly received an auto gen addr event") + case <-time.After(time.Duration(test.evl)*time.Second - delta): } - if r.eventType != invalidatedAddr { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) + + // Wait for another second (2x delta), but now + // we expect the invalidation event. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + case <-time.After(2 * delta): + t.Fatal("timeout waiting for addr auto gen event") } - case <-time.After(2 * delta): - t.Fatal("timeout waiting for addr auto gen event") - } - }) - } + }) + } + }) } // TestAutoGenAddrRemoval tests that when auto-generated addresses are removed @@ -1979,9 +1984,9 @@ func TestAutoGenAddrRemoval(t *testing.T) { prefix, _, addr := prefixSubnetAddr(0, linkAddr1) ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -1995,51 +2000,37 @@ func TestAutoGenAddrRemoval(t *testing.T) { t.Fatalf("CreateNIC(1) = %s", err) } - // Receive an RA with prefix with its valid lifetime = lifetime. - const lifetime = 5 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetime, 0)) - select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) - } - if r.eventType != newAddr { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for addr auto gen event") } - // Remove the address. + // Receive a PI to auto-generate an address. + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr, newAddr) + + // Removing the address should result in an invalidation event + // immediately. if err := s.RemoveAddress(1, addr.Address); err != nil { t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err) } - - // Should get the invalidation event immediately. - select { - case r := <-ndpDisp.autoGenAddrC: - if r.nicID != 1 { - t.Fatalf("got r.nicID = %d, want = 1", r.nicID) - } - if r.addr != addr { - t.Fatalf("got r.addr = %s, want = %s", r.addr, addr) - } - if r.eventType != invalidatedAddr { - t.Fatalf("got r.eventType = %v, want = %v", r.eventType, newAddr) - } - case <-time.After(defaultTimeout): - t.Fatal("timeout waiting for addr auto gen event") - } + expectAutoGenAddrEvent(addr, invalidatedAddr) // Wait for the original valid lifetime to make sure the original timer // got stopped/cleaned up. select { case <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly received an auto gen addr event") - case <-time.After(lifetime*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } } @@ -2051,9 +2042,9 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { prefix, _, addr := prefixSubnetAddr(0, linkAddr1) ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NDPConfigs: stack.NDPConfigurations{ @@ -2077,12 +2068,12 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { // Receive a PI where the generated address will be the same as the one // that we already have assigned statically. - const lifetime = 5 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetime, 0)) + const lifetimeSeconds = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") - case <-time.After(defaultTimeout): + default: } if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -2093,7 +2084,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetime*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) -- cgit v1.2.3 From 8e6e87f8e8885eeadb8b3d891e24137f11ebdf31 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 18 Dec 2019 15:12:33 -0800 Subject: Allow 'out-of-line' routing table updates for Router and Prefix discovery events This change removes the requirement that a new routing table be provided when a router or prefix discovery event happens so that an updated routing table may be provided to the stack at a later time from the event. This change is to address the use case where the netstack integrator may need to obtain a lock before providing updated routes in response to the events above. As an example, say we have an integrator that performs the below two operations operations as described: A. Normal route update: 1. Obtain integrator lock 2. Update routes in the integrator 3. Call Stack.SetRouteTable with the updated routes 3.1. Obtain Stack lock 3.2. Update routes in Stack 3.3. Release Stack lock 4. Release integrator lock B. NDP event triggered route update: 1. Obtain Stack lock 2. Call event handler 2.1. Obtain integrator lock 2.2. Update routes in the integrator 2.3. Release integrator lock 2.4. Return updated routes to update Stack 3. Update routes in Stack 4. Release Stack lock A deadlock may occur if a Normal route update was attemped at the same time an NDP event triggered route update was attempted. With threads T1 and T2: 1) T1 -> A.1, A.2 2) T2 -> B.1 3) T1 -> A.3 (hangs at A.3.1 since Stack lock is taken in step 2) 4) T2 -> B.2 (hangs at B.2.1 since integrator lock is taken in step 1) Test: Existing tests were modified to not provide or expect routing table changes in response to Router and Prefix discovery events. PiperOrigin-RevId: 286274712 --- pkg/tcpip/stack/ndp.go | 69 ++++++-------- pkg/tcpip/stack/ndp_test.go | 223 ++++---------------------------------------- 2 files changed, 50 insertions(+), 242 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 27bd02e76..90664ba8a 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -131,40 +131,34 @@ type NDPDispatcher interface { OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) // OnDefaultRouterDiscovered will be called when a new default router is - // discovered. Implementations must return true along with a new valid - // route table if the newly discovered router should be remembered. If - // an implementation returns false, the second return value will be - // ignored. + // discovered. Implementations must return true if the newly discovered + // router should be remembered. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) + OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool // OnDefaultRouterInvalidated will be called when a discovered default - // router is invalidated. Implementers must return a new valid route - // table. + // router that was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route + OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is - // discovered. Implementations must return true along with a new valid - // route table if the newly discovered on-link prefix should be - // remembered. If an implementation returns false, the second return - // value will be ignored. + // discovered. Implementations must return true if the newly discovered + // on-link prefix should be remembered. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) (bool, []tcpip.Route) + OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool // OnOnLinkPrefixInvalidated will be called when a discovered on-link - // prefix is invalidated. Implementers must return a new valid route - // table. + // prefix that was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) []tcpip.Route + OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) // OnAutoGenAddress will be called when a new prefix with its // autonomous address-configuration flag set has been received and SLAAC @@ -668,7 +662,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // invalidateDefaultRouter invalidates a discovered default router. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { rtr, ok := ndp.defaultRouters[ip] @@ -686,8 +680,8 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. - if ndp.nic.stack.ndpDisp != nil { - ndp.nic.stack.routeTable = ndp.nic.stack.ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip) } } @@ -696,15 +690,15 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { // // The router identified by ip MUST NOT already be known by the NIC. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { - if ndp.nic.stack.ndpDisp == nil { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { return } // Inform the integrator when we discovered a default router. - remember, routeTable := ndp.nic.stack.ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) - if !remember { + if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) { // Informed by the integrator to not remember the router, do // nothing further. return @@ -731,8 +725,6 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { }), doNotInvalidate: &doNotInvalidate, } - - ndp.nic.stack.routeTable = routeTable } // rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 @@ -740,15 +732,15 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { // // The prefix identified by prefix MUST NOT already be known. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { - if ndp.nic.stack.ndpDisp == nil { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { return } // Inform the integrator when we discovered an on-link prefix. - remember, routeTable := ndp.nic.stack.ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) - if !remember { + if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) { // Informed by the integrator to not remember the prefix, do // nothing further. return @@ -769,13 +761,11 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) invalidationTimer: timer, doNotInvalidate: &doNotInvalidate, } - - ndp.nic.stack.routeTable = routeTable } // invalidateOnLinkPrefix invalidates a discovered on-link prefix. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { s, ok := ndp.onLinkPrefixes[prefix] @@ -796,8 +786,8 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. - if ndp.nic.stack.ndpDisp != nil { - ndp.nic.stack.routeTable = ndp.nic.stack.ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix) } } @@ -829,7 +819,7 @@ func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Dur // handleOnLinkPrefixInformation assumes that the prefix this pi is for is // not the link-local prefix and the on-link flag is set. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) { prefix := pi.Subnet() prefixState, ok := ndp.onLinkPrefixes[prefix] @@ -1066,10 +1056,11 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform } // Inform the integrator that we have a new SLAAC address. - if ndp.nic.stack.ndpDisp == nil { + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { return } - if !ndp.nic.stack.ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) { + if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) { // Informed by the integrator not to add the address. return } @@ -1135,8 +1126,8 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo delete(ndp.autoGenAddresses, addr) - if ndp.nic.stack.ndpDisp != nil { - ndp.nic.stack.ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ Address: addr, PrefixLen: validPrefixLenForAutoGen, }) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 9f589a471..666f86c33 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -163,7 +163,6 @@ type ndpDispatcher struct { rememberPrefix bool autoGenAddrC chan ndpAutoGenAddrEvent rdnssC chan ndpRDNSSEvent - routeTable []tcpip.Route } // Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. @@ -179,106 +178,56 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, add } // Implements stack.NDPDispatcher.OnDefaultRouterDiscovered. -func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) (bool, []tcpip.Route) { - if n.routerC != nil { - n.routerC <- ndpRouterEvent{ +func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { + if c := n.routerC; c != nil { + c <- ndpRouterEvent{ nicID, addr, true, } } - if !n.rememberRouter { - return false, nil - } - - rt := append([]tcpip.Route(nil), n.routeTable...) - rt = append(rt, tcpip.Route{ - Destination: header.IPv6EmptySubnet, - Gateway: addr, - NIC: nicID, - }) - n.routeTable = rt - return true, rt + return n.rememberRouter } // Implements stack.NDPDispatcher.OnDefaultRouterInvalidated. -func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) []tcpip.Route { - if n.routerC != nil { - n.routerC <- ndpRouterEvent{ +func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { + if c := n.routerC; c != nil { + c <- ndpRouterEvent{ nicID, addr, false, } } - - var rt []tcpip.Route - exclude := tcpip.Route{ - Destination: header.IPv6EmptySubnet, - Gateway: addr, - NIC: nicID, - } - - for _, r := range n.routeTable { - if r != exclude { - rt = append(rt, r) - } - } - n.routeTable = rt - return rt } // Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered. -func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) (bool, []tcpip.Route) { - if n.prefixC != nil { - n.prefixC <- ndpPrefixEvent{ +func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { + if c := n.prefixC; c != nil { + c <- ndpPrefixEvent{ nicID, prefix, true, } } - if !n.rememberPrefix { - return false, nil - } - - rt := append([]tcpip.Route(nil), n.routeTable...) - rt = append(rt, tcpip.Route{ - Destination: prefix, - NIC: nicID, - }) - n.routeTable = rt - return true, rt + return n.rememberPrefix } // Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated. -func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) []tcpip.Route { - if n.prefixC != nil { - n.prefixC <- ndpPrefixEvent{ +func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) { + if c := n.prefixC; c != nil { + c <- ndpPrefixEvent{ nicID, prefix, false, } } - - var rt []tcpip.Route - exclude := tcpip.Route{ - Destination: prefix, - NIC: nicID, - } - - for _, r := range n.routeTable { - if r != exclude { - rt = append(rt, r) - } - } - n.routeTable = rt - return rt } func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { - if n.autoGenAddrC != nil { - n.autoGenAddrC <- ndpAutoGenAddrEvent{ + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ nicID, addr, newAddr, @@ -288,8 +237,8 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi } func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { - if n.autoGenAddrC != nil { - n.autoGenAddrC <- ndpAutoGenAddrEvent{ + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ nicID, addr, invalidatedAddr, @@ -299,8 +248,8 @@ func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpi // Implements stack.NDPDispatcher.OnRecursiveDNSServerOption. func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) { - if n.rdnssC != nil { - n.rdnssC <- ndpRDNSSEvent{ + if c := n.rdnssC; c != nil { + c <- ndpRDNSSEvent{ nicID, ndpRDNSS{ addrs, @@ -976,15 +925,6 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { t.Fatalf("CreateNIC(1) = %s", err) } - routeTable := []tcpip.Route{ - { - header.IPv6EmptySubnet, - llAddr3, - 1, - }, - } - s.SetRouteTable(routeTable) - // Receive an RA for a router we should not remember. const lifetimeSeconds = 1 e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) @@ -997,11 +937,6 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { t.Fatal("expected router discovery event") } - // Original route table should not have been modified. - if diff := cmp.Diff(routeTable, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } - // Wait for the invalidation time plus some buffer to make sure we do // not actually receive any invalidation events as we should not have // remembered the router in the first place. @@ -1010,11 +945,6 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { t.Fatal("should not have received any router events") case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } - - // Original route table should not have been modified. - if diff := cmp.Diff(routeTable, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } } func TestRouterDiscovery(t *testing.T) { @@ -1077,22 +1007,11 @@ func TestRouterDiscovery(t *testing.T) { e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) expectRouterEvent(llAddr2, true) - // Should have a default route through the discovered router. - if diff := cmp.Diff([]tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}}, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } - // Rx an RA from another router (lladdr3) with non-zero lifetime. l3Lifetime := time.Duration(6) e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime))) expectRouterEvent(llAddr3, true) - // Should have default routes through the discovered routers. - want := []tcpip.Route{{header.IPv6EmptySubnet, llAddr2, 1}, {header.IPv6EmptySubnet, llAddr3, 1}} - if diff := cmp.Diff(want, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } - // Rx an RA from lladdr2 with lesser lifetime. l2Lifetime := time.Duration(2) e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime))) @@ -1102,11 +1021,6 @@ func TestRouterDiscovery(t *testing.T) { default: } - // Should still have a default route through the discovered routers. - if diff := cmp.Diff(want, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } - // Wait for lladdr2's router invalidation timer to fire. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. @@ -1116,30 +1030,14 @@ func TestRouterDiscovery(t *testing.T) { // event after this time, then something is wrong. expectAsyncRouterInvalidationEvent(llAddr2, l2Lifetime*time.Second+defaultTimeout) - // Should no longer have the default route through lladdr2. - if diff := cmp.Diff([]tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } - // Rx an RA from lladdr2 with huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) expectRouterEvent(llAddr2, true) - // Should have a default route through the discovered routers. - if diff := cmp.Diff([]tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}, {header.IPv6EmptySubnet, llAddr2, 1}}, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } - // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) expectRouterEvent(llAddr2, false) - // Should have deleted the default route through the router that just - // got invalidated. - if diff := cmp.Diff([]tcpip.Route{{header.IPv6EmptySubnet, llAddr3, 1}}, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } - // Wait for lladdr3's router invalidation timer to fire. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. @@ -1148,12 +1046,6 @@ func TestRouterDiscovery(t *testing.T) { // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. expectAsyncRouterInvalidationEvent(llAddr3, l3Lifetime*time.Second+defaultTimeout) - - // Should not have any routes now that all discovered routers have been - // invalidated. - if got := len(s.GetRouteTable()); got != 0 { - t.Fatalf("got len(s.GetRouteTable()) = %d, want = 0", got) - } } // TestRouterDiscoveryMaxRouters tests that only @@ -1179,8 +1071,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { t.Fatalf("CreateNIC(1) = %s", err) } - expectedRt := [stack.MaxDiscoveredDefaultRouters]tcpip.Route{} - // Receive an RA from 2 more than the max number of discovered routers. for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ { linkAddr := []byte{2, 2, 3, 4, 5, 0} @@ -1190,7 +1080,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) if i <= stack.MaxDiscoveredDefaultRouters { - expectedRt[i-1] = tcpip.Route{header.IPv6EmptySubnet, llAddr, 1} select { case e := <-ndpDisp.routerC: if diff := checkRouterEvent(e, llAddr, true); diff != "" { @@ -1208,12 +1097,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } } } - - // Should only have default routes for the first - // stack.MaxDiscoveredDefaultRouters discovered routers. - if diff := cmp.Diff(expectedRt[:], s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } } // TestNoPrefixDiscovery tests that prefix discovery will not be performed if @@ -1299,15 +1182,6 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { t.Fatalf("CreateNIC(1) = %s", err) } - routeTable := []tcpip.Route{ - { - header.IPv6EmptySubnet, - llAddr3, - 1, - }, - } - s.SetRouteTable(routeTable) - // Receive an RA with prefix that we should not remember. const lifetimeSeconds = 1 e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) @@ -1320,11 +1194,6 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { t.Fatal("expected prefix discovery event") } - // Original route table should not have been modified. - if diff := cmp.Diff(routeTable, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } - // Wait for the invalidation time plus some buffer to make sure we do // not actually receive any invalidation events as we should not have // remembered the prefix in the first place. @@ -1333,11 +1202,6 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { t.Fatal("should not have received any prefix events") case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } - - // Original route table should not have been modified. - if diff := cmp.Diff(routeTable, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } } func TestPrefixDiscovery(t *testing.T) { @@ -1392,39 +1256,18 @@ func TestPrefixDiscovery(t *testing.T) { e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0)) expectPrefixEvent(subnet1, true) - // Should have added a device route for subnet1 through the nic. - if diff := cmp.Diff([]tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}}, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } - // Receive an RA with prefix2 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0)) expectPrefixEvent(subnet2, true) - // Should have added a device route for subnet2 through the nic. - if diff := cmp.Diff([]tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}}, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } - // Receive an RA with prefix3 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0)) expectPrefixEvent(subnet3, true) - // Should have added a device route for subnet3 through the nic. - if diff := cmp.Diff([]tcpip.Route{{subnet1, tcpip.Address([]byte(nil)), 1}, {subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}}, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } - // Receive an RA with prefix1 in a PI with lifetime = 0. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0)) expectPrefixEvent(subnet1, false) - // Should have removed the device route for subnet1 through the nic. - want := []tcpip.Route{{subnet2, tcpip.Address([]byte(nil)), 1}, {subnet3, tcpip.Address([]byte(nil)), 1}} - if diff := cmp.Diff(want, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } - // Receive an RA with prefix2 in a PI with lesser lifetime. lifetime := uint32(2) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0)) @@ -1434,11 +1277,6 @@ func TestPrefixDiscovery(t *testing.T) { default: } - // Should not have updated route table. - if diff := cmp.Diff(want, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } - // Wait for prefix2's most recent invalidation timer plus some buffer to // expire. select { @@ -1450,19 +1288,9 @@ func TestPrefixDiscovery(t *testing.T) { t.Fatal("timed out waiting for prefix discovery event") } - // Should have removed the device route for subnet2 through the nic. - if diff := cmp.Diff([]tcpip.Route{{subnet3, tcpip.Address([]byte(nil)), 1}}, s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } - // Receive RA to invalidate prefix3. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0)) expectPrefixEvent(subnet3, false) - - // Should not have any routes. - if got := len(s.GetRouteTable()); got != 0 { - t.Fatalf("got len(s.GetRouteTable()) = %d, want = 0", got) - } } func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { @@ -1589,7 +1417,6 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { } optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2) - expectedRt := [stack.MaxDiscoveredOnLinkPrefixes]tcpip.Route{} prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{} // Receive an RA with 2 more than the max number of discovered on-link @@ -1609,10 +1436,6 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { copy(buf[14:], prefix.Address) optSer[i] = header.NDPPrefixInformation(buf[:]) - - if i < stack.MaxDiscoveredOnLinkPrefixes { - expectedRt[i] = tcpip.Route{prefixes[i], tcpip.Address([]byte(nil)), 1} - } } e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) @@ -1634,12 +1457,6 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { } } } - - // Should only have device routes for the first - // stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes. - if diff := cmp.Diff(expectedRt[:], s.GetRouteTable()); diff != "" { - t.Fatalf("GetRouteTable() mismatch (-want +got):\n%s", diff) - } } // Checks to see if list contains an IPv6 address, item. -- cgit v1.2.3 From 5bc4ae9d5746e65909a0bdab60e7bd598d4401c7 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Mon, 23 Dec 2019 08:53:57 -0800 Subject: Clear any host-specific NDP state when becoming a router This change supports clearing all host-only NDP state when NICs become routers. All discovered routers, discovered on-link prefixes and auto-generated addresses will be invalidated when becoming a router. This is because normally, routers do not process Router Advertisements to discover routers or on-link prefixes, and do not do SLAAC. Tests: Unittest to make sure that all discovered routers, discovered prefixes and auto-generated addresses get invalidated when transitioning from a host to a router. PiperOrigin-RevId: 286902309 --- pkg/tcpip/stack/ndp.go | 35 ++++++ pkg/tcpip/stack/ndp_test.go | 283 ++++++++++++++++++++++++++++++++++++++++++-- pkg/tcpip/stack/nic.go | 12 ++ pkg/tcpip/stack/stack.go | 22 +++- 4 files changed, 338 insertions(+), 14 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 90664ba8a..d9ab59336 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -1155,3 +1155,38 @@ func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Du ndp.invalidateAutoGenAddress(addr) }) } + +// cleanupHostOnlyState cleans up any state that is only useful for hosts. +// +// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a +// host to a router. This function will invalidate all discovered on-link +// prefixes, discovered routers, and auto-generated addresses as routers do not +// normally process Router Advertisements to discover default routers and +// on-link prefixes, and auto-generate addresses via SLAAC. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupHostOnlyState() { + for addr, _ := range ndp.autoGenAddresses { + ndp.invalidateAutoGenAddress(addr) + } + + if got := len(ndp.autoGenAddresses); got != 0 { + log.Fatalf("ndp: still have auto-generated addresses after cleaning up, found = %d", got) + } + + for prefix, _ := range ndp.onLinkPrefixes { + ndp.invalidateOnLinkPrefix(prefix) + } + + if got := len(ndp.onLinkPrefixes); got != 0 { + log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up, found = %d", got) + } + + for router, _ := range ndp.defaultRouters { + ndp.invalidateDefaultRouter(router) + } + + if got := len(ndp.defaultRouters); got != 0 { + log.Fatalf("ndp: still have discovered default routers after cleaning up, found = %d", got) + } +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 666f86c33..64a9a2b20 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -47,6 +47,19 @@ var ( llAddr3 = header.LinkLocalAddr(linkAddr3) ) +func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix { + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return tcpip.AddressWithPrefix{} + } + + addrBytes := []byte(subnet.ID()) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: 64, + } +} + // prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent // tcpip.Subnet, and an address where the lower half of the address is composed // of the EUI-64 of linkAddr if it is a valid unicast ethernet address. @@ -59,17 +72,7 @@ func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWi subnet := prefix.Subnet() - var addr tcpip.AddressWithPrefix - if header.IsValidUnicastEthernetAddress(linkAddr) { - addrBytes := []byte(subnet.ID()) - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) - addr = tcpip.AddressWithPrefix{ - Address: tcpip.Address(addrBytes), - PrefixLen: 64, - } - } - - return prefix, subnet, addr + return prefix, subnet, addrForSubnet(subnet, linkAddr) } // TestDADDisabled tests that an address successfully resolves immediately @@ -1772,7 +1775,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { // test.evl. select { case <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly received an auto gen addr event") + t.Fatal("unexpectedly received an auto gen addr event") case <-time.After(time.Duration(test.evl)*time.Second - delta): } @@ -1846,7 +1849,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { // got stopped/cleaned up. select { case <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly received an auto gen addr event") + t.Fatal("unexpectedly received an auto gen addr event") case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } } @@ -2055,3 +2058,257 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { }) } } + +// TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers +// and prefixes, and auto-generated addresses get invalidated when a NIC +// becomes a router. +func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { + t.Parallel() + + const ( + lifetimeSeconds = 5 + maxEvents = 4 + nicID1 = 1 + nicID2 = 2 + ) + + prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1) + e2Addr1 := addrForSubnet(subnet1, linkAddr2) + e2Addr2 := addrForSubnet(subnet2, linkAddr2) + + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, maxEvents), + rememberRouter: true, + prefixC: make(chan ndpPrefixEvent, maxEvents), + rememberPrefix: true, + autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents), + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + e1 := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) + } + + e2 := channel.New(0, 1280, linkAddr2) + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) + } + + expectRouterEvent := func() (bool, ndpRouterEvent) { + select { + case e := <-ndpDisp.routerC: + return true, e + default: + } + + return false, ndpRouterEvent{} + } + + expectPrefixEvent := func() (bool, ndpPrefixEvent) { + select { + case e := <-ndpDisp.prefixC: + return true, e + default: + } + + return false, ndpPrefixEvent{} + } + + expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { + select { + case e := <-ndpDisp.autoGenAddrC: + return true, e + default: + } + + return false, ndpAutoGenAddrEvent{} + } + + // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr1 and + // llAddr2) w/ PI (for prefix1 in RA from llAddr1 and prefix2 in RA from + // llAddr2) to discover multiple routers and prefixes, and auto-gen + // multiple addresses. + + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr1, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + // We have other tests that make sure we receive the *correct* events + // on normal discovery of routers/prefixes, and auto-generated + // addresses. Here we just make sure we get an event and let other tests + // handle the correctness check. + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr1, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) + } + + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr2, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) + } + + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr1, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr1, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) + } + + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr2, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) + } + + // We should have the auto-generated addresses added. + nicinfo := s.NICInfo() + nic1Addrs := nicinfo[nicID1].ProtocolAddresses + nic2Addrs := nicinfo[nicID2].ProtocolAddresses + if !contains(nic1Addrs, e1Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if !contains(nic1Addrs, e1Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if !contains(nic2Addrs, e2Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if !contains(nic2Addrs, e2Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } + + // We can't proceed any further if we already failed the test (missing + // some discovery/auto-generated address events or addresses). + if t.Failed() { + t.FailNow() + } + + s.SetForwarding(true) + + // Collect invalidation events after becoming a router + gotRouterEvents := make(map[ndpRouterEvent]int) + for i := 0; i < maxEvents; i++ { + ok, e := expectRouterEvent() + if !ok { + t.Errorf("expected %d router events after becoming a router; got = %d", maxEvents, i) + break + } + gotRouterEvents[e]++ + } + gotPrefixEvents := make(map[ndpPrefixEvent]int) + for i := 0; i < maxEvents; i++ { + ok, e := expectPrefixEvent() + if !ok { + t.Errorf("expected %d prefix events after becoming a router; got = %d", maxEvents, i) + break + } + gotPrefixEvents[e]++ + } + gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) + for i := 0; i < maxEvents; i++ { + ok, e := expectAutoGenAddrEvent() + if !ok { + t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", maxEvents, i) + break + } + gotAutoGenAddrEvents[e]++ + } + + // No need to proceed any further if we already failed the test (missing + // some invalidation events). + if t.Failed() { + t.FailNow() + } + + expectedRouterEvents := map[ndpRouterEvent]int{ + {nicID: nicID1, addr: llAddr1, discovered: false}: 1, + {nicID: nicID1, addr: llAddr2, discovered: false}: 1, + {nicID: nicID2, addr: llAddr1, discovered: false}: 1, + {nicID: nicID2, addr: llAddr2, discovered: false}: 1, + } + if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { + t.Errorf("router events mismatch (-want +got):\n%s", diff) + } + expectedPrefixEvents := map[ndpPrefixEvent]int{ + {nicID: nicID1, prefix: subnet1, discovered: false}: 1, + {nicID: nicID1, prefix: subnet2, discovered: false}: 1, + {nicID: nicID2, prefix: subnet1, discovered: false}: 1, + {nicID: nicID2, prefix: subnet2, discovered: false}: 1, + } + if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { + t.Errorf("prefix events mismatch (-want +got):\n%s", diff) + } + expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ + {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, + } + if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { + t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) + } + + // Make sure the auto-generated addresses got removed. + nicinfo = s.NICInfo() + nic1Addrs = nicinfo[nicID1].ProtocolAddresses + nic2Addrs = nicinfo[nicID2].ProtocolAddresses + if contains(nic1Addrs, e1Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if contains(nic1Addrs, e1Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if contains(nic2Addrs, e2Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if contains(nic2Addrs, e2Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } + + // Should not get any more events (invalidation timers should have been + // cancelled when we transitioned into a router). + time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) + select { + case <-ndpDisp.routerC: + t.Error("unexpected router event") + default: + } + select { + case <-ndpDisp.prefixC: + t.Error("unexpected prefix event") + default: + } + select { + case <-ndpDisp.autoGenAddrC: + t.Error("unexpected auto-generated address event") + default: + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index e8401c673..ddd014658 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -203,6 +203,18 @@ func (n *NIC) enable() *tcpip.Error { return err } +// becomeIPv6Router transitions n into an IPv6 router. +// +// When transitioning into an IPv6 router, host-only state (NDP discovered +// routers, discovered on-link prefixes, and auto-generated addresses) will +// be cleaned up/invalidated. +func (n *NIC) becomeIPv6Router() { + n.mu.Lock() + defer n.mu.Unlock() + + n.ndp.cleanupHostOnlyState() +} + // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it // to start delivering packets. func (n *NIC) attachLinkEndpoint() { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0e88643a4..7a9600679 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -662,11 +662,31 @@ func (s *Stack) Stats() tcpip.Stats { } // SetForwarding enables or disables the packet forwarding between NICs. +// +// When forwarding becomes enabled, any host-only state on all NICs will be +// cleaned up. func (s *Stack) SetForwarding(enable bool) { // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. s.mu.Lock() + defer s.mu.Unlock() + + // If forwarding status didn't change, do nothing further. + if s.forwarding == enable { + return + } + s.forwarding = enable - s.mu.Unlock() + + // If this stack does not support IPv6, do nothing further. + if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok { + return + } + + if enable { + for _, nic := range s.nics { + nic.becomeIPv6Router() + } + } } // Forwarding returns if the packet forwarding between NICs is enabled. -- cgit v1.2.3 From e013c48c78c9a7daf245b7de9563e3a0bd8a1e97 Mon Sep 17 00:00:00 2001 From: Ryan Heacock Date: Tue, 24 Dec 2019 08:48:14 -0800 Subject: Enable IP_RECVTOS socket option for datagram sockets Added the ability to get/set the IP_RECVTOS socket option on UDP endpoints. If enabled, TOS from the incoming Network Header passed as ancillary data in the ControlMessages. Test: * Added unit test to udp_test.go that tests getting/setting as well as verifying that we receive expected TOS from incoming packet. * Added a syscall test PiperOrigin-RevId: 287029703 --- pkg/sentry/socket/control/control.go | 2 +- pkg/sentry/socket/netstack/netstack.go | 42 ++++++++++++++++- pkg/tcpip/checker/checker.go | 16 +++++++ pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack.go | 2 +- pkg/tcpip/tcpip.go | 6 ++- pkg/tcpip/transport/raw/endpoint.go | 2 +- pkg/tcpip/transport/udp/endpoint.go | 31 ++++++++++++- pkg/tcpip/transport/udp/udp_test.go | 69 ++++++++++++++++++++++++---- test/syscalls/linux/socket_ip_udp_generic.cc | 40 ++++++++++++++++ test/syscalls/linux/udp_socket_test_cases.cc | 8 ++-- 11 files changed, 201 insertions(+), 19 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index af1a4e95f..b649dd021 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -327,7 +327,7 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { } // PackTOS packs an IP_TOS socket control message. -func PackTOS(t *kernel.Task, tos int8, buf []byte) []byte { +func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IP, diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 140851c17..d2f263402 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1323,6 +1323,21 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } return int32(v), nil + case linux.IP_RECVTOS: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.ReceiveTOSOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + if v { + return int32(1), nil + } + return int32(0), nil + default: emitUnimplementedEventIP(t, name) } @@ -1808,6 +1823,16 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) + case linux.IP_RECVTOS: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + return syserr.TranslateNetstackError(ep.SetSockOpt( + tcpip.ReceiveTOSOption(v != 0), + )) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -1828,7 +1853,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, linux.IP_RECVORIGDSTADDR, - linux.IP_RECVTOS, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2139,6 +2163,21 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) } +func (s *SocketOperations) fillCmsgTOS(cmsg *socket.ControlMessages) { + if s.skType != linux.SOCK_DGRAM { + return + } + var receiveTOS tcpip.ReceiveTOSOption + if err := s.Endpoint.GetSockOpt(&receiveTOS); err != nil { + return + } + if !receiveTOS { + return + } + cmsg.IP.HasTOS = s.readCM.HasTOS + cmsg.IP.TOS = s.readCM.TOS +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -2244,6 +2283,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe cmsg := s.controlMessages() s.fillCmsgInq(&cmsg) + s.fillCmsgTOS(&cmsg) return n, flags, addr, addrLen, cmsg, syserr.FromError(err) } diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 2f15bf1f1..542abc99d 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -33,6 +33,9 @@ type NetworkChecker func(*testing.T, []header.Network) // TransportChecker is a function to check a property of a transport packet. type TransportChecker func(*testing.T, header.Transport) +// ControlMessagesChecker is a function to check a property of ancillary data. +type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages) + // IPv4 checks the validity and properties of the given IPv4 packet. It is // expected to be used in conjunction with other network checkers for specific // properties. For example, to check the source and destination address, one @@ -158,6 +161,19 @@ func FragmentFlags(flags uint8) NetworkChecker { } } +// ReceiveTOS creates a checker that checks the TOS field in ControlMessages. +func ReceiveTOS(want uint8) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasTOS { + t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want) + } + if got := cm.TOS; got != want { + t.Fatalf("got cm.TOS = %d, want %d", got, want) + } + } +} + // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ddd014658..a4556674b 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -575,7 +575,7 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Unlock() } -// Subnets returns the Subnets associated with this NIC. +// AddressRanges returns the Subnets associated with this NIC. func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 7a9600679..251336224 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -829,7 +829,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return false } -// NICSubnets returns a map of NICIDs to their associated subnets. +// NICAddressRanges returns a map of NICIDs to their associated subnets. func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f62fd729f..5c7b2af88 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -322,7 +322,7 @@ type ControlMessages struct { HasTOS bool // TOS is the IPv4 type of service of the associated packet. - TOS int8 + TOS uint8 // HasTClass indicates whether Tclass is valid/set. HasTClass bool @@ -666,6 +666,10 @@ type IPv4TOSOption uint8 // for all subsequent outgoing IPv6 packets from the endpoint. type IPv6TrafficClassOption uint8 +// ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS +// ancillary message is passed with incoming packets. +type ReceiveTOSOption bool + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 5aafe2615..6d23ab5a1 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -510,7 +510,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 1ac4705af..269470ed4 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -32,6 +32,7 @@ type udpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 + tos uint8 } // EndpointState represents the state of a UDP endpoint. @@ -114,6 +115,10 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 + // receiveTOS determines if the incoming IPv4 TOS header field is passed + // as ancillary data to ControlMessages on Read. + receiveTOS bool + // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -244,7 +249,12 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess *addr = p.senderAddress } - return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil + return p.data.ToView(), tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: p.timestamp, + HasTOS: e.receiveTOS, + TOS: p.tos, + }, nil } // prepareForWrite prepares the endpoint for sending data. In particular, it @@ -656,6 +666,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.sendTOS = uint8(v) e.mu.Unlock() return nil + + case tcpip.ReceiveTOSOption: + e.mu.Lock() + e.receiveTOS = bool(v) + e.mu.Unlock() + return nil } return nil } @@ -792,6 +808,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.RUnlock() return nil + case *tcpip.ReceiveTOSOption: + e.mu.RLock() + *o = tcpip.ReceiveTOSOption(e.receiveTOS) + e.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -1238,6 +1260,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.rcvList.PushBack(packet) e.rcvBufSize += pkt.Data.Size() + // Save any useful information from the NetworkHeader to the packet. + switch r.NetProto { + case header.IPv4ProtocolNumber: + // This packet has already been validated before being passed up the stack. + packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() + } + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 7051a7a9c..43b8b35ba 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -56,6 +56,7 @@ const ( multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" broadcastAddr = header.IPv4Broadcast + testTOS = 0x80 // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -453,6 +454,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool ip := header.IPv4(buf) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, + TOS: testTOS, TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(udp.ProtocolNumber), @@ -556,8 +558,8 @@ func TestBindToDeviceOption(t *testing.T) { // testReadInternal sends a packet of the given test flow into the stack by // injecting it into the link endpoint. It then attempts to read it from the // UDP endpoint and depending on if this was expected to succeed verifies its -// correctness. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { +// correctness including any additional checker functions provided. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { c.t.Helper() payload := newPayload() @@ -572,12 +574,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() var addr tcpip.FullAddress - v, _, err := c.ep.Read(&addr) + v, cm, err := c.ep.Read(&addr) if err == tcpip.ErrWouldBlock { // Wait for data to become available. select { case <-ch: - v, _, err = c.ep.Read(&addr) + v, cm, err = c.ep.Read(&addr) case <-time.After(300 * time.Millisecond): if packetShouldBeDropped { @@ -610,15 +612,21 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe if !bytes.Equal(payload, v) { c.t.Fatalf("bad payload: got %x, want %x", v, payload) } + + // Run any checkers against the ControlMessages. + for _, f := range checkers { + f(c.t, cm) + } + c.checkEndpointReadStats(1, epstats, err) } // testRead sends a packet of the given test flow into the stack by injecting it // into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness. -func testRead(c *testContext, flow testFlow) { +// its correctness including any additional checker functions provided. +func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) } // testFailingRead sends a packet of the given test flow into the stack by @@ -1286,7 +1294,7 @@ func TestTOSV4(t *testing.T) { c.createEndpointForFlow(flow) - const tos = 0xC0 + const tos = testTOS var v tcpip.IPv4TOSOption if err := c.ep.GetSockOpt(&v); err != nil { c.t.Errorf("GetSockopt failed: %s", err) @@ -1321,7 +1329,7 @@ func TestTOSV6(t *testing.T) { c.createEndpointForFlow(flow) - const tos = 0xC0 + const tos = testTOS var v tcpip.IPv6TrafficClassOption if err := c.ep.GetSockOpt(&v); err != nil { c.t.Errorf("GetSockopt failed: %s", err) @@ -1348,6 +1356,49 @@ func TestTOSV6(t *testing.T) { } } +func TestReceiveTOSV4(t *testing.T) { + for _, flow := range []testFlow{unicastV4, broadcast} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Verify that setting and reading the option works. + const recvTos = true + var v tcpip.ReceiveTOSOption + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + // Test for expected default value. + if v != false { + c.t.Errorf("got GetSockOpt(...) = %t, want = %t", v, false) + } + + if err := c.ep.SetSockOpt(tcpip.ReceiveTOSOption(recvTos)); err != nil { + c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.ReceiveTOSOption(recvTos), err) + } + + if err := c.ep.GetSockOpt(&v); err != nil { + c.t.Errorf("GetSockopt failed: %s", err) + } + + if want := tcpip.ReceiveTOSOption(recvTos); v != want { + c.t.Errorf("got GetSockOpt(...) = %t, want = %t", v, want) + } + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + // Verify that the correct received TOS is actually handed through as + // ancillary data to the ControlMessages struct. + testRead(c, flow, checker.ReceiveTOS(testTOS)) + }) + } +} + func TestMulticastInterfaceOption(t *testing.T) { for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 66eb68857..53290bed7 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -209,6 +209,46 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { EXPECT_EQ(get, kSockOptOn); } +// Ensure that Receiving TOS is off by default. +TEST_P(UDPSocketPairTest, RecvTosDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +// Test that setting and getting IP_RECVTOS works as expected. +TEST_P(UDPSocketPairTest, SetRecvTos) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, + &kSockOptOff, sizeof(kSockOptOff)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); + + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); +} + TEST_P(UDPSocketPairTest, ReuseAddrDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index dc35c2f50..68e0a8109 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -1349,8 +1349,9 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { // outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SetAndReceiveTOS) { - // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. - SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); + // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack. + SKIP_IF((GetParam() != AddressFamily::kIpv4) && IsRunningOnGvisor() && + !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1421,7 +1422,8 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { // TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SendAndReceiveTOS) { - // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. + // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack. + // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); -- cgit v1.2.3 From 87e4d03fdf576348ac7023c599e0fc66ad4cccbd Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 26 Dec 2019 13:04:14 -0800 Subject: Automated rollback of changelist 287029703 PiperOrigin-RevId: 287217899 --- pkg/sentry/socket/control/control.go | 2 +- pkg/sentry/socket/netstack/netstack.go | 42 +---------------- pkg/tcpip/checker/checker.go | 16 ------- pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack.go | 2 +- pkg/tcpip/tcpip.go | 6 +-- pkg/tcpip/transport/raw/endpoint.go | 2 +- pkg/tcpip/transport/udp/endpoint.go | 31 +------------ pkg/tcpip/transport/udp/udp_test.go | 69 ++++------------------------ test/syscalls/linux/socket_ip_udp_generic.cc | 40 ---------------- test/syscalls/linux/udp_socket_test_cases.cc | 8 ++-- 11 files changed, 19 insertions(+), 201 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index b649dd021..af1a4e95f 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -327,7 +327,7 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { } // PackTOS packs an IP_TOS socket control message. -func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { +func PackTOS(t *kernel.Task, tos int8, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IP, diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index d2f263402..140851c17 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1323,21 +1323,6 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } return int32(v), nil - case linux.IP_RECVTOS: - if outLen < sizeOfInt32 { - return nil, syserr.ErrInvalidArgument - } - - var v tcpip.ReceiveTOSOption - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - if v { - return int32(1), nil - } - return int32(0), nil - default: emitUnimplementedEventIP(t, name) } @@ -1823,16 +1808,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) - case linux.IP_RECVTOS: - v, err := parseIntOrChar(optVal) - if err != nil { - return err - } - - return syserr.TranslateNetstackError(ep.SetSockOpt( - tcpip.ReceiveTOSOption(v != 0), - )) - case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -1853,6 +1828,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, linux.IP_RECVORIGDSTADDR, + linux.IP_RECVTOS, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2163,21 +2139,6 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) } -func (s *SocketOperations) fillCmsgTOS(cmsg *socket.ControlMessages) { - if s.skType != linux.SOCK_DGRAM { - return - } - var receiveTOS tcpip.ReceiveTOSOption - if err := s.Endpoint.GetSockOpt(&receiveTOS); err != nil { - return - } - if !receiveTOS { - return - } - cmsg.IP.HasTOS = s.readCM.HasTOS - cmsg.IP.TOS = s.readCM.TOS -} - // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -2283,7 +2244,6 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe cmsg := s.controlMessages() s.fillCmsgInq(&cmsg) - s.fillCmsgTOS(&cmsg) return n, flags, addr, addrLen, cmsg, syserr.FromError(err) } diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 542abc99d..2f15bf1f1 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -33,9 +33,6 @@ type NetworkChecker func(*testing.T, []header.Network) // TransportChecker is a function to check a property of a transport packet. type TransportChecker func(*testing.T, header.Transport) -// ControlMessagesChecker is a function to check a property of ancillary data. -type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages) - // IPv4 checks the validity and properties of the given IPv4 packet. It is // expected to be used in conjunction with other network checkers for specific // properties. For example, to check the source and destination address, one @@ -161,19 +158,6 @@ func FragmentFlags(flags uint8) NetworkChecker { } } -// ReceiveTOS creates a checker that checks the TOS field in ControlMessages. -func ReceiveTOS(want uint8) ControlMessagesChecker { - return func(t *testing.T, cm tcpip.ControlMessages) { - t.Helper() - if !cm.HasTOS { - t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want) - } - if got := cm.TOS; got != want { - t.Fatalf("got cm.TOS = %d, want %d", got, want) - } - } -} - // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index a4556674b..ddd014658 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -575,7 +575,7 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Unlock() } -// AddressRanges returns the Subnets associated with this NIC. +// Subnets returns the Subnets associated with this NIC. func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 251336224..7a9600679 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -829,7 +829,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return false } -// NICAddressRanges returns a map of NICIDs to their associated subnets. +// NICSubnets returns a map of NICIDs to their associated subnets. func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 5c7b2af88..f62fd729f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -322,7 +322,7 @@ type ControlMessages struct { HasTOS bool // TOS is the IPv4 type of service of the associated packet. - TOS uint8 + TOS int8 // HasTClass indicates whether Tclass is valid/set. HasTClass bool @@ -666,10 +666,6 @@ type IPv4TOSOption uint8 // for all subsequent outgoing IPv6 packets from the endpoint. type IPv6TrafficClassOption uint8 -// ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS -// ancillary message is passed with incoming packets. -type ReceiveTOSOption bool - // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 6d23ab5a1..5aafe2615 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -510,7 +510,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 269470ed4..1ac4705af 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -32,7 +32,6 @@ type udpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - tos uint8 } // EndpointState represents the state of a UDP endpoint. @@ -115,10 +114,6 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 - // receiveTOS determines if the incoming IPv4 TOS header field is passed - // as ancillary data to ControlMessages on Read. - receiveTOS bool - // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -249,12 +244,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess *addr = p.senderAddress } - return p.data.ToView(), tcpip.ControlMessages{ - HasTimestamp: true, - Timestamp: p.timestamp, - HasTOS: e.receiveTOS, - TOS: p.tos, - }, nil + return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil } // prepareForWrite prepares the endpoint for sending data. In particular, it @@ -666,12 +656,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.sendTOS = uint8(v) e.mu.Unlock() return nil - - case tcpip.ReceiveTOSOption: - e.mu.Lock() - e.receiveTOS = bool(v) - e.mu.Unlock() - return nil } return nil } @@ -808,12 +792,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.RUnlock() return nil - case *tcpip.ReceiveTOSOption: - e.mu.RLock() - *o = tcpip.ReceiveTOSOption(e.receiveTOS) - e.mu.RUnlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } @@ -1260,13 +1238,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.rcvList.PushBack(packet) e.rcvBufSize += pkt.Data.Size() - // Save any useful information from the NetworkHeader to the packet. - switch r.NetProto { - case header.IPv4ProtocolNumber: - // This packet has already been validated before being passed up the stack. - packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() - } - packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 43b8b35ba..7051a7a9c 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -56,7 +56,6 @@ const ( multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" broadcastAddr = header.IPv4Broadcast - testTOS = 0x80 // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -454,7 +453,6 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool ip := header.IPv4(buf) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, - TOS: testTOS, TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(udp.ProtocolNumber), @@ -558,8 +556,8 @@ func TestBindToDeviceOption(t *testing.T) { // testReadInternal sends a packet of the given test flow into the stack by // injecting it into the link endpoint. It then attempts to read it from the // UDP endpoint and depending on if this was expected to succeed verifies its -// correctness including any additional checker functions provided. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { +// correctness. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { c.t.Helper() payload := newPayload() @@ -574,12 +572,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() var addr tcpip.FullAddress - v, cm, err := c.ep.Read(&addr) + v, _, err := c.ep.Read(&addr) if err == tcpip.ErrWouldBlock { // Wait for data to become available. select { case <-ch: - v, cm, err = c.ep.Read(&addr) + v, _, err = c.ep.Read(&addr) case <-time.After(300 * time.Millisecond): if packetShouldBeDropped { @@ -612,21 +610,15 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe if !bytes.Equal(payload, v) { c.t.Fatalf("bad payload: got %x, want %x", v, payload) } - - // Run any checkers against the ControlMessages. - for _, f := range checkers { - f(c.t, cm) - } - c.checkEndpointReadStats(1, epstats, err) } // testRead sends a packet of the given test flow into the stack by injecting it // into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness including any additional checker functions provided. -func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { +// its correctness. +func testRead(c *testContext, flow testFlow) { c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) } // testFailingRead sends a packet of the given test flow into the stack by @@ -1294,7 +1286,7 @@ func TestTOSV4(t *testing.T) { c.createEndpointForFlow(flow) - const tos = testTOS + const tos = 0xC0 var v tcpip.IPv4TOSOption if err := c.ep.GetSockOpt(&v); err != nil { c.t.Errorf("GetSockopt failed: %s", err) @@ -1329,7 +1321,7 @@ func TestTOSV6(t *testing.T) { c.createEndpointForFlow(flow) - const tos = testTOS + const tos = 0xC0 var v tcpip.IPv6TrafficClassOption if err := c.ep.GetSockOpt(&v); err != nil { c.t.Errorf("GetSockopt failed: %s", err) @@ -1356,49 +1348,6 @@ func TestTOSV6(t *testing.T) { } } -func TestReceiveTOSV4(t *testing.T) { - for _, flow := range []testFlow{unicastV4, broadcast} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpointForFlow(flow) - - // Verify that setting and reading the option works. - const recvTos = true - var v tcpip.ReceiveTOSOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) - } - // Test for expected default value. - if v != false { - c.t.Errorf("got GetSockOpt(...) = %t, want = %t", v, false) - } - - if err := c.ep.SetSockOpt(tcpip.ReceiveTOSOption(recvTos)); err != nil { - c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.ReceiveTOSOption(recvTos), err) - } - - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) - } - - if want := tcpip.ReceiveTOSOption(recvTos); v != want { - c.t.Errorf("got GetSockOpt(...) = %t, want = %t", v, want) - } - - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - // Verify that the correct received TOS is actually handed through as - // ancillary data to the ControlMessages struct. - testRead(c, flow, checker.ReceiveTOS(testTOS)) - }) - } -} - func TestMulticastInterfaceOption(t *testing.T) { for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 53290bed7..66eb68857 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -209,46 +209,6 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { EXPECT_EQ(get, kSockOptOn); } -// Ensure that Receiving TOS is off by default. -TEST_P(UDPSocketPairTest, RecvTosDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -// Test that setting and getting IP_RECVTOS works as expected. -TEST_P(UDPSocketPairTest, SetRecvTos) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); -} - TEST_P(UDPSocketPairTest, ReuseAddrDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index 68e0a8109..dc35c2f50 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -1349,9 +1349,8 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { // outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SetAndReceiveTOS) { - // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack. - SKIP_IF((GetParam() != AddressFamily::kIpv4) && IsRunningOnGvisor() && - !IsRunningWithHostinet()); + // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. + SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1422,8 +1421,7 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { // TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SendAndReceiveTOS) { - // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack. - // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. + // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); -- cgit v1.2.3 From d1d878a801e066d6a54838ac3b2cdb43d65743e1 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 3 Jan 2020 12:58:40 -0800 Subject: Support generating opaque interface identifiers as defined by RFC 7217 Support generating opaque interface identifiers as defined by RFC 7217 for auto-generated IPv6 link-local addresses. Opaque interface identifiers will also be used for IPv6 addresses auto-generated via SLAAC in a later change. Note, this change does not handle retries in response to DAD conflicts yet. That will also come in a later change. Tests: Test that when configured to generated opaque IIDs, they are properly generated as outlined by RFC 7217. PiperOrigin-RevId: 288035349 --- pkg/tcpip/header/BUILD | 1 + pkg/tcpip/header/ipv6.go | 45 ++++++++++++ pkg/tcpip/header/ipv6_test.go | 163 ++++++++++++++++++++++++++++++++++++++++++ pkg/tcpip/stack/BUILD | 1 + pkg/tcpip/stack/nic.go | 26 ++++--- pkg/tcpip/stack/stack.go | 36 ++++++++++ pkg/tcpip/stack/stack_test.go | 127 ++++++++++++++++++++++++++++++-- 7 files changed, 384 insertions(+), 15 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index f1d837196..f2061c778 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -44,6 +44,7 @@ go_test( ], deps = [ ":header", + "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", "@com_github_google_go-cmp//cmp:go_default_library", diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index fc671e439..135a60b12 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -15,6 +15,7 @@ package header import ( + "crypto/sha256" "encoding/binary" "strings" @@ -102,6 +103,11 @@ const ( // bytes including and after the IIDOffsetInIPv6Address-th byte are // for the IID. IIDOffsetInIPv6Address = 8 + + // OpaqueIIDSecretKeyMinBytes is the recommended minimum number of bytes + // for the secret key used to generate an opaque interface identifier as + // outlined by RFC 7217. + OpaqueIIDSecretKeyMinBytes = 16 ) // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the @@ -326,3 +332,42 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool { } return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } + +// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier +// (IID) to buf as outlined by RFC 7217 and returns the extended buffer. +// +// The opaque IID is generated from the cryptographic hash of the concatenation +// of the prefix, NIC's name, DAD counter (DAD retry counter) and the secret +// key. The secret key SHOULD be at least OpaqueIIDSecretKeyMinBytes bytes and +// MUST be generated to a pseudo-random number. See RFC 4086 for randomness +// requirements for security. +// +// If buf has enough capacity for the IID (IIDSize bytes), a new underlying +// array for the buffer will not be allocated. +func AppendOpaqueInterfaceIdentifier(buf []byte, prefix tcpip.Subnet, nicName string, dadCounter uint8, secretKey []byte) []byte { + // As per RFC 7217 section 5, the opaque identifier can be generated as a + // cryptographic hash of the concatenation of each of the function parameters. + // Note, we omit the optional Network_ID field. + h := sha256.New() + // h.Write never returns an error. + h.Write([]byte(prefix.ID()[:IIDOffsetInIPv6Address])) + h.Write([]byte(nicName)) + h.Write([]byte{dadCounter}) + h.Write(secretKey) + + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + return append(buf, sum[:IIDSize]...) +} + +// LinkLocalAddrWithOpaqueIID computes the default IPv6 link-local address with +// an opaque IID. +func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []byte) tcpip.Address { + lladdrb := [IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, + } + + return tcpip.Address(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey)) +} diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index 42c5c6fc1..cd1862e42 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -15,9 +15,12 @@ package header_test import ( + "bytes" + "crypto/sha256" "testing" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -43,3 +46,163 @@ func TestLinkLocalAddr(t *testing.T) { t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want) } } + +func TestAppendOpaqueInterfaceIdentifier(t *testing.T) { + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte + if n, err := rand.Read(secretKeyBuf[:]); err != nil { + t.Fatalf("rand.Read(_): %s", err) + } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) + } + + tests := []struct { + name string + prefix tcpip.Subnet + nicName string + dadCounter uint8 + secretKey []byte + }{ + { + name: "SecretKey of minimum size", + prefix: header.IPv6LinkLocalPrefix.Subnet(), + nicName: "eth0", + dadCounter: 0, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], + }, + { + name: "SecretKey of less than minimum size", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "eth10", + dadCounter: 1, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], + }, + { + name: "SecretKey of more than minimum size", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "eth11", + dadCounter: 2, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], + }, + { + name: "Nil SecretKey", + prefix: func() tcpip.Subnet { + addrWithPrefix := tcpip.AddressWithPrefix{ + Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + return addrWithPrefix.Subnet() + }(), + nicName: "eth12", + dadCounter: 3, + secretKey: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + h := sha256.New() + h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address])) + h.Write([]byte(test.nicName)) + h.Write([]byte{test.dadCounter}) + if k := test.secretKey; k != nil { + h.Write(k) + } + var hashSum [sha256.Size]byte + h.Sum(hashSum[:0]) + want := hashSum[:header.IIDSize] + + // Passing a nil buffer should result in a new buffer returned with the + // IID. + if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { + t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) + } + + // Passing a buffer with sufficient capacity for the IID should populate + // the buffer provided. + var iidBuf [header.IIDSize]byte + if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) { + t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want) + } + if got := iidBuf[:]; !bytes.Equal(got, want) { + t.Errorf("got iidBuf = %x, want = %x", got, want) + } + }) + } +} + +func TestLinkLocalAddrWithOpaqueIID(t *testing.T) { + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte + if n, err := rand.Read(secretKeyBuf[:]); err != nil { + t.Fatalf("rand.Read(_): %s", err) + } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n) + } + + prefix := header.IPv6LinkLocalPrefix.Subnet() + + tests := []struct { + name string + prefix tcpip.Subnet + nicName string + dadCounter uint8 + secretKey []byte + }{ + { + name: "SecretKey of minimum size", + nicName: "eth0", + dadCounter: 0, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes], + }, + { + name: "SecretKey of less than minimum size", + nicName: "eth10", + dadCounter: 1, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2], + }, + { + name: "SecretKey of more than minimum size", + nicName: "eth11", + dadCounter: 2, + secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], + }, + { + name: "Nil SecretKey", + nicName: "eth12", + dadCounter: 3, + secretKey: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + addrBytes := [header.IPv6AddressSize]byte{ + 0: 0xFE, + 1: 0x80, + } + + want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier( + addrBytes[:header.IIDOffsetInIPv6Address], + prefix, + test.nicName, + test.dadCounter, + test.secretKey, + )) + + if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want { + t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want) + } + }) + } +} diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 69077669a..b8f9517d0 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -59,6 +59,7 @@ go_test( ], deps = [ ":stack", + "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ddd014658..3bed0af3c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -178,20 +178,24 @@ func (n *NIC) enable() *tcpip.Error { return nil } - l2addr := n.linkEP.LinkAddress() + var addr tcpip.Address + if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID()), 0, oIID.SecretKey) + } else { + l2addr := n.linkEP.LinkAddress() + + // Only attempt to generate the link-local address if we have a valid MAC + // address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + if !header.IsValidUnicastEthernetAddress(l2addr) { + return nil + } - // Only attempt to generate the link-local address if we have a - // valid MAC address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address - // (provided by LinkEndpoint.LinkAddress) before reaching this - // point. - if !header.IsValidUnicastEthernetAddress(l2addr) { - return nil + addr = header.LinkLocalAddr(l2addr) } - addr := header.LinkLocalAddr(l2addr) - _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 7a9600679..c6e6becf3 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -352,6 +352,33 @@ func (u *uniqueIDGenerator) UniqueID() uint64 { return atomic.AddUint64((*uint64)(u), 1) } +// NICNameFromID is a function that returns a stable name for the specified NIC, +// even if the NIC ID changes over time. +type NICNameFromID func(tcpip.NICID) string + +// OpaqueInterfaceIdentifierOptions holds the options related to the generation +// of opaque interface indentifiers (IIDs) as defined by RFC 7217. +type OpaqueInterfaceIdentifierOptions struct { + // NICNameFromID is a function that returns a stable name for a specified NIC, + // even if the NIC ID changes over time. + // + // Must be specified to generate the opaque IID. + NICNameFromID NICNameFromID + + // SecretKey is a pseudo-random number used as the secret key when generating + // opaque IIDs as defined by RFC 7217. The key SHOULD be at least + // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness + // requirements for security as outlined by RFC 4086. SecretKey MUST NOT + // change between program runs, unless explicitly changed. + // + // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey + // MUST NOT be modified after Stack is created. + // + // May be nil, but a nil value is highly discouraged to maintain + // some level of randomness between nodes. + SecretKey []byte +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -422,6 +449,10 @@ type Stack struct { // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID + + // opaqueIIDOpts hold the options for generating opaque interface identifiers + // (IIDs) as outlined by RFC 7217. + opaqueIIDOpts OpaqueInterfaceIdentifierOptions } // UniqueID is an abstract generator of unique identifiers. @@ -479,6 +510,10 @@ type Options struct { // RawFactory produces raw endpoints. Raw endpoints are enabled only if // this is non-nil. RawFactory RawFactory + + // OpaqueIIDOpts hold the options for generating opaque interface identifiers + // (IIDs) as outlined by RFC 7217. + OpaqueIIDOpts OpaqueInterfaceIdentifierOptions } // TransportEndpointInfo holds useful information about a transport endpoint @@ -549,6 +584,7 @@ func New(opts Options) *Stack { autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, + opaqueIIDOpts: opts.OpaqueIIDOpts, } // Add specified network protocols. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 8fc034ca1..e18dfea83 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -27,6 +27,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -1894,55 +1895,67 @@ func TestNICForwarding(t *testing.T) { } // TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses -// (or lack there-of if disabled (default)). Note, DAD will be disabled in -// these tests. +// using the modified EUI-64 of the NIC's MAC address (or lack there-of if +// disabled (default)). Note, DAD will be disabled in these tests. func TestNICAutoGenAddr(t *testing.T) { tests := []struct { name string autoGen bool linkAddr tcpip.LinkAddress + iidOpts stack.OpaqueInterfaceIdentifierOptions shouldGen bool }{ { "Disabled", false, linkAddr1, + stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(nicID tcpip.NICID) string { + return fmt.Sprintf("nic%d", nicID) + }, + }, false, }, { "Enabled", true, linkAddr1, + stack.OpaqueInterfaceIdentifierOptions{}, true, }, { "Nil MAC", true, tcpip.LinkAddress([]byte(nil)), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Empty MAC", true, tcpip.LinkAddress(""), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Invalid MAC", true, tcpip.LinkAddress("\x01\x02\x03"), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Multicast MAC", true, tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, { "Unspecified MAC", true, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + stack.OpaqueInterfaceIdentifierOptions{}, false, }, } @@ -1951,6 +1964,112 @@ func TestNICAutoGenAddr(t *testing.T) { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + OpaqueIIDOpts: test.iidOpts, + } + + if test.autoGen { + // Only set opts.AutoGenIPv6LinkLocal when test.autoGen is true because + // opts.AutoGenIPv6LinkLocal should be false by default. + opts.AutoGenIPv6LinkLocal = true + } + + e := channel.New(10, 1280, test.linkAddr) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + + if test.shouldGen { + // Should have auto-generated an address and resolved immediately (DAD + // is disabled). + if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } + } else { + // Should not have auto-generated an address. + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + } + }) + } +} + +// TestNICAutoGenAddrWithOpaque tests the auto-generation of IPv6 link-local +// addresses with opaque interface identifiers. Link Local addresses should +// always be generated with opaque IIDs if configured to use them, even if the +// NIC has an invalid MAC address. +func TestNICAutoGenAddrWithOpaque(t *testing.T) { + var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte + n, err := rand.Read(secretKey[:]) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n) + } + + iidOpts := stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(nicID tcpip.NICID) string { + return fmt.Sprintf("nic%d", nicID) + }, + SecretKey: secretKey[:], + } + + tests := []struct { + name string + autoGen bool + linkAddr tcpip.LinkAddress + }{ + { + "Disabled", + false, + linkAddr1, + }, + { + "Enabled", + true, + linkAddr1, + }, + // These are all cases where we would not have generated a + // link-local address if opaque IIDs were disabled. + { + "Nil MAC", + true, + tcpip.LinkAddress([]byte(nil)), + }, + { + "Empty MAC", + true, + tcpip.LinkAddress(""), + }, + { + "Invalid MAC", + true, + tcpip.LinkAddress("\x01\x02\x03"), + }, + { + "Multicast MAC", + true, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + { + "Unspecified MAC", + true, + tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + OpaqueIIDOpts: iidOpts, } if test.autoGen { @@ -1972,10 +2091,10 @@ func TestNICAutoGenAddr(t *testing.T) { t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) } - if test.shouldGen { + if test.autoGen { // Should have auto-generated an address and // resolved immediately (DAD is disabled). - if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) } } else { -- cgit v1.2.3 From 83ab47e87badd8b46f784739903361d9f824fa2c Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 3 Jan 2020 18:27:04 -0800 Subject: Use opaque interface identifiers when generating IPv6 addresses via SLAAC Support using opaque interface identifiers when generating IPv6 addresses via SLAAC when configured to do so. Note, this change does not handle retries in response to DAD conflicts yet. That will also come in a later change. Test: Test that when SLAAC addresses are generated, they use opaque interface identifiers when configured to do so. PiperOrigin-RevId: 288078605 --- pkg/tcpip/stack/ndp.go | 32 +++++++------- pkg/tcpip/stack/ndp_test.go | 104 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 15 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index d9ab59336..ba6a57e6f 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -1028,22 +1028,24 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform return } - // Only attempt to generate an interface-specific IID if we have a valid - // link address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address - // (provided by LinkEndpoint.LinkAddress) before reaching this - // point. - linkAddr := ndp.nic.linkEP.LinkAddress() - if !header.IsValidUnicastEthernetAddress(linkAddr) { - return - } + addrBytes := []byte(prefix.ID()) + if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID()), 0 /* dadCounter */, oIID.SecretKey) + } else { + // Only attempt to generate an interface-specific IID if we have a valid + // link address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + linkAddr := ndp.nic.linkEP.LinkAddress() + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return + } - // Generate an address within prefix from the modified EUI-64 of ndp's - // NIC's Ethernet MAC address. - addrBytes := make([]byte, header.IPv6AddressSize) - copy(addrBytes[:header.IIDOffsetInIPv6Address], prefix.ID()[:header.IIDOffsetInIPv6Address]) - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + // Generate an address within prefix from the modified EUI-64 of ndp's NIC's + // Ethernet MAC address. + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + } addr := tcpip.Address(addrBytes) addrWithPrefix := tcpip.AddressWithPrefix{ Address: addr, diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 64a9a2b20..8e817e730 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -21,6 +21,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -1911,6 +1912,109 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { } } +// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use +// opaque interface identifiers when configured to do so. +func TestAutoGenAddrWithOpaqueIID(t *testing.T) { + t.Parallel() + + const nicID = 1 + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte + secretKey := secretKeyBuf[:] + n, err := rand.Read(secretKey) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) + } + + prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1) + prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1) + // addr1 and addr2 are the addresses that are expected to be generated when + // stack.Stack is configured to generate opaque interface identifiers as + // defined by RFC 7217. + addrBytes := []byte(subnet1.ID()) + addr1 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, "nic1", 0, secretKey)), + PrefixLen: 64, + } + addrBytes = []byte(subnet2.ID()) + addr2 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, "nic1", 0, secretKey)), + PrefixLen: 64, + } + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(nicID tcpip.NICID) string { + return fmt.Sprintf("nic%d", nicID) + }, + SecretKey: secretKey, + }, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive an RA with prefix1 in a PI. + const validLifetimeSecondPrefix1 = 1 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + + // Receive an RA with prefix2 in a PI with a large valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + + // Wait for addr of prefix1 to be invalidated. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } +} + // TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event // to the integrator when an RA is received with the NDP Recursive DNS Server // option with at least one valid address. -- cgit v1.2.3 From 8dfd92284016f7c719b5766506cf3d6ab9c39c0e Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Mon, 6 Jan 2020 16:04:19 -0800 Subject: Pass the NIC-internal name to the NIC name function when generating opaque IIDs Pass the NIC-internal name to the NIC name function when generating opaque IIDs so implementations can use the name that was provided when the NIC was created. Previously, explicit NICID to NIC name resolution was required from the netstack integrator. Tests: Test that the name provided when creating a NIC is passed to the NIC name function when generating opaque IIDs. PiperOrigin-RevId: 288395359 --- pkg/tcpip/header/ipv6_test.go | 8 ++-- pkg/tcpip/stack/ndp.go | 2 +- pkg/tcpip/stack/ndp_test.go | 13 ++++--- pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack.go | 9 ++++- pkg/tcpip/stack/stack_test.go | 90 +++++++++++++++++++++++++------------------ 6 files changed, 72 insertions(+), 52 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index cd1862e42..1994003ed 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -96,7 +96,7 @@ func TestAppendOpaqueInterfaceIdentifier(t *testing.T) { secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], }, { - name: "Nil SecretKey", + name: "Nil SecretKey and empty nicName", prefix: func() tcpip.Subnet { addrWithPrefix := tcpip.AddressWithPrefix{ Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", @@ -104,7 +104,7 @@ func TestAppendOpaqueInterfaceIdentifier(t *testing.T) { } return addrWithPrefix.Subnet() }(), - nicName: "eth12", + nicName: "", dadCounter: 3, secretKey: nil, }, @@ -178,8 +178,8 @@ func TestLinkLocalAddrWithOpaqueIID(t *testing.T) { secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2], }, { - name: "Nil SecretKey", - nicName: "eth12", + name: "Nil SecretKey and empty nicName", + nicName: "", dadCounter: 3, secretKey: nil, }, diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index ba6a57e6f..238bc27dc 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -1030,7 +1030,7 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform addrBytes := []byte(prefix.ID()) if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { - addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID()), 0 /* dadCounter */, oIID.SecretKey) + addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), 0 /* dadCounter */, oIID.SecretKey) } else { // Only attempt to generate an interface-specific IID if we have a valid // link address. diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 8e817e730..9430844d3 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1918,6 +1918,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { t.Parallel() const nicID = 1 + const nicName = "nic1" var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte secretKey := secretKeyBuf[:] n, err := rand.Read(secretKey) @@ -1935,12 +1936,12 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { // defined by RFC 7217. addrBytes := []byte(subnet1.ID()) addr1 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, "nic1", 0, secretKey)), + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)), PrefixLen: 64, } addrBytes = []byte(subnet2.ID()) addr2 := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, "nic1", 0, secretKey)), + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)), PrefixLen: 64, } @@ -1956,15 +1957,15 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { }, NDPDisp: &ndpDisp, OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(nicID tcpip.NICID) string { - return fmt.Sprintf("nic%d", nicID) + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName }, SecretKey: secretKey, }, }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + if err := s.CreateNamedNIC(nicID, nicName, e); err != nil { + t.Fatalf("CreateNamedNIC(%d, %q, _) = %s", nicID, nicName, err) } expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3bed0af3c..044fe5298 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -180,7 +180,7 @@ func (n *NIC) enable() *tcpip.Error { var addr tcpip.Address if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { - addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID()), 0, oIID.SecretKey) + addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey) } else { l2addr := n.linkEP.LinkAddress() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index c6e6becf3..ffb379363 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -353,8 +353,13 @@ func (u *uniqueIDGenerator) UniqueID() uint64 { } // NICNameFromID is a function that returns a stable name for the specified NIC, -// even if the NIC ID changes over time. -type NICNameFromID func(tcpip.NICID) string +// even if different NIC IDs are used to refer to the same NIC in different +// program runs. It is used when generating opaque interface identifiers (IIDs). +// If the NIC was created with a name, it will be passed to NICNameFromID. +// +// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are +// generated for the same prefix on differnt NICs. +type NICNameFromID func(tcpip.NICID, string) string // OpaqueInterfaceIdentifierOptions holds the options related to the generation // of opaque interface indentifiers (IIDs) as defined by RFC 7217. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e18dfea83..f533949c0 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1910,7 +1910,7 @@ func TestNICAutoGenAddr(t *testing.T) { false, linkAddr1, stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(nicID tcpip.NICID) string { + NICNameFromID: func(nicID tcpip.NICID, _ string) string { return fmt.Sprintf("nic%d", nicID) }, }, @@ -2005,6 +2005,8 @@ func TestNICAutoGenAddr(t *testing.T) { // always be generated with opaque IIDs if configured to use them, even if the // NIC has an invalid MAC address. func TestNICAutoGenAddrWithOpaque(t *testing.T) { + const nicID = 1 + var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte n, err := rand.Read(secretKey[:]) if err != nil { @@ -2014,54 +2016,61 @@ func TestNICAutoGenAddrWithOpaque(t *testing.T) { t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n) } - iidOpts := stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(nicID tcpip.NICID) string { - return fmt.Sprintf("nic%d", nicID) - }, - SecretKey: secretKey[:], - } - tests := []struct { - name string - autoGen bool - linkAddr tcpip.LinkAddress + name string + nicName string + autoGen bool + linkAddr tcpip.LinkAddress + secretKey []byte }{ { - "Disabled", - false, - linkAddr1, + name: "Disabled", + nicName: "nic1", + autoGen: false, + linkAddr: linkAddr1, + secretKey: secretKey[:], }, { - "Enabled", - true, - linkAddr1, + name: "Enabled", + nicName: "nic1", + autoGen: true, + linkAddr: linkAddr1, + secretKey: secretKey[:], }, // These are all cases where we would not have generated a // link-local address if opaque IIDs were disabled. { - "Nil MAC", - true, - tcpip.LinkAddress([]byte(nil)), + name: "Nil MAC and empty nicName", + nicName: "", + autoGen: true, + linkAddr: tcpip.LinkAddress([]byte(nil)), + secretKey: secretKey[:1], }, { - "Empty MAC", - true, - tcpip.LinkAddress(""), + name: "Empty MAC and empty nicName", + autoGen: true, + linkAddr: tcpip.LinkAddress(""), + secretKey: secretKey[:2], }, { - "Invalid MAC", - true, - tcpip.LinkAddress("\x01\x02\x03"), + name: "Invalid MAC", + nicName: "test", + autoGen: true, + linkAddr: tcpip.LinkAddress("\x01\x02\x03"), + secretKey: secretKey[:3], }, { - "Multicast MAC", - true, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + name: "Multicast MAC", + nicName: "test2", + autoGen: true, + linkAddr: tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + secretKey: secretKey[:4], }, { - "Unspecified MAC", - true, - tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + name: "Unspecified MAC and nil SecretKey", + nicName: "test3", + autoGen: true, + linkAddr: tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), }, } @@ -2069,7 +2078,12 @@ func TestNICAutoGenAddrWithOpaque(t *testing.T) { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - OpaqueIIDOpts: iidOpts, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: test.secretKey, + }, } if test.autoGen { @@ -2082,19 +2096,19 @@ func TestNICAutoGenAddrWithOpaque(t *testing.T) { e := channel.New(10, 1280, test.linkAddr) s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNamedNIC(nicID, test.nicName, e); err != nil { + t.Fatalf("CreateNamedNIC(%d, %q, _) = %s", nicID, test.nicName, err) } - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) } if test.autoGen { // Should have auto-generated an address and // resolved immediately (DAD is disabled). - if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddrWithOpaqueIID(test.nicName, 0, test.secretKey), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) } } else { -- cgit v1.2.3 From 2031cc4701d5bfd21b34d7b0f7dc86920a553385 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 7 Jan 2020 10:07:59 -0800 Subject: Disable auto-generation of IPv6 link-local addresses for loopback NICs Test: Test that an IPv6 link-local address is not auto-generated for loopback NICs, even when it is enabled for non-loopback NICS. PiperOrigin-RevId: 288519591 --- pkg/tcpip/stack/nic.go | 3 ++- pkg/tcpip/stack/stack.go | 18 +++++++++------- pkg/tcpip/stack/stack_test.go | 49 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 9 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 044fe5298..523c2a699 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -174,7 +174,8 @@ func (n *NIC) enable() *tcpip.Error { return err } - if !n.stack.autoGenIPv6LinkLocal { + // Do not auto-generate an IPv6 link-local address for loopback devices. + if !n.stack.autoGenIPv6LinkLocal || n.loopback { return nil } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index ffb379363..d4e98f277 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -444,8 +444,8 @@ type Stack struct { ndpConfigs NDPConfigurations // autoGenIPv6LinkLocal determines whether or not the stack will attempt - // to auto-generate an IPv6 link-local address for newly enabled NICs. - // See the AutoGenIPv6LinkLocal field of Options for more details. + // to auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. autoGenIPv6LinkLocal bool // ndpDisp is the NDP event dispatcher that is used to send the netstack @@ -496,13 +496,15 @@ type Options struct { // before assigning an address to a NIC. NDPConfigs NDPConfigurations - // AutoGenIPv6LinkLocal determins whether or not the stack will attempt - // to auto-generate an IPv6 link-local address for newly enabled NICs. + // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to + // auto-generate an IPv6 link-local address for newly enabled non-loopback + // NICs. + // // Note, setting this to true does not mean that a link-local address - // will be assigned right away, or at all. If Duplicate Address - // Detection is enabled, an address will only be assigned if it - // successfully resolves. If it fails, no further attempt will be made - // to auto-generate an IPv6 link-local address. + // will be assigned right away, or at all. If Duplicate Address Detection + // is enabled, an address will only be assigned if it successfully resolves. + // If it fails, no further attempt will be made to auto-generate an IPv6 + // link-local address. // // The generated link-local address will follow RFC 4291 Appendix A // guidelines. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index f533949c0..d970a4abb 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2121,6 +2121,55 @@ func TestNICAutoGenAddrWithOpaque(t *testing.T) { } } +// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are +// not auto-generated for loopback NICs. +func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { + const nicID = 1 + const nicName = "nicName" + + tests := []struct { + name string + opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions + }{ + { + name: "IID From MAC", + opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{}, + }, + { + name: "Opaque IID", + opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: true, + OpaqueIIDOpts: test.opaqueIIDOpts, + } + + e := channel.New(0, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNamedLoopbackNIC(nicID, nicName, e); err != nil { + t.Fatalf("CreateNamedLoopbackNIC(%d, %q, _) = %s", nicID, nicName, err) + } + + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want) + } + }) + } +} + // TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 // link-local addresses will only be assigned after the DAD process resolves. func TestNICAutoGenAddrDoesDAD(t *testing.T) { -- cgit v1.2.3 From 4e19d165ccc8035cd23eb31f34af82f1d6389907 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 7 Jan 2020 13:40:43 -0800 Subject: Support deprecating SLAAC addresses after the preferred lifetime Support deprecating network endpoints on a NIC. If an endpoint is deprecated, it should not be used for new connections unless a more preferred endpoint is not available, or unless the deprecated endpoint was explicitly requested. Test: Test that deprecated endpoints are only returned when more preferred endpoints are not available and SLAAC addresses are deprecated after its preferred lifetime PiperOrigin-RevId: 288562705 --- pkg/tcpip/stack/ndp.go | 289 ++++++++++++++++-------- pkg/tcpip/stack/ndp_test.go | 539 +++++++++++++++++++++++++++++++++++++++++++- pkg/tcpip/stack/nic.go | 93 +++++++- pkg/tcpip/stack/stack.go | 15 +- 4 files changed, 827 insertions(+), 109 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 238bc27dc..4722ec9ce 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -169,6 +169,15 @@ type NDPDispatcher interface { // call functions on the stack itself. OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + // OnAutoGenAddressDeprecated will be called when an auto-generated + // address (as part of SLAAC) has been deprecated, but is still + // considered valid. Note, if an address is invalidated at the same + // time it is deprecated, the deprecation event MAY be omitted. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) + // OnAutoGenAddressInvalidated will be called when an auto-generated // address (as part of SLAAC) has been invalidated. // @@ -335,6 +344,17 @@ type onLinkPrefixState struct { // autoGenAddressState holds data associated with an address generated via // SLAAC. type autoGenAddressState struct { + // A reference to the referencedNetworkEndpoint that this autoGenAddressState + // is holding state for. + ref *referencedNetworkEndpoint + + deprecationTimer *time.Timer + + // Used to signal the timer not to deprecate the SLAAC address in a race + // condition. Used for the same reason as doNotInvalidate, but for deprecating + // an address. + doNotDeprecate *bool + invalidationTimer *time.Timer // Used to signal the timer not to invalidate the SLAAC address (A) in @@ -912,103 +932,30 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform prefix := pi.Subnet() // Check if we already have an auto-generated address for prefix. - for _, ref := range ndp.nic.endpoints { - if ref.protocol != header.IPv6ProtocolNumber { - continue - } - - if ref.configType != slaac { - continue - } - - addr := ref.ep.ID().LocalAddress - refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: ref.ep.PrefixLen()} + for addr, addrState := range ndp.autoGenAddresses { + refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: addrState.ref.ep.PrefixLen()} if refAddrWithPrefix.Subnet() != prefix { continue } - // - // At this point, we know we are refreshing a SLAAC generated - // IPv6 address with the prefix, prefix. Do the work as outlined - // by RFC 4862 section 5.5.3.e. - // - - addrState, ok := ndp.autoGenAddresses[addr] - if !ok { - panic(fmt.Sprintf("must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr)) - } - - // TODO(b/143713887): Handle deprecating auto-generated address - // after the preferred lifetime. - - // As per RFC 4862 section 5.5.3.e, the valid lifetime of the - // address generated by SLAAC is as follows: - // - // 1) If the received Valid Lifetime is greater than 2 hours or - // greater than RemainingLifetime, set the valid lifetime of - // the address to the advertised Valid Lifetime. - // - // 2) If RemainingLifetime is less than or equal to 2 hours, - // ignore the advertised Valid Lifetime. - // - // 3) Otherwise, reset the valid lifetime of the address to 2 - // hours. - - // Handle the infinite valid lifetime separately as we do not - // keep a timer in this case. - if vl >= header.NDPInfiniteLifetime { - if addrState.invalidationTimer != nil { - // Valid lifetime was finite before, but now it - // is valid forever. - if !addrState.invalidationTimer.Stop() { - *addrState.doNotInvalidate = true - } - addrState.invalidationTimer = nil - addrState.validUntil = time.Time{} - ndp.autoGenAddresses[addr] = addrState - } - - return - } - - var effectiveVl time.Duration - var rl time.Duration - - // If the address was originally set to be valid forever, - // assume the remaining time to be the maximum possible value. - if addrState.invalidationTimer == nil { - rl = header.NDPInfiniteLifetime - } else { - rl = time.Until(addrState.validUntil) - } - - if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { - effectiveVl = vl - } else if rl <= MinPrefixInformationValidLifetimeForUpdate { - ndp.autoGenAddresses[addr] = addrState - return - } else { - effectiveVl = MinPrefixInformationValidLifetimeForUpdate - } - - if addrState.invalidationTimer == nil { - addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate) - } else { - if !addrState.invalidationTimer.Stop() { - *addrState.doNotInvalidate = true - } - addrState.invalidationTimer.Reset(effectiveVl) - } - - addrState.validUntil = time.Now().Add(effectiveVl) - ndp.autoGenAddresses[addr] = addrState + // At this point, we know we are refreshing a SLAAC generated IPv6 address + // with the prefix prefix. Do the work as outlined by RFC 4862 section + // 5.5.3.e. + ndp.refreshAutoGenAddressLifetimes(addr, pl, vl) return } // We do not already have an address within the prefix, prefix. Do the // work as outlined by RFC 4862 section 5.5.3.d if n is configured // to auto-generated global addresses by SLAAC. + ndp.newAutoGenAddress(prefix, pl, vl) +} +// newAutoGenAddress generates a new SLAAC address with the provided lifetimes +// for prefix. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration) { // Are we configured to auto-generate new global addresses? if !ndp.configs.AutoGenGlobalAddresses { return @@ -1067,18 +1014,25 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform return } - if _, err := ndp.nic.addAddressLocked(tcpip.ProtocolAddress{ + protocolAddr := tcpip.ProtocolAddress{ Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addrWithPrefix, - }, FirstPrimaryEndpoint, permanent, slaac); err != nil { - panic(err) + } + // If the preferred lifetime is zero, then the address should be considered + // deprecated. + deprecated := pl == 0 + ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) + if err != nil { + log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err) } - // Setup the timers to deprecate and invalidate this newly generated - // address. + // Setup the timers to deprecate and invalidate this newly generated address. - // TODO(b/143713887): Handle deprecating auto-generated addresses - // after the preferred lifetime. + var doNotDeprecate bool + var pTimer *time.Timer + if !deprecated && pl < header.NDPInfiniteLifetime { + pTimer = ndp.autoGenAddrDeprecationTimer(addr, pl, &doNotDeprecate) + } var doNotInvalidate bool var vTimer *time.Timer @@ -1087,12 +1041,126 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform } ndp.autoGenAddresses[addr] = autoGenAddressState{ + ref: ref, + deprecationTimer: pTimer, + doNotDeprecate: &doNotDeprecate, invalidationTimer: vTimer, doNotInvalidate: &doNotInvalidate, validUntil: time.Now().Add(vl), } } +// refreshAutoGenAddressLifetimes refreshes the lifetime of a SLAAC generated +// address addr. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl time.Duration) { + addrState, ok := ndp.autoGenAddresses[addr] + if !ok { + log.Fatalf("ndp: SLAAC state not found to refresh lifetimes for %s", addr) + } + defer func() { ndp.autoGenAddresses[addr] = addrState }() + + // If the preferred lifetime is zero, then the address should be considered + // deprecated. + deprecated := pl == 0 + wasDeprecated := addrState.ref.deprecated + addrState.ref.deprecated = deprecated + + // Only send the deprecation event if the deprecated status for addr just + // changed from non-deprecated to deprecated. + if !wasDeprecated && deprecated { + ndp.notifyAutoGenAddressDeprecated(addr) + } + + // If addr was preferred for some finite lifetime before, stop the deprecation + // timer so it can be reset. + if t := addrState.deprecationTimer; t != nil && !t.Stop() { + *addrState.doNotDeprecate = true + } + + // Reset the deprecation timer. + if pl >= header.NDPInfiniteLifetime || deprecated { + // If addr is preferred forever or it has been deprecated already, there is + // no need for a deprecation timer. + addrState.deprecationTimer = nil + } else if addrState.deprecationTimer == nil { + // addr is now preferred for a finite lifetime. + addrState.deprecationTimer = ndp.autoGenAddrDeprecationTimer(addr, pl, addrState.doNotDeprecate) + } else { + // addr continues to be preferred for a finite lifetime. + addrState.deprecationTimer.Reset(pl) + } + + // As per RFC 4862 section 5.5.3.e, the valid lifetime of the address + // + // + // 1) If the received Valid Lifetime is greater than 2 hours or greater than + // RemainingLifetime, set the valid lifetime of the address to the + // advertised Valid Lifetime. + // + // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the + // advertised Valid Lifetime. + // + // 3) Otherwise, reset the valid lifetime of the address to 2 hours. + + // Handle the infinite valid lifetime separately as we do not keep a timer in + // this case. + if vl >= header.NDPInfiniteLifetime { + if addrState.invalidationTimer != nil { + // Valid lifetime was finite before, but now it is valid forever. + if !addrState.invalidationTimer.Stop() { + *addrState.doNotInvalidate = true + } + addrState.invalidationTimer = nil + addrState.validUntil = time.Time{} + } + + return + } + + var effectiveVl time.Duration + var rl time.Duration + + // If the address was originally set to be valid forever, assume the remaining + // time to be the maximum possible value. + if addrState.invalidationTimer == nil { + rl = header.NDPInfiniteLifetime + } else { + rl = time.Until(addrState.validUntil) + } + + if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { + effectiveVl = vl + } else if rl <= MinPrefixInformationValidLifetimeForUpdate { + return + } else { + effectiveVl = MinPrefixInformationValidLifetimeForUpdate + } + + if addrState.invalidationTimer == nil { + addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate) + } else { + if !addrState.invalidationTimer.Stop() { + *addrState.doNotInvalidate = true + } + addrState.invalidationTimer.Reset(effectiveVl) + } + + addrState.validUntil = time.Now().Add(effectiveVl) +} + +// notifyAutoGenAddressDeprecated notifies the stack's NDP dispatcher that addr +// has been deprecated. +func (ndp *ndpState) notifyAutoGenAddressDeprecated(addr tcpip.Address) { + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: validPrefixLenForAutoGen, + }) + } +} + // invalidateAutoGenAddress invalidates an auto-generated address. // // The NIC that ndp belongs to MUST be locked. @@ -1118,6 +1186,14 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo return false } + if state.deprecationTimer != nil { + state.deprecationTimer.Stop() + state.deprecationTimer = nil + *state.doNotDeprecate = true + } + + state.doNotDeprecate = nil + if state.invalidationTimer != nil { state.invalidationTimer.Stop() state.invalidationTimer = nil @@ -1138,6 +1214,33 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo return true } +// autoGenAddrDeprecationTimer returns a new deprecation timer for an +// auto-generated address that fires after pl. +// +// doNotDeprecate is used to inform the timer when it fires at the same time +// that an auto-generated address's preferred lifetime gets refreshed. See +// autoGenAddrState.doNotDeprecate for more details. +func (ndp *ndpState) autoGenAddrDeprecationTimer(addr tcpip.Address, pl time.Duration, doNotDeprecate *bool) *time.Timer { + return time.AfterFunc(pl, func() { + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + + if *doNotDeprecate { + *doNotDeprecate = false + return + } + + addrState, ok := ndp.autoGenAddresses[addr] + if !ok { + log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr) + } + addrState.ref.deprecated = true + ndp.notifyAutoGenAddressDeprecated(addr) + addrState.deprecationTimer = nil + ndp.autoGenAddresses[addr] = addrState + }) +} + // autoGenAddrInvalidationTimer returns a new invalidation timer for an // auto-generated address that fires after vl. // diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 9430844d3..8d89859ba 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -30,6 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) const ( @@ -46,6 +48,10 @@ var ( llAddr1 = header.LinkLocalAddr(linkAddr1) llAddr2 = header.LinkLocalAddr(linkAddr2) llAddr3 = header.LinkLocalAddr(linkAddr3) + dstAddr = tcpip.FullAddress{ + Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + Port: 25, + } ) func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix { @@ -136,6 +142,7 @@ type ndpAutoGenAddrEventType int const ( newAddr ndpAutoGenAddrEventType = iota + deprecatedAddr invalidatedAddr ) @@ -240,6 +247,16 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi return true } +func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { + if c := n.autoGenAddrC; c != nil { + c <- ndpAutoGenAddrEvent{ + nicID, + addr, + deprecatedAddr, + } + } +} + func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { if c := n.autoGenAddrC; c != nil { c <- ndpAutoGenAddrEvent{ @@ -1638,12 +1655,532 @@ func TestAutoGenAddr(t *testing.T) { } } +// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher, +// channel.Endpoint and stack.Stack. +// +// stack.Stack will have a default route through the router (llAddr3) installed +// and a static link-address (linkAddr3) added to the link address cache for the +// router. +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { + t.Helper() + ndpDisp := &ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: ndpDisp, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: llAddr3, + NIC: nicID, + }}) + s.AddLinkAddress(nicID, llAddr3, linkAddr3) + return ndpDisp, e, s +} + +// addrForNewConnection returns the local address used when creating a new +// connection. +func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + v := tcpip.V6OnlyOption(1) + if err := ep.SetSockOpt(v); err != nil { + t.Fatalf("SetSockOpt(%+v): %s", v, err) + } + if err := ep.Connect(dstAddr); err != nil { + t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) + } + got, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("ep.GetLocalAddress(): %s", err) + } + return got.Addr +} + +// addrForNewConnectionWithAddr returns the local address used when creating a +// new connection with a specific local address. +func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + v := tcpip.V6OnlyOption(1) + if err := ep.SetSockOpt(v); err != nil { + t.Fatalf("SetSockOpt(%+v): %s", v, err) + } + if err := ep.Bind(addr); err != nil { + t.Fatalf("ep.Bind(%+v): %s", addr, err) + } + if err := ep.Connect(dstAddr); err != nil { + t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) + } + got, err := ep.GetLocalAddress() + if err != nil { + t.Fatalf("ep.GetLocalAddress(): %s", err) + } + return got.Addr +} + +// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when +// receiving a PI with 0 preferred lifetime. +func TestAutoGenAddrDeprecateFromPI(t *testing.T) { + const nicID = 1 + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() + + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } + + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } + + // Receive PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + expectPrimaryAddr(addr1) + + // Deprecate addr for prefix1 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + // addr should still be the primary endpoint as there are no other addresses. + expectPrimaryAddr(addr1) + + // Refresh lifetimes of addr generated from prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Deprecate addr for prefix2 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr1 should be the primary endpoint now since addr2 is deprecated but + // addr1 is not. + expectPrimaryAddr(addr1) + // addr2 is deprecated but if explicitly requested, it should be used. + fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + } + + // Another PI w/ 0 preferred lifetime should not result in a deprecation + // event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + } + + // Refresh lifetimes of addr generated from prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) +} + +// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated +// when its preferred lifetime expires. +func TestAutoGenAddrTimerDeprecation(t *testing.T) { + const nicID = 1 + const newMinVL = 2 + newMinVLDuration := newMinVL * time.Second + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() + + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } + + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } + + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Receive a PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) + expectAutoGenAddrEvent(addr1, newAddr) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr1) + + // Refresh lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since addr1 is deprecated but + // addr2 is not. + expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. + fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + } + + // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make + // sure we do not get a deprecation event again. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + } + + // Refresh lifetimes for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + // addr1 is the primary endpoint again since it is non-deprecated now. + expectPrimaryAddr(addr1) + + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout) + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since it is not deprecated. + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + } + + // Wait for addr of prefix1 to be invalidated. + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultTimeout) + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) + + // Refresh both lifetimes for addr of prefix2 to the same value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + + // Wait for a deprecation then invalidation events, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation handlers could be handled in + // either deprecation then invalidation, or invalidation then deprecation + // (which should be cancelled by the invalidation handler). + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we should not get a deprecation + // event after. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + case <-time.After(defaultTimeout): + } + } else { + t.Fatalf("got unexpected auto-generated event") + } + + case <-time.After(newMinVLDuration + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should not have %s in the list of addresses", addr2) + } + // Should not have any primary endpoints. + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); got != want { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) + } + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + v := tcpip.V6OnlyOption(1) + if err := ep.SetSockOpt(v); err != nil { + t.Fatalf("SetSockOpt(%+v): %s", v, err) + } + + if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { + t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute) + } +} + +// Tests transitioning a SLAAC address's valid lifetime between finite and +// infinite values. +func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { + const infiniteVLSeconds = 2 + const minVLSeconds = 1 + savedIL := header.NDPInfiniteLifetime + savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL + header.NDPInfiniteLifetime = savedIL + }() + stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second + header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + + tests := []struct { + name string + infiniteVL uint32 + }{ + { + name: "EqualToInfiniteVL", + infiniteVL: infiniteVLSeconds, + }, + // Our implementation supports changing header.NDPInfiniteLifetime for tests + // such that a packet can be received where the lifetime field has a value + // greater than header.NDPInfiniteLifetime. Because of this, we test to make + // sure that receiving a value greater than header.NDPInfiniteLifetime is + // handled the same as when receiving a value equal to + // header.NDPInfiniteLifetime. + { + name: "MoreThanInfiniteVL", + infiniteVL: infiniteVLSeconds + 1, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Receive an RA with finite prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + default: + t.Fatal("expected addr auto gen event") + } + + // Receive an new RA with prefix with infinite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) + + // Receive a new RA with prefix with finite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + case <-time.After(minVLSeconds*time.Second + defaultTimeout): + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } + }) +} + // TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an // auto-generated address only gets updated when required to, as specified in // RFC 4862 section 5.5.3.e. func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { const infiniteVL = 4294967295 - const newMinVL = 5 + const newMinVL = 4 saved := stack.MinPrefixInformationValidLifetimeForUpdate defer func() { stack.MinPrefixInformationValidLifetimeForUpdate = saved diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 523c2a699..5726c3642 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -249,17 +249,47 @@ func (n *NIC) setSpoofing(enable bool) { // primaryEndpoint returns the primary endpoint of n for the given network // protocol. +// +// primaryEndpoint will return the first non-deprecated endpoint if such an +// endpoint exists. If no non-deprecated endpoint exists, the first deprecated +// endpoint will be returned. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { n.mu.RLock() defer n.mu.RUnlock() + var deprecatedEndpoint *referencedNetworkEndpoint for _, r := range n.primary[protocol] { - if r.isValidForOutgoing() && r.tryIncRef() { - return r + if !r.isValidForOutgoing() { + continue + } + + if !r.deprecated { + if r.tryIncRef() { + // r is not deprecated, so return it immediately. + // + // If we kept track of a deprecated endpoint, decrement its reference + // count since it was incremented when we decided to keep track of it. + if deprecatedEndpoint != nil { + deprecatedEndpoint.decRefLocked() + deprecatedEndpoint = nil + } + + return r + } + } else if deprecatedEndpoint == nil && r.tryIncRef() { + // We prefer an endpoint that is not deprecated, but we keep track of r in + // case n doesn't have any non-deprecated endpoints. + // + // If we end up finding a more preferred endpoint, r's reference count + // will be decremented when such an endpoint is found. + deprecatedEndpoint = r } } - return nil + // n doesn't have any valid non-deprecated endpoints, so return + // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated + // endpoints either). + return deprecatedEndpoint } // hasPermanentAddrLocked returns true if n has a permanent (including currently @@ -367,7 +397,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t Address: address, PrefixLen: netProto.DefaultPrefixLen(), }, - }, peb, temporary, static) + }, peb, temporary, static, false) n.mu.Unlock() return ref @@ -416,10 +446,10 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } } - return n.addAddressLocked(protocolAddress, peb, permanent, static) + return n.addAddressLocked(protocolAddress, peb, permanent, static, false) } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { // TODO(b/141022673): Validate IP address before adding them. // Sanity check. @@ -455,6 +485,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar protocol: protocolAddress.Protocol, kind: kind, configType: configType, + deprecated: deprecated, } // Set up cache if link address resolution exists for this protocol. @@ -553,6 +584,51 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { return addrs } +// primaryAddress returns the primary address associated with this NIC. +// +// primaryAddress will return the first non-deprecated address if such an +// address exists. If no non-deprecated address exists, the first deprecated +// address will be returned. +func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix { + n.mu.RLock() + defer n.mu.RUnlock() + + list, ok := n.primary[proto] + if !ok { + return tcpip.AddressWithPrefix{} + } + + var deprecatedEndpoint *referencedNetworkEndpoint + for _, ref := range list { + // Don't include tentative, expired or tempory endpoints to avoid confusion + // and prevent the caller from using those. + switch ref.getKind() { + case permanentTentative, permanentExpired, temporary: + continue + } + + if !ref.deprecated { + return tcpip.AddressWithPrefix{ + Address: ref.ep.ID().LocalAddress, + PrefixLen: ref.ep.PrefixLen(), + } + } + + if deprecatedEndpoint == nil { + deprecatedEndpoint = ref + } + } + + if deprecatedEndpoint != nil { + return tcpip.AddressWithPrefix{ + Address: deprecatedEndpoint.ep.ID().LocalAddress, + PrefixLen: deprecatedEndpoint.ep.PrefixLen(), + } + } + + return tcpip.AddressWithPrefix{} +} + // AddAddressRange adds a range of addresses to n, so that it starts accepting // packets targeted at the given addresses and network protocol. The range is // given by a subnet address, and all addresses contained in the subnet are @@ -1109,6 +1185,11 @@ type referencedNetworkEndpoint struct { // configType is the method that was used to configure this endpoint. // This must never change after the endpoint is added to a NIC. configType networkEndpointConfigType + + // deprecated indicates whether or not the endpoint should be considered + // deprecated. That is, when deprecated is true, other endpoints that are not + // deprecated should be preferred. + deprecated bool } func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index d4e98f277..583ede3e5 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1036,9 +1036,11 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { return nics } -// GetMainNICAddress returns the first primary address and prefix for the given -// NIC and protocol. Returns an error if the NIC doesn't exist and an empty -// value if the NIC doesn't have a primary address for the given protocol. +// GetMainNICAddress returns the first non-deprecated primary address and prefix +// for the given NIC and protocol. If no non-deprecated primary address exists, +// a deprecated primary address and prefix will be returned. Returns an error if +// the NIC doesn't exist and an empty value if the NIC doesn't have a primary +// address for the given protocol. func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -1048,12 +1050,7 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID } - for _, a := range nic.PrimaryAddresses() { - if a.Protocol == protocol { - return a.AddressWithPrefix, nil - } - } - return tcpip.AddressWithPrefix{}, nil + return nic.primaryAddress(protocol), nil } func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { -- cgit v1.2.3 From 0cc1e74b57e539e66c1a421c047a08635c0008e8 Mon Sep 17 00:00:00 2001 From: Bert Muthalaly Date: Wed, 8 Jan 2020 09:28:53 -0800 Subject: Add NIC.isLoopback() ...enabling us to remove the "CreateNamedLoopbackNIC" variant of CreateNIC and all the plumbing to connect it through to where the value is read in FindRoute. PiperOrigin-RevId: 288713093 --- pkg/tcpip/stack/nic.go | 18 ++++++++++-------- pkg/tcpip/stack/stack.go | 24 +++++++++--------------- pkg/tcpip/stack/stack_test.go | 7 ++++--- runsc/boot/network.go | 16 +++++----------- 4 files changed, 28 insertions(+), 37 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 5726c3642..4144d5d0f 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -27,11 +27,10 @@ import ( // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { - stack *Stack - id tcpip.NICID - name string - linkEP LinkEndpoint - loopback bool + stack *Stack + id tcpip.NICID + name string + linkEP LinkEndpoint mu sync.RWMutex spoofing bool @@ -85,7 +84,7 @@ const ( ) // newNIC returns a new NIC using the default NDP configurations from stack. -func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC { +func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For // example, make sure that the link address it provides is a valid // unicast ethernet address. @@ -99,7 +98,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback id: id, name: name, linkEP: ep, - loopback: loopback, primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), @@ -175,7 +173,7 @@ func (n *NIC) enable() *tcpip.Error { } // Do not auto-generate an IPv6 link-local address for loopback devices. - if !n.stack.autoGenIPv6LinkLocal || n.loopback { + if !n.stack.autoGenIPv6LinkLocal || n.isLoopback() { return nil } @@ -240,6 +238,10 @@ func (n *NIC) isPromiscuousMode() bool { return rv } +func (n *NIC) isLoopback() bool { + return n.linkEP.Capabilities()&CapabilityLoopback != 0 +} + // setSpoofing enables or disables address spoofing. func (n *NIC) setSpoofing(enable bool) { n.mu.Lock() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 583ede3e5..807f910f6 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -798,7 +798,7 @@ func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNum // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { +func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled bool) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -807,7 +807,7 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, return tcpip.ErrDuplicateNICID } - n := newNIC(s, id, name, ep, loopback) + n := newNIC(s, id, name, ep) s.nics[id] = n if enabled { @@ -819,32 +819,26 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, // CreateNIC creates a NIC with the provided id and link-layer endpoint. func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, "", ep, true, false) + return s.createNIC(id, "", ep, true) } // CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, // and a human-readable name. func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, true, false) -} - -// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer -// endpoint, and a human-readable name. -func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, true, true) + return s.createNIC(id, name, ep, true) } // CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, // but leave it disable. Stack.EnableNIC must be called before the link-layer // endpoint starts delivering packets to it. func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, "", ep, false, false) + return s.createNIC(id, "", ep, false) } // CreateDisabledNamedNIC is a combination of CreateNamedNIC and // CreateDisabledNIC. func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, false, false) + return s.createNIC(id, name, ep, false) } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -911,7 +905,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Up: true, // Netstack interfaces are always up. Running: nic.linkEP.IsAttached(), Promiscuous: nic.isPromiscuousMode(), - Loopback: nic.linkEP.Capabilities()&CapabilityLoopback != 0, + Loopback: nic.isLoopback(), } nics[id] = NICInfo{ Name: nic.name, @@ -1072,7 +1066,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok { if ref := s.getRefEP(nic, localAddr, netProto); ref != nil { - return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback), nil + return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil } } } else { @@ -1088,7 +1082,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n remoteAddr = ref.ep.ID().LocalAddress } - r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback) + r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) if needRoute { r.NextHop = route.Gateway } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d970a4abb..bf057745e 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -2153,10 +2154,10 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { OpaqueIIDOpts: test.opaqueIIDOpts, } - e := channel.New(0, 1280, linkAddr1) + e := loopback.New() s := stack.New(opts) - if err := s.CreateNamedLoopbackNIC(nicID, nicName, e); err != nil { - t.Fatalf("CreateNamedLoopbackNIC(%d, %q, _) = %s", nicID, nicName, err) + if err := s.CreateNamedNIC(nicID, nicName, e); err != nil { + t.Fatalf("CreateNamedNIC(%d, %q, _) = %s", nicID, nicName, err) } addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) diff --git a/runsc/boot/network.go b/runsc/boot/network.go index dd4926bb9..0240fe323 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -126,7 +126,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct linkEP := loopback.New() log.Infof("Enabling loopback interface %q with id %d on addresses %+v", link.Name, nicID, link.Addresses) - if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, true /* loopback */); err != nil { + if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil { return err } @@ -173,7 +173,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels) - if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, false /* loopback */); err != nil { + if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil { return err } @@ -218,15 +218,9 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct // createNICWithAddrs creates a NIC in the network stack and adds the given // addresses. -func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP, loopback bool) error { - if loopback { - if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(ep)); err != nil { - return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v, %v) failed: %v", id, name, ep, err) - } - } else { - if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(ep)); err != nil { - return fmt.Errorf("CreateNamedNIC(%v, %v, %v) failed: %v", id, name, ep, err) - } +func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP) error { + if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(ep)); err != nil { + return fmt.Errorf("CreateNamedNIC(%v, %v, %v) failed: %v", id, name, ep, err) } // Always start with an arp address for the NIC. -- cgit v1.2.3 From 9df018767cdfee5d837746b6dce6dafd9b9fcfce Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Wed, 8 Jan 2020 10:10:57 -0800 Subject: Remove redundant function argument PacketLooping is already a member on the passed Route. PiperOrigin-RevId: 288721500 --- pkg/tcpip/network/arp/arp.go | 6 +++--- pkg/tcpip/network/ip_test.go | 4 ++-- pkg/tcpip/network/ipv4/ipv4.go | 18 +++++++++--------- pkg/tcpip/network/ipv6/ipv6.go | 14 +++++++------- pkg/tcpip/stack/registration.go | 6 +++--- pkg/tcpip/stack/route.go | 6 +++--- pkg/tcpip/stack/stack_test.go | 10 +++++----- 7 files changed, 32 insertions(+), 32 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index da8482509..42cacb8a6 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -79,16 +79,16 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketLooping, tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4144a7837..f1bc33adf 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -239,7 +239,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -480,7 +480,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut, tcpip.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e645cf62c..4ee3d5b45 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -238,11 +238,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in // the PacketBuffer's Data field. views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) @@ -256,7 +256,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw loopedR.Release() } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { @@ -270,11 +270,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { - if loop&stack.PacketLoop != 0 { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { + if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return len(pkts), nil } @@ -289,7 +289,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. ip := header.IPv4(pkt.Data.First()) @@ -324,10 +324,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLo ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { e.HandlePacket(r, pkt.Clone()) } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index e13f1fabf..58c3c79b9 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -112,11 +112,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in // the PacketBuffer's Data field. views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) @@ -130,7 +130,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw loopedR.Release() } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } @@ -139,11 +139,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { - if loop&stack.PacketLoop != 0 { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { + if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return len(pkts), nil } @@ -161,7 +161,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { // TODO(b/146666412): Support IPv6 header-included packets. return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 61fd46d66..2b8751d49 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -234,15 +234,15 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. - WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, loop PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 34307ae07..517f4b941 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -158,7 +158,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.Pack return tcpip.ErrInvalidEndpointState } - err := r.ref.ep.WritePacket(r, gso, params, r.Loop, pkt) + err := r.ref.ep.WritePacket(r, gso, params, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { @@ -174,7 +174,7 @@ func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params Network return 0, tcpip.ErrInvalidEndpointState } - n, err := r.ref.ep.WritePackets(r, gso, pkts, params, r.Loop) + n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n)) } @@ -195,7 +195,7 @@ func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - if err := r.ref.ep.WriteHeaderIncludedPacket(r, r.Loop, pkt); err != nil { + if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index bf057745e..33f20579f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -124,7 +124,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -135,7 +135,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params b[1] = f.id.LocalAddress[0] b[2] = byte(params.Protocol) - if loop&stack.PacketLoop != 0 { + if r.Loop&stack.PacketLoop != 0 { views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) @@ -143,7 +143,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) } - if loop&stack.PacketOut == 0 { + if r.Loop&stack.PacketOut == 0 { return nil } @@ -151,11 +151,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) { +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, loop stack.PacketLooping, pkt tcpip.PacketBuffer) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -- cgit v1.2.3 From a271bccfc61390be64ca0175b8fc7d20e66d05b6 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Wed, 8 Jan 2020 14:16:38 -0800 Subject: Rename tcpip.SockOpt{,Int} PiperOrigin-RevId: 288772878 --- pkg/sentry/socket/netstack/netstack.go | 4 ++-- pkg/sentry/socket/unix/transport/unix.go | 8 ++++---- pkg/tcpip/stack/transport_test.go | 4 ++-- pkg/tcpip/tcpip.go | 10 +++++----- pkg/tcpip/transport/icmp/endpoint.go | 4 ++-- pkg/tcpip/transport/packet/endpoint.go | 4 ++-- pkg/tcpip/transport/raw/endpoint.go | 4 ++-- pkg/tcpip/transport/tcp/endpoint.go | 4 ++-- pkg/tcpip/transport/udp/endpoint.go | 4 ++-- 9 files changed, 23 insertions(+), 23 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 140851c17..5f91a0d1a 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -224,7 +224,7 @@ type commonEndpoint interface { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. - SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. @@ -232,7 +232,7 @@ type commonEndpoint interface { // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. - GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) + GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) } // SocketOperations encapsulates all the state needed to represent a network stack diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 529a7a7a9..fcba49435 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -177,7 +177,7 @@ type Endpoint interface { // SetSockOptInt sets a socket option for simple cases when a value has // the int type. - SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. @@ -185,7 +185,7 @@ type Endpoint interface { // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. - GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) + GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) // State returns the current state of the socket, as represented by Linux in // procfs. @@ -851,11 +851,11 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } -func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } -func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 748ce4ea5..095346f0b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -103,12 +103,12 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { } // SetSockOptInt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error { +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { return tcpip.ErrInvalidEndpointState } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f62fd729f..b172d71b0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -425,7 +425,7 @@ type Endpoint interface { // SetSockOptInt sets a socket option, for simple cases where a value // has the int type. - SetSockOptInt(opt SockOpt, v int) *Error + SetSockOptInt(opt SockOptInt, v int) *Error // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. @@ -433,7 +433,7 @@ type Endpoint interface { // GetSockOptInt gets a socket option for simple cases where a return // value has the int type. - GetSockOptInt(SockOpt) (int, *Error) + GetSockOptInt(SockOptInt) (int, *Error) // State returns a socket's lifecycle state. The returned value is // protocol-specific and is primarily used for diagnostics. @@ -488,13 +488,13 @@ type WriteOptions struct { Atomic bool } -// SockOpt represents socket options which values have the int type. -type SockOpt int +// SockOptInt represents socket options which values have the int type. +type SockOptInt int const ( // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the // number of unread bytes in the input buffer should be returned. - ReceiveQueueSizeOption SockOpt = iota + ReceiveQueueSizeOption SockOptInt = iota // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to // specify the send buffer size option. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 9c40931b5..5816ce49a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -351,12 +351,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } // SetSockOptInt sets a socket option. Currently not supported. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 0010b5e5f..6360ce880 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -251,12 +251,12 @@ func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 5aafe2615..0fd9c456a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -510,12 +510,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index fe629aa40..f79154b95 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1146,7 +1146,7 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool { } // SetSockOptInt sets a socket option. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max @@ -1447,7 +1447,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 1ac4705af..dae373ea7 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -457,7 +457,7 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } @@ -661,7 +661,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 -- cgit v1.2.3 From e21c5840569155d39e8e11ac18cee99bc6d67469 Mon Sep 17 00:00:00 2001 From: Bert Muthalaly Date: Wed, 8 Jan 2020 14:49:12 -0800 Subject: Combine various Create*NIC methods into CreateNICWithOptions. PiperOrigin-RevId: 288779416 --- pkg/tcpip/stack/ndp_test.go | 6 +-- pkg/tcpip/stack/stack.go | 46 ++++++++++------------ pkg/tcpip/stack/stack_test.go | 10 +++-- pkg/tcpip/stack/transport_demuxer_test.go | 5 ++- pkg/tcpip/transport/tcp/tcp_test.go | 5 ++- pkg/tcpip/transport/tcp/testing/context/context.go | 10 +++-- pkg/tcpip/transport/udp/udp_test.go | 5 ++- runsc/boot/network.go | 5 ++- 8 files changed, 47 insertions(+), 45 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 8d89859ba..070d80c8d 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -2500,9 +2500,9 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { SecretKey: secretKey, }, }) - - if err := s.CreateNamedNIC(nicID, nicName, e); err != nil { - t.Fatalf("CreateNamedNIC(%d, %q, _) = %s", nicID, nicName, err) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err) } expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 807f910f6..fb7ac409e 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -796,9 +796,21 @@ func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNum return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) } -// createNIC creates a NIC with the provided id and link-layer endpoint, and -// optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled bool) *tcpip.Error { +// NICOptions specifies the configuration of a NIC as it is being created. +// The zero value creates an enabled, unnamed NIC. +type NICOptions struct { + // Name specifies the name of the NIC. + Name string + + // Disabled specifies whether to avoid calling Attach on the passed + // LinkEndpoint. + Disabled bool +} + +// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and +// NICOptions. See the documentation on type NICOptions for details on how +// NICs can be configured. +func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -807,38 +819,20 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled return tcpip.ErrDuplicateNICID } - n := newNIC(s, id, name, ep) + n := newNIC(s, id, opts.Name, ep) s.nics[id] = n - if enabled { + if !opts.Disabled { return n.enable() } return nil } -// CreateNIC creates a NIC with the provided id and link-layer endpoint. +// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls +// `LinkEndpoint.Attach` to start delivering packets to it. func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, "", ep, true) -} - -// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, -// and a human-readable name. -func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, true) -} - -// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, -// but leave it disable. Stack.EnableNIC must be called before the link-layer -// endpoint starts delivering packets to it. -func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, "", ep, false) -} - -// CreateDisabledNamedNIC is a combination of CreateNamedNIC and -// CreateDisabledNIC. -func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { - return s.createNIC(id, name, ep, false) + return s.CreateNICWithOptions(id, ep, NICOptions{}) } // EnableNIC enables the given NIC so that the link-layer endpoint can start diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 33f20579f..9ac50bb23 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2097,8 +2097,9 @@ func TestNICAutoGenAddrWithOpaque(t *testing.T) { e := channel.New(10, 1280, test.linkAddr) s := stack.New(opts) - if err := s.CreateNamedNIC(nicID, test.nicName, e); err != nil { - t.Fatalf("CreateNamedNIC(%d, %q, _) = %s", nicID, test.nicName, err) + nicOpts := stack.NICOptions{Name: test.nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) } addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) @@ -2156,8 +2157,9 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { e := loopback.New() s := stack.New(opts) - if err := s.CreateNamedNIC(nicID, nicName, e); err != nil { - t.Fatalf("CreateNamedNIC(%d, %q, _) = %s", nicID, nicName, err) + nicOpts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) } addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 3b28b06d0..33dbc0536 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -80,8 +80,9 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) for i, linkEpName := range linkEpNames { channelEP := channel.New(256, mtu, "") nicID := tcpip.NICID(i + 1) - if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + opts := stack.NICOptions{Name: linkEpName} + if err := s.CreateNICWithOptions(nicID, channelEP, opts); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } linkEPs[linkEpName] = channelEP diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index e8fe4dab5..9d7b0910d 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3794,8 +3794,9 @@ func TestBindToDeviceOption(t *testing.T) { } defer ep.Close() - if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { - t.Errorf("CreateNamedNIC failed: %v", err) + opts := stack.NICOptions{Name: "my_device"} + if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { + t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } // Make an nameless NIC. diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index b0a376eba..50c81aa65 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -158,15 +158,17 @@ func New(t *testing.T, mtu uint32) *Context { if testing.Verbose() { wep = sniffer.New(ep) } - if err := s.CreateNamedNIC(1, "nic1", wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + opts := stack.NICOptions{Name: "nic1"} + if err := s.CreateNICWithOptions(1, wep, opts); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) if testing.Verbose() { wep2 = sniffer.New(channel.New(1000, mtu, "")) } - if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + opts2 := stack.NICOptions{Name: "nic2"} + if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 7051a7a9c..65382b7f1 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -508,8 +508,9 @@ func TestBindToDeviceOption(t *testing.T) { } defer ep.Close() - if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { - t.Errorf("CreateNamedNIC failed: %v", err) + opts := stack.NICOptions{Name: "my_device"} + if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { + t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } // Make an nameless NIC. diff --git a/runsc/boot/network.go b/runsc/boot/network.go index 0240fe323..6a8765ec8 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -219,8 +219,9 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct // createNICWithAddrs creates a NIC in the network stack and adds the given // addresses. func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP) error { - if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(ep)); err != nil { - return fmt.Errorf("CreateNamedNIC(%v, %v, %v) failed: %v", id, name, ep, err) + opts := stack.NICOptions{Name: name} + if err := n.Stack.CreateNICWithOptions(id, sniffer.New(ep), opts); err != nil { + return fmt.Errorf("CreateNICWithOptions(%d, _, %+v) failed: %v", id, opts, err) } // Always start with an arp address for the NIC. -- cgit v1.2.3 From d530df2f95c3f75488ecc56b8fd205c3ee0966f8 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Wed, 8 Jan 2020 15:39:22 -0800 Subject: Introduce tcpip.SockOptBool ...and port V6OnlyOption to it. PiperOrigin-RevId: 288789451 --- pkg/sentry/socket/netstack/netstack.go | 21 ++++-- pkg/sentry/socket/unix/transport/unix.go | 16 +++++ pkg/tcpip/stack/ndp_test.go | 15 ++--- pkg/tcpip/stack/transport_demuxer_test.go | 48 +++++++------- pkg/tcpip/stack/transport_test.go | 10 +++ pkg/tcpip/tcpip.go | 21 ++++-- pkg/tcpip/transport/icmp/endpoint.go | 10 +++ pkg/tcpip/transport/packet/endpoint.go | 22 +++++-- pkg/tcpip/transport/raw/endpoint.go | 40 +++++++----- pkg/tcpip/transport/tcp/dual_stack_test.go | 8 +-- pkg/tcpip/transport/tcp/endpoint.go | 75 ++++++++++++---------- pkg/tcpip/transport/tcp/tcp_test.go | 8 +-- pkg/tcpip/transport/tcp/testing/context/context.go | 6 +- pkg/tcpip/transport/udp/endpoint.go | 60 +++++++++-------- pkg/tcpip/transport/udp/udp_test.go | 2 +- 15 files changed, 224 insertions(+), 138 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 5f91a0d1a..9e0d69046 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -222,6 +222,10 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(interface{}) *tcpip.Error + // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and + // transport.Endpoint.SetSockOptBool. + SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -230,6 +234,10 @@ type commonEndpoint interface { // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error + // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and + // transport.Endpoint.GetSockOpt. + GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) @@ -1213,12 +1221,15 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf return nil, syserr.ErrInvalidArgument } - var v tcpip.V6OnlyOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.V6OnlyOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + var o uint32 + if v { + o = 1 + } + return int32(o), nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1621,7 +1632,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.V6OnlyOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0)) case linux.IPV6_ADD_MEMBERSHIP, linux.IPV6_DROP_MEMBERSHIP, diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index fcba49435..37c7ac3c1 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -175,6 +175,10 @@ type Endpoint interface { // types. SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOptBool sets a socket option for simple cases when a value has + // the int type. + SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error + // SetSockOptInt sets a socket option for simple cases when a value has // the int type. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -183,6 +187,10 @@ type Endpoint interface { // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error + // GetSockOptBool gets a socket option for simple cases when a return + // value has the int type. + GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) + // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) @@ -851,10 +859,18 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + return nil +} + func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 070d80c8d..e51462a55 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1701,9 +1701,8 @@ func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - v := tcpip.V6OnlyOption(1) - if err := ep.SetSockOpt(v); err != nil { - t.Fatalf("SetSockOpt(%+v): %s", v, err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) } if err := ep.Connect(dstAddr); err != nil { t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) @@ -1728,9 +1727,8 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - v := tcpip.V6OnlyOption(1) - if err := ep.SetSockOpt(v); err != nil { - t.Fatalf("SetSockOpt(%+v): %s", v, err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) } if err := ep.Bind(addr); err != nil { t.Fatalf("ep.Bind(%+v): %s", addr, err) @@ -2066,9 +2064,8 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - v := tcpip.V6OnlyOption(1) - if err := ep.SetSockOpt(v); err != nil { - t.Fatalf("SetSockOpt(%+v): %s", v, err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) } if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 33dbc0536..df5ced887 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -61,11 +61,7 @@ func (c *testContext) createV6Endpoint(v6only bool) { c.t.Fatalf("NewEndpoint failed: %v", err) } - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.ep.SetSockOpt(v); err != nil { + if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { c.t.Fatalf("SetSockOpt failed: %v", err) } } @@ -201,54 +197,54 @@ func TestDistribution(t *testing.T) { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, - endpointSockopts{1, ""}, + {1, ""}, + {1, ""}, + {1, ""}, + {1, ""}, + {1, ""}, }, map[string][]float64{ // Injected packets on dev0 get distributed evenly. - "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, + "dev0": {0.2, 0.2, 0.2, 0.2, 0.2}, }, }, { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - endpointSockopts{0, "dev0"}, - endpointSockopts{0, "dev1"}, - endpointSockopts{0, "dev2"}, + {0, "dev0"}, + {0, "dev1"}, + {0, "dev2"}, }, map[string][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. - "dev0": []float64{1, 0, 0}, + "dev0": {1, 0, 0}, // Injected packets on dev1 go only to the endpoint bound to dev1. - "dev1": []float64{0, 1, 0}, + "dev1": {0, 1, 0}, // Injected packets on dev2 go only to the endpoint bound to dev2. - "dev2": []float64{0, 0, 1}, + "dev2": {0, 0, 1}, }, }, { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - endpointSockopts{1, "dev0"}, - endpointSockopts{1, "dev0"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, "dev1"}, - endpointSockopts{1, ""}, + {1, "dev0"}, + {1, "dev0"}, + {1, "dev1"}, + {1, "dev1"}, + {1, "dev1"}, + {1, ""}, }, map[string][]float64{ // Injected packets on dev0 get distributed among endpoints bound to // dev0. - "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, + "dev0": {0.5, 0.5, 0, 0, 0, 0}, // Injected packets on dev1 get distributed among endpoints bound to // dev1 or unbound. - "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + "dev1": {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, // Injected packets on dev999 go only to the unbound. - "dev999": []float64{0, 0, 0, 0, 0, 1}, + "dev999": {0, 0, 0, 0, 0, 1}, }, }, } { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 095346f0b..f50604a8a 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -102,11 +102,21 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// SetSockOptBool sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + // SetSockOptInt sets a socket option. Currently not supported. func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index b172d71b0..1eca76c30 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -423,6 +423,10 @@ type Endpoint interface { // SetSockOpt sets a socket option. opt should be one of the *Option types. SetSockOpt(opt interface{}) *Error + // SetSockOptBool sets a socket option, for simple cases where a value + // has the bool type. + SetSockOptBool(opt SockOptBool, v bool) *Error + // SetSockOptInt sets a socket option, for simple cases where a value // has the int type. SetSockOptInt(opt SockOptInt, v int) *Error @@ -431,6 +435,10 @@ type Endpoint interface { // *Option types. GetSockOpt(opt interface{}) *Error + // GetSockOptBool gets a socket option for simple cases where a return + // value has the bool type. + GetSockOptBool(SockOptBool) (bool, *Error) + // GetSockOptInt gets a socket option for simple cases where a return // value has the int type. GetSockOptInt(SockOptInt) (int, *Error) @@ -488,6 +496,15 @@ type WriteOptions struct { Atomic bool } +// SockOptBool represents socket options which values have the bool type. +type SockOptBool int + +const ( + // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 + // socket is to be restricted to sending and receiving IPv6 packets only. + V6OnlyOption SockOptBool = iota +) + // SockOptInt represents socket options which values have the int type. type SockOptInt int @@ -521,10 +538,6 @@ const ( // the endpoint should be cleared and returned. type ErrorOption struct{} -// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6 -// socket is to be restricted to sending and receiving IPv6 packets only. -type V6OnlyOption int - // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be // held until segments are full by the TCP transport protocol. type CorkOption int diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 5816ce49a..c7ce74cdd 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -350,11 +350,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// SetSockOptBool sets a socket option. Currently not supported. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + return nil +} + // SetSockOptInt sets a socket option. Currently not supported. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 6360ce880..07ffa8aba 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -247,17 +247,17 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrNotSupported + return tcpip.ErrUnknownProtocolOption } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. -func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { - return 0, tcpip.ErrNotSupported +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. @@ -265,6 +265,16 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrNotSupported } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrNotSupported +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { + return 0, tcpip.ErrNotSupported +} + // HandlePacket implements stack.PacketEndpoint.HandlePacket. func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { ep.rcvMu.Lock() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 0fd9c456a..85f7eb76b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -509,11 +509,36 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.KeepaliveEnabledOption: + *o = 0 + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -544,21 +569,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } -} - // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) { e.rcvMu.Lock() diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index dfaa4a559..4f361b226 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -391,9 +391,8 @@ func testV4Accept(t *testing.T, c *context.Context) { // Make sure we get the same error when calling the original ep and the // new one. This validates that v4-mapped endpoints are still able to // query the V6Only flag, whereas pure v4 endpoints are not. - var v tcpip.V6OnlyOption - expected := c.EP.GetSockOpt(&v) - if err := nep.GetSockOpt(&v); err != expected { + _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption) + if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected { t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected) } @@ -531,8 +530,7 @@ func TestV6AcceptOnV6(t *testing.T) { // Make sure we can still query the v6 only status of the new endpoint, // that is, that it is in fact a v6 socket. - var v tcpip.V6OnlyOption - if err := nep.GetSockOpt(&v); err != nil { + if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { t.Fatalf("GetSockOpt failed failed: %v", err) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f79154b95..2ac1b6877 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1145,6 +1145,29 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool { return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 } +// SetSockOptBool sets a socket option. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + // We only allow this to be set when we're in the initial state. + if e.state != StateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.v6only = v + } + + return nil +} + // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { @@ -1289,23 +1312,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.notifyProtocolGoroutine(notifyMSSChanged) return nil - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } - - e.mu.Lock() - defer e.mu.Unlock() - - // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { - return tcpip.ErrInvalidEndpointState - } - - e.v6only = v != 0 - return nil - case tcpip.TTLOption: e.mu.Lock() e.ttl = uint8(v) @@ -1446,6 +1452,25 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { return e.rcvBufUsed, nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + return v, nil + } + + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -1540,22 +1565,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil - case *tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrUnknownProtocolOption - } - - e.mu.Lock() - v := e.v6only - e.mu.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil - case *tcpip.TTLOption: e.mu.Lock() *o = tcpip.TTLOption(e.ttl) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9d7b0910d..15745ebd4 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4028,12 +4028,12 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { switch network { case "ipv4": case "ipv6": - if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err) } case "dual": - if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err) + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil { + t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err) } default: t.Fatalf("unknown network: '%s'", network) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 50c81aa65..822907998 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -475,11 +475,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) { c.t.Fatalf("NewEndpoint failed: %v", err) } - var v tcpip.V6OnlyOption - if v6only { - v = 1 - } - if err := c.EP.SetSockOpt(v); err != nil { + if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { c.t.Fatalf("SetSockOpt failed failed: %v", err) } } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index dae373ea7..1a5ee6317 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -456,14 +456,9 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { - return nil -} - -// SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -478,8 +473,20 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - e.v6only = v != 0 + e.v6only = v + } + + return nil +} +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { case tcpip.TTLOption: e.mu.Lock() e.ttl = uint8(v) @@ -660,6 +667,25 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + return v, nil + } + + return false, tcpip.ErrUnknownProtocolOption +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -695,22 +721,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrUnknownProtocolOption - } - - e.mu.Lock() - v := e.v6only - e.mu.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil - case *tcpip.TTLOption: e.mu.Lock() *o = tcpip.TTLOption(e.ttl) diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 65382b7f1..149fff999 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -335,7 +335,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { c.createEndpoint(flow.sockProto()) if flow.isV6Only() { - if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { c.t.Fatalf("SetSockOpt failed: %v", err) } } else if flow.isBroadcast() { -- cgit v1.2.3 From 0999ae8b34d83a4b2ea8342d0459c8131c35d6e1 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 8 Jan 2020 15:57:25 -0800 Subject: Getting a panic when running tests. For some reason the filter table is ending up with the wrong chains and is indexing -1 into rules. --- pkg/sentry/socket/netfilter/netfilter.go | 17 ++++++---------- pkg/sentry/socket/netstack/netstack.go | 12 +++++++++-- pkg/tcpip/BUILD | 1 - pkg/tcpip/iptables/BUILD | 1 + pkg/tcpip/iptables/iptables.go | 35 +++++++++++++++++++++++++------- pkg/tcpip/iptables/targets.go | 8 ++++---- pkg/tcpip/iptables/types.go | 8 +++----- pkg/tcpip/network/arp/arp.go | 2 +- pkg/tcpip/network/ipv4/BUILD | 1 + pkg/tcpip/network/ipv4/ipv4.go | 8 ++++++-- pkg/tcpip/network/ipv6/ipv6.go | 2 +- pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/registration.go | 2 +- pkg/tcpip/tcpip.go | 4 ---- 14 files changed, 63 insertions(+), 40 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 57785220e..3a857ef6d 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -45,7 +44,7 @@ type metadata struct { } // GetInfo returns information about iptables. -func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) { +func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. var info linux.IPTGetinfo if _, err := t.CopyIn(outPtr, &info); err != nil { @@ -53,7 +52,7 @@ func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTG } // Find the appropriate table. - table, err := findTable(ep, info.Name.String()) + table, err := findTable(stack, info.Name.String()) if err != nil { return linux.IPTGetinfo{}, err } @@ -76,7 +75,7 @@ func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTG } // GetEntries returns netstack's iptables rules encoded for the iptables tool. -func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { +func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { // Read in the struct and table name. var userEntries linux.IPTGetEntries if _, err := t.CopyIn(outPtr, &userEntries); err != nil { @@ -84,7 +83,7 @@ func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen i } // Find the appropriate table. - table, err := findTable(ep, userEntries.Name.String()) + table, err := findTable(stack, userEntries.Name.String()) if err != nil { return linux.KernelIPTGetEntries{}, err } @@ -103,12 +102,8 @@ func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen i return entries, nil } -func findTable(ep tcpip.Endpoint, tableName string) (iptables.Table, *syserr.Error) { - ipt, err := ep.IPTables() - if err != nil { - return iptables.Table{}, syserr.FromError(err) - } - table, ok := ipt.Tables[tableName] +func findTable(stack *stack.Stack, tableName string) (iptables.Table, *syserr.Error) { + table, ok := stack.IPTables().Tables[tableName] if !ok { return iptables.Table{}, syserr.ErrInvalidArgument } diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 8c07eef4b..86a8104df 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -826,7 +826,11 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us return nil, syserr.ErrInvalidArgument } - info, err := netfilter.GetInfo(t, s.Endpoint, outPtr) + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr) if err != nil { return nil, err } @@ -837,7 +841,11 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us return nil, syserr.ErrInvalidArgument } - entries, err := netfilter.GetEntries(t, s.Endpoint, outPtr, outLen) + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen) if err != nil { return nil, err } diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 65d4d0cd8..36bc3a63b 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -15,7 +15,6 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/tcpip/buffer", - "//pkg/tcpip/iptables", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index 6ed7c6da0..2893c80cd 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -12,6 +12,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables", visibility = ["//visibility:public"], deps = [ + "//pkg/log", "//pkg/tcpip", ], ) diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index 025a4679d..aff8a680b 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -16,7 +16,12 @@ // tool. package iptables -import "github.com/google/netstack/tcpip" +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/tcpip" +) const ( TablenameNat = "nat" @@ -135,31 +140,47 @@ func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool { // Go through each table containing the hook. for _, tablename := range it.Priorities[hook] { - verdict := it.checkTable(tablename) + verdict := it.checkTable(hook, pkt, tablename) switch verdict { - // TODO: We either got a final verdict or simply continue on. + // If the table returns Accept, move on to the next table. + case Accept: + continue + // The Drop verdict is final. + case Drop: + log.Infof("kevin: Packet dropped") + return false + case Stolen, Queue, Repeat, None, Jump, Return, Continue: + panic(fmt.Sprintf("Unimplemented verdict %v.", verdict)) } } + + // Every table returned Accept. + log.Infof("kevin: Packet accepted") + return true } -func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename string) bool { +func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename string) Verdict { log.Infof("kevin: iptables.IPTables: checking table %q", tablename) table := it.Tables[tablename] - ruleIdx := table.BuiltinChains[hook] + log.Infof("kevin: iptables.IPTables: table %+v", table) // Start from ruleIdx and go down until a rule gives us a verdict. for ruleIdx := table.BuiltinChains[hook]; ruleIdx < len(table.Rules); ruleIdx++ { - verdict := checkRule(hook, pkt, table, ruleIdx) + verdict := it.checkRule(hook, pkt, table, ruleIdx) switch verdict { + // For either of these cases, this table is done with the + // packet. case Accept, Drop: return verdict + // Continue traversing the rules of the table. case Continue: continue case Stolen, Queue, Repeat, None, Jump, Return: + panic(fmt.Sprintf("Unimplemented verdict %v.", verdict)) } } - panic("Traversed past the entire list of iptables rules.") + panic(fmt.Sprintf("Traversed past the entire list of iptables rules in table %q.", tablename)) } func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) Verdict { diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go index 2c3598e3d..cb3ac1aff 100644 --- a/pkg/tcpip/iptables/targets.go +++ b/pkg/tcpip/iptables/targets.go @@ -16,13 +16,13 @@ package iptables -import "gvisor.dev/gvisor/pkg/tcpip/buffer" +import "gvisor.dev/gvisor/pkg/tcpip" // UnconditionalAcceptTarget accepts all packets. type UnconditionalAcceptTarget struct{} // Action implements Target.Action. -func (UnconditionalAcceptTarget) Action(packet buffer.VectorisedView) (Verdict, string) { +func (UnconditionalAcceptTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) { return Accept, "" } @@ -30,7 +30,7 @@ func (UnconditionalAcceptTarget) Action(packet buffer.VectorisedView) (Verdict, type UnconditionalDropTarget struct{} // Action implements Target.Action. -func (UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, string) { +func (UnconditionalDropTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) { return Drop, "" } @@ -38,6 +38,6 @@ func (UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, st type PanicTarget struct{} // Actions implements Target.Action. -func (PanicTarget) Action(packet buffer.VectorisedView) (Verdict, string) { +func (PanicTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) { panic("PanicTarget triggered.") } diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 540f8c0b4..9f6906100 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -14,9 +14,7 @@ package iptables -import ( - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) +import "gvisor.dev/gvisor/pkg/tcpip" // A Hook specifies one of the hooks built into the network stack. // @@ -165,7 +163,7 @@ type Matcher interface { // Match returns whether the packet matches and whether the packet // should be "hotdropped", i.e. dropped immediately. This is usually // used for suspicious packets. - Match(hook Hook, packet buffer.VectorisedView, interfaceName string) (matches bool, hotdrop bool) + Match(hook Hook, packet tcpip.PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } // A Target is the interface for taking an action for a packet. @@ -173,5 +171,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the name of the chain to jump to. - Action(packet buffer.VectorisedView) (Verdict, string) + Action(packet tcpip.PacketBuffer) (Verdict, string) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index da8482509..d88119f68 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -137,7 +137,7 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { if addrWithPrefix.Address != ProtocolAddress { return nil, tcpip.ErrBadLocalAddress } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index aeddfcdd4..4e2aae9a3 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -15,6 +15,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/iptables", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index bbb5aafee..f856081e6 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -54,10 +55,11 @@ type endpoint struct { dispatcher stack.TransportDispatcher fragmentation *fragmentation.Fragmentation protocol *protocol + stack *stack.Stack } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { e := &endpoint{ nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, @@ -66,6 +68,7 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi dispatcher: dispatcher, fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), protocol: p, + stack: st, } return e, nil @@ -351,7 +354,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { pkt.NetworkHeader = headerView[:h.HeaderLength()] // iptables filtering. - if ok := iptables.Check(iptables.Input, pkt); !ok { + ipt := e.stack.IPTables() + if ok := ipt.Check(iptables.Input, pkt); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index e13f1fabf..4c940e9e5 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -221,7 +221,7 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { return &endpoint{ nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 4144d5d0f..f2d338bd1 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -467,7 +467,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } // Create the new network endpoint. - ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP) + ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP, n.stack) if err != nil { return nil, err } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 61fd46d66..754323e82 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -282,7 +282,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) + NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) (NetworkEndpoint, *tcpip.Error) // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f62fd729f..d02950c7a 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -40,7 +40,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/waiter" ) @@ -446,9 +445,6 @@ type Endpoint interface { // NOTE: This method is a no-op for sockets other than TCP. ModerateRecvBuf(copied int) - // IPTables returns the iptables for this endpoint's stack. - IPTables() (iptables.IPTables, error) - // Info returns a copy to the transport endpoint info. Info() EndpointInfo -- cgit v1.2.3 From d057871f410088fe6825b1dde695f015e36abf73 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 8 Jan 2020 17:19:35 -0800 Subject: CancellableTimer to encapsulate the work of safely stopping timers Add a new CancellableTimer type to encapsulate the work of safely stopping timers when it fires at the same time some "related work" is being handled. The term "related work" is some work that needs to be done while having obtained some common lock (L). Example: Say we have an invalidation timer that may be extended or cancelled by some event. Creating a normal timer and simply cancelling may not be sufficient as the timer may have already fired when the event handler attemps to cancel it. Even if the timer and event handler obtains L before doing work, once the event handler releases L, the timer will eventually obtain L and do some unwanted work. To prevent the timer from doing unwanted work, it checks if it should early return instead of doing the normal work after obtaining L. When stopping the timer callers must have L locked so the timer can be safely informed that it should early return. Test: Tests that CancellableTimer fires and resets properly. Test to make sure the timer fn is not called after being stopped within the lock L. PiperOrigin-RevId: 288806984 --- pkg/tcpip/BUILD | 8 + pkg/tcpip/stack/ndp.go | 349 ++++++++------------------------------------ pkg/tcpip/stack/ndp_test.go | 12 +- pkg/tcpip/timer.go | 161 ++++++++++++++++++++ pkg/tcpip/timer_test.go | 236 ++++++++++++++++++++++++++++++ 5 files changed, 473 insertions(+), 293 deletions(-) create mode 100644 pkg/tcpip/timer.go create mode 100644 pkg/tcpip/timer_test.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 65d4d0cd8..e07ebd153 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -10,6 +10,7 @@ go_library( "packet_buffer_state.go", "tcpip.go", "time_unsafe.go", + "timer.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip", visibility = ["//visibility:public"], @@ -26,3 +27,10 @@ go_test( srcs = ["tcpip_test.go"], embed = [":tcpip"], ) + +go_test( + name = "timer_test", + size = "small", + srcs = ["timer_test.go"], + deps = [":tcpip"], +) diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 4722ec9ce..35825ebf7 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -299,46 +299,14 @@ type dadState struct { // defaultRouterState holds data associated with a default router discovered by // a Router Advertisement (RA). type defaultRouterState struct { - invalidationTimer *time.Timer - - // Used to inform the timer not to invalidate the default router (R) in - // a race condition (T1 is a goroutine that handles an RA from R and T2 - // is the goroutine that handles R's invalidation timer firing): - // T1: Receive a new RA from R - // T1: Obtain the NIC's lock before processing the RA - // T2: R's invalidation timer fires, and gets blocked on obtaining the - // NIC's lock - // T1: Refreshes/extends R's lifetime & releases NIC's lock - // T2: Obtains NIC's lock & invalidates R immediately - // - // To resolve this, T1 will check to see if the timer already fired, and - // inform the timer using doNotInvalidate to not invalidate R, so that - // once T2 obtains the lock, it will see that it is set to true and do - // nothing further. - doNotInvalidate *bool + invalidationTimer tcpip.CancellableTimer } // onLinkPrefixState holds data associated with an on-link prefix discovered by // a Router Advertisement's Prefix Information option (PI) when the NDP // configurations was configured to do so. type onLinkPrefixState struct { - invalidationTimer *time.Timer - - // Used to signal the timer not to invalidate the on-link prefix (P) in - // a race condition (T1 is a goroutine that handles a PI for P and T2 - // is the goroutine that handles P's invalidation timer firing): - // T1: Receive a new PI for P - // T1: Obtain the NIC's lock before processing the PI - // T2: P's invalidation timer fires, and gets blocked on obtaining the - // NIC's lock - // T1: Refreshes/extends P's lifetime & releases NIC's lock - // T2: Obtains NIC's lock & invalidates P immediately - // - // To resolve this, T1 will check to see if the timer already fired, and - // inform the timer using doNotInvalidate to not invalidate P, so that - // once T2 obtains the lock, it will see that it is set to true and do - // nothing further. - doNotInvalidate *bool + invalidationTimer tcpip.CancellableTimer } // autoGenAddressState holds data associated with an address generated via @@ -348,33 +316,10 @@ type autoGenAddressState struct { // is holding state for. ref *referencedNetworkEndpoint - deprecationTimer *time.Timer - - // Used to signal the timer not to deprecate the SLAAC address in a race - // condition. Used for the same reason as doNotInvalidate, but for deprecating - // an address. - doNotDeprecate *bool + deprecationTimer tcpip.CancellableTimer + invalidationTimer tcpip.CancellableTimer - invalidationTimer *time.Timer - - // Used to signal the timer not to invalidate the SLAAC address (A) in - // a race condition (T1 is a goroutine that handles a PI for A and T2 - // is the goroutine that handles A's invalidation timer firing): - // T1: Receive a new PI for A - // T1: Obtain the NIC's lock before processing the PI - // T2: A's invalidation timer fires, and gets blocked on obtaining the - // NIC's lock - // T1: Refreshes/extends A's lifetime & releases NIC's lock - // T2: Obtains NIC's lock & invalidates A immediately - // - // To resolve this, T1 will check to see if the timer already fired, and - // inform the timer using doNotInvalidate to not invalidate A, so that - // once T2 obtains the lock, it will see that it is set to true and do - // nothing further. - doNotInvalidate *bool - - // Nonzero only when the address is not valid forever (invalidationTimer - // is not nil). + // Nonzero only when the address is not valid forever. validUntil time.Time } @@ -576,7 +521,7 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { // handleRA handles a Router Advertisement message that arrived on the NIC // this ndp is for. Does nothing if the NIC is configured to not handle RAs. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // Is the NIC configured to handle RAs at all? // @@ -605,27 +550,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { case ok && rl != 0: // This is an already discovered default router. Update // the invalidation timer. - timer := rtr.invalidationTimer - - // We should ALWAYS have an invalidation timer for a - // discovered router. - if timer == nil { - panic("ndphandlera: RA invalidation timer should not be nil") - } - - if !timer.Stop() { - // If we reach this point, then we know the - // timer fired after we already took the NIC - // lock. Inform the timer not to invalidate the - // router when it obtains the lock as we just - // got a new RA that refreshes its lifetime to a - // non-zero value. See - // defaultRouterState.doNotInvalidate for more - // details. - *rtr.doNotInvalidate = true - } - - timer.Reset(rl) + rtr.invalidationTimer.StopLocked() + rtr.invalidationTimer.Reset(rl) + ndp.defaultRouters[ip] = rtr case ok && rl == 0: // We know about the router but it is no longer to be @@ -692,10 +619,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { return } - rtr.invalidationTimer.Stop() - rtr.invalidationTimer = nil - *rtr.doNotInvalidate = true - rtr.doNotInvalidate = nil + rtr.invalidationTimer.StopLocked() delete(ndp.defaultRouters, ip) @@ -724,27 +648,15 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { return } - // Used to signal the timer not to invalidate the default router (R) in - // a race condition. See defaultRouterState.doNotInvalidate for more - // details. - var doNotInvalidate bool - - ndp.defaultRouters[ip] = defaultRouterState{ - invalidationTimer: time.AfterFunc(rl, func() { - ndp.nic.stack.mu.Lock() - defer ndp.nic.stack.mu.Unlock() - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if doNotInvalidate { - doNotInvalidate = false - return - } - + state := defaultRouterState{ + invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { ndp.invalidateDefaultRouter(ip) }), - doNotInvalidate: &doNotInvalidate, } + + state.invalidationTimer.Reset(rl) + + ndp.defaultRouters[ip] = state } // rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 @@ -766,21 +678,17 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) return } - // Used to signal the timer not to invalidate the on-link prefix (P) in - // a race condition. See onLinkPrefixState.doNotInvalidate for more - // details. - var doNotInvalidate bool - var timer *time.Timer + state := onLinkPrefixState{ + invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + ndp.invalidateOnLinkPrefix(prefix) + }), + } - // Only create a timer if the lifetime is not infinite. if l < header.NDPInfiniteLifetime { - timer = ndp.prefixInvalidationCallback(prefix, l, &doNotInvalidate) + state.invalidationTimer.Reset(l) } - ndp.onLinkPrefixes[prefix] = onLinkPrefixState{ - invalidationTimer: timer, - doNotInvalidate: &doNotInvalidate, - } + ndp.onLinkPrefixes[prefix] = state } // invalidateOnLinkPrefix invalidates a discovered on-link prefix. @@ -795,13 +703,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { return } - if s.invalidationTimer != nil { - s.invalidationTimer.Stop() - s.invalidationTimer = nil - *s.doNotInvalidate = true - } - - s.doNotInvalidate = nil + s.invalidationTimer.StopLocked() delete(ndp.onLinkPrefixes, prefix) @@ -811,28 +713,6 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { } } -// prefixInvalidationCallback returns a new on-link prefix invalidation timer -// for prefix that fires after vl. -// -// doNotInvalidate is used to signal the timer when it fires at the same time -// that a prefix's valid lifetime gets refreshed. See -// onLinkPrefixState.doNotInvalidate for more details. -func (ndp *ndpState) prefixInvalidationCallback(prefix tcpip.Subnet, vl time.Duration, doNotInvalidate *bool) *time.Timer { - return time.AfterFunc(vl, func() { - ndp.nic.stack.mu.Lock() - defer ndp.nic.stack.mu.Unlock() - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if *doNotInvalidate { - *doNotInvalidate = false - return - } - - ndp.invalidateOnLinkPrefix(prefix) - }) -} - // handleOnLinkPrefixInformation handles a Prefix Information option with // its on-link flag set, as per RFC 4861 section 6.3.4. // @@ -872,42 +752,17 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // This is an already discovered on-link prefix with a // new non-zero valid lifetime. + // // Update the invalidation timer. - timer := prefixState.invalidationTimer - - if timer == nil && vl >= header.NDPInfiniteLifetime { - // Had infinite valid lifetime before and - // continues to have an invalid lifetime. Do - // nothing further. - return - } - if timer != nil && !timer.Stop() { - // If we reach this point, then we know the timer alread fired - // after we took the NIC lock. Inform the timer to not - // invalidate the prefix once it obtains the lock as we just - // got a new PI that refreshes its lifetime to a non-zero value. - // See onLinkPrefixState.doNotInvalidate for more details. - *prefixState.doNotInvalidate = true - } + prefixState.invalidationTimer.StopLocked() - if vl >= header.NDPInfiniteLifetime { - // Prefix is now valid forever so we don't need - // an invalidation timer. - prefixState.invalidationTimer = nil - ndp.onLinkPrefixes[prefix] = prefixState - return - } - - if timer != nil { - // We already have a timer so just reset it to - // expire after the new valid lifetime. - timer.Reset(vl) - return + if vl < header.NDPInfiniteLifetime { + // Prefix is valid for a finite lifetime, reset the timer to expire after + // the new valid lifetime. + prefixState.invalidationTimer.Reset(vl) } - // We do not have a timer so just create a new one. - prefixState.invalidationTimer = ndp.prefixInvalidationCallback(prefix, vl, prefixState.doNotInvalidate) ndp.onLinkPrefixes[prefix] = prefixState } @@ -917,7 +772,7 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // handleAutonomousPrefixInformation assumes that the prefix this pi is for is // not the link-local prefix and the autonomous flag is set. // -// The NIC that ndp belongs to and its associated stack MUST be locked. +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) { vl := pi.ValidLifetime() pl := pi.PreferredLifetime() @@ -1026,28 +881,34 @@ func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err) } - // Setup the timers to deprecate and invalidate this newly generated address. + state := autoGenAddressState{ + ref: ref, + deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + addrState, ok := ndp.autoGenAddresses[addr] + if !ok { + log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr) + } + addrState.ref.deprecated = true + ndp.notifyAutoGenAddressDeprecated(addr) + }), + invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + ndp.invalidateAutoGenAddress(addr) + }), + } + + // Setup the initial timers to deprecate and invalidate this newly generated + // address. - var doNotDeprecate bool - var pTimer *time.Timer if !deprecated && pl < header.NDPInfiniteLifetime { - pTimer = ndp.autoGenAddrDeprecationTimer(addr, pl, &doNotDeprecate) + state.deprecationTimer.Reset(pl) } - var doNotInvalidate bool - var vTimer *time.Timer if vl < header.NDPInfiniteLifetime { - vTimer = ndp.autoGenAddrInvalidationTimer(addr, vl, &doNotInvalidate) + state.invalidationTimer.Reset(vl) + state.validUntil = time.Now().Add(vl) } - ndp.autoGenAddresses[addr] = autoGenAddressState{ - ref: ref, - deprecationTimer: pTimer, - doNotDeprecate: &doNotDeprecate, - invalidationTimer: vTimer, - doNotInvalidate: &doNotInvalidate, - validUntil: time.Now().Add(vl), - } + ndp.autoGenAddresses[addr] = state } // refreshAutoGenAddressLifetimes refreshes the lifetime of a SLAAC generated @@ -1075,20 +936,10 @@ func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl t // If addr was preferred for some finite lifetime before, stop the deprecation // timer so it can be reset. - if t := addrState.deprecationTimer; t != nil && !t.Stop() { - *addrState.doNotDeprecate = true - } + addrState.deprecationTimer.StopLocked() - // Reset the deprecation timer. - if pl >= header.NDPInfiniteLifetime || deprecated { - // If addr is preferred forever or it has been deprecated already, there is - // no need for a deprecation timer. - addrState.deprecationTimer = nil - } else if addrState.deprecationTimer == nil { - // addr is now preferred for a finite lifetime. - addrState.deprecationTimer = ndp.autoGenAddrDeprecationTimer(addr, pl, addrState.doNotDeprecate) - } else { - // addr continues to be preferred for a finite lifetime. + // Reset the deprecation timer if addr has a finite preferred lifetime. + if !deprecated && pl < header.NDPInfiniteLifetime { addrState.deprecationTimer.Reset(pl) } @@ -1107,15 +958,8 @@ func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl t // Handle the infinite valid lifetime separately as we do not keep a timer in // this case. if vl >= header.NDPInfiniteLifetime { - if addrState.invalidationTimer != nil { - // Valid lifetime was finite before, but now it is valid forever. - if !addrState.invalidationTimer.Stop() { - *addrState.doNotInvalidate = true - } - addrState.invalidationTimer = nil - addrState.validUntil = time.Time{} - } - + addrState.invalidationTimer.StopLocked() + addrState.validUntil = time.Time{} return } @@ -1124,7 +968,7 @@ func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl t // If the address was originally set to be valid forever, assume the remaining // time to be the maximum possible value. - if addrState.invalidationTimer == nil { + if addrState.validUntil == (time.Time{}) { rl = header.NDPInfiniteLifetime } else { rl = time.Until(addrState.validUntil) @@ -1138,15 +982,8 @@ func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl t effectiveVl = MinPrefixInformationValidLifetimeForUpdate } - if addrState.invalidationTimer == nil { - addrState.invalidationTimer = ndp.autoGenAddrInvalidationTimer(addr, effectiveVl, addrState.doNotInvalidate) - } else { - if !addrState.invalidationTimer.Stop() { - *addrState.doNotInvalidate = true - } - addrState.invalidationTimer.Reset(effectiveVl) - } - + addrState.invalidationTimer.StopLocked() + addrState.invalidationTimer.Reset(effectiveVl) addrState.validUntil = time.Now().Add(effectiveVl) } @@ -1181,27 +1018,12 @@ func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) { // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool { state, ok := ndp.autoGenAddresses[addr] - if !ok { return false } - if state.deprecationTimer != nil { - state.deprecationTimer.Stop() - state.deprecationTimer = nil - *state.doNotDeprecate = true - } - - state.doNotDeprecate = nil - - if state.invalidationTimer != nil { - state.invalidationTimer.Stop() - state.invalidationTimer = nil - *state.doNotInvalidate = true - } - - state.doNotInvalidate = nil - + state.deprecationTimer.StopLocked() + state.invalidationTimer.StopLocked() delete(ndp.autoGenAddresses, addr) if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { @@ -1214,53 +1036,6 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo return true } -// autoGenAddrDeprecationTimer returns a new deprecation timer for an -// auto-generated address that fires after pl. -// -// doNotDeprecate is used to inform the timer when it fires at the same time -// that an auto-generated address's preferred lifetime gets refreshed. See -// autoGenAddrState.doNotDeprecate for more details. -func (ndp *ndpState) autoGenAddrDeprecationTimer(addr tcpip.Address, pl time.Duration, doNotDeprecate *bool) *time.Timer { - return time.AfterFunc(pl, func() { - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if *doNotDeprecate { - *doNotDeprecate = false - return - } - - addrState, ok := ndp.autoGenAddresses[addr] - if !ok { - log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr) - } - addrState.ref.deprecated = true - ndp.notifyAutoGenAddressDeprecated(addr) - addrState.deprecationTimer = nil - ndp.autoGenAddresses[addr] = addrState - }) -} - -// autoGenAddrInvalidationTimer returns a new invalidation timer for an -// auto-generated address that fires after vl. -// -// doNotInvalidate is used to inform the timer when it fires at the same time -// that an auto-generated address's valid lifetime gets refreshed. See -// autoGenAddrState.doNotInvalidate for more details. -func (ndp *ndpState) autoGenAddrInvalidationTimer(addr tcpip.Address, vl time.Duration, doNotInvalidate *bool) *time.Timer { - return time.AfterFunc(vl, func() { - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if *doNotInvalidate { - *doNotInvalidate = false - return - } - - ndp.invalidateAutoGenAddress(addr) - }) -} - // cleanupHostOnlyState cleans up any state that is only useful for hosts. // // cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index e51462a55..d334af289 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1029,13 +1029,13 @@ func TestRouterDiscovery(t *testing.T) { expectRouterEvent(llAddr2, true) // Rx an RA from another router (lladdr3) with non-zero lifetime. - l3Lifetime := time.Duration(6) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, uint16(l3Lifetime))) + const l3LifetimeSeconds = 6 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) expectRouterEvent(llAddr3, true) // Rx an RA from lladdr2 with lesser lifetime. - l2Lifetime := time.Duration(2) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, uint16(l2Lifetime))) + const l2LifetimeSeconds = 2 + e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) select { case <-ndpDisp.routerC: t.Fatal("Should not receive a router event when updating lifetimes for known routers") @@ -1049,7 +1049,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2Lifetime*time.Second+defaultTimeout) + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultTimeout) // Rx an RA from lladdr2 with huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) @@ -1066,7 +1066,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3Lifetime*time.Second+defaultTimeout) + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultTimeout) } // TestRouterDiscoveryMaxRouters tests that only diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go new file mode 100644 index 000000000..f5f01f32f --- /dev/null +++ b/pkg/tcpip/timer.go @@ -0,0 +1,161 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "sync" + "time" +) + +// cancellableTimerInstance is a specific instance of CancellableTimer. +// +// Different instances are created each time CancellableTimer is Reset so each +// timer has its own earlyReturn signal. This is to address a bug when a +// CancellableTimer is stopped and reset in quick succession resulting in a +// timer instance's earlyReturn signal being affected or seen by another timer +// instance. +// +// Consider the following sceneario where timer instances share a common +// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a +// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second +// (B), third (C), and fourth (D) instance of the timer firing, respectively): +// T1: Obtain L +// T1: Create a new CancellableTimer w/ lock L (create instance A) +// T2: instance A fires, blocked trying to obtain L. +// T1: Attempt to stop instance A (set earlyReturn = true) +// T1: Reset timer (create instance B) +// T3: instance B fires, blocked trying to obtain L. +// T1: Attempt to stop instance B (set earlyReturn = true) +// T1: Reset timer (create instance C) +// T4: instance C fires, blocked trying to obtain L. +// T1: Attempt to stop instance C (set earlyReturn = true) +// T1: Reset timer (create instance D) +// T5: instance D fires, blocked trying to obtain L. +// T1: Release L +// +// Now that T1 has released L, any of the 4 timer instances can take L and check +// earlyReturn. If the timers simply check earlyReturn and then do nothing +// further, then instance D will never early return even though it was not +// requested to stop. If the timers reset earlyReturn before early returning, +// then all but one of the timers will do work when only one was expected to. +// If CancellableTimer resets earlyReturn when resetting, then all the timers +// will fire (again, when only one was expected to). +// +// To address the above concerns the simplest solution was to give each timer +// its own earlyReturn signal. +type cancellableTimerInstance struct { + timer *time.Timer + + // Used to inform the timer to early return when it gets stopped while the + // lock the timer tries to obtain when fired is held (T1 is a goroutine that + // tries to cancel the timer and T2 is the goroutine that handles the timer + // firing): + // T1: Obtain the lock, then call StopLocked() + // T2: timer fires, and gets blocked on obtaining the lock + // T1: Releases lock + // T2: Obtains lock does unintended work + // + // To resolve this, T1 will check to see if the timer already fired, and + // inform the timer using earlyReturn to return early so that once T2 obtains + // the lock, it will see that it is set to true and do nothing further. + earlyReturn *bool +} + +// stop stops the timer instance t from firing if it hasn't fired already. If it +// has fired and is blocked at obtaining the lock, earlyReturn will be set to +// true so that it will early return when it obtains the lock. +func (t *cancellableTimerInstance) stop() { + if t.timer != nil { + t.timer.Stop() + *t.earlyReturn = true + } +} + +// CancellableTimer is a timer that does some work and can be safely cancelled +// when it fires at the same time some "related work" is being done. +// +// The term "related work" is defined as some work that needs to be done while +// holding some lock that the timer must also hold while doing some work. +type CancellableTimer struct { + // The active instance of a cancellable timer. + instance cancellableTimerInstance + + // locker is the lock taken by the timer immediately after it fires and must + // be held when attempting to stop the timer. + // + // Must never change after being assigned. + locker sync.Locker + + // fn is the function that will be called when a timer fires and has not been + // signaled to early return. + // + // fn MUST NOT attempt to lock locker. + // + // Must never change after being assigned. + fn func() +} + +// StopLocked prevents the Timer from firing if it has not fired already. +// +// If the timer is blocked on obtaining the t.locker lock when StopLocked is +// called, it will early return instead of calling t.fn. +// +// Note, t will be modified. +// +// t.locker MUST be locked. +func (t *CancellableTimer) StopLocked() { + t.instance.stop() + + // Nothing to do with the stopped instance anymore. + t.instance = cancellableTimerInstance{} +} + +// Reset changes the timer to expire after duration d. +// +// Note, t will be modified. +// +// Reset should only be called on stopped or expired timers. To be safe, callers +// should always call StopLocked before calling Reset. +func (t *CancellableTimer) Reset(d time.Duration) { + // Create a new instance. + earlyReturn := false + t.instance = cancellableTimerInstance{ + timer: time.AfterFunc(d, func() { + t.locker.Lock() + defer t.locker.Unlock() + + if earlyReturn { + // If we reach this point, it means that the timer fired while another + // goroutine called StopLocked while it had the lock. Simply return + // here and do nothing further. + earlyReturn = false + return + } + + t.fn() + }), + earlyReturn: &earlyReturn, + } +} + +// MakeCancellableTimer returns an unscheduled CancellableTimer with the given +// locker and fn. +// +// fn MUST NOT attempt to lock locker. +// +// Callers must call Reset to schedule the timer to fire. +func MakeCancellableTimer(locker sync.Locker, fn func()) CancellableTimer { + return CancellableTimer{locker: locker, fn: fn} +} diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go new file mode 100644 index 000000000..1f735d735 --- /dev/null +++ b/pkg/tcpip/timer_test.go @@ -0,0 +1,236 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package timer_test + +import ( + "sync" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + shortDuration = 1 * time.Nanosecond + middleDuration = 100 * time.Millisecond + longDuration = 1 * time.Second +) + +func TestCancellableTimerFire(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + timer := tcpip.MakeCancellableTimer(&lock, func() { + ch <- struct{}{} + }) + timer.Reset(shortDuration) + + // Wait for timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestCancellableTimerResetFromLongDuration(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(middleDuration) + + lock.Lock() + timer.StopLocked() + lock.Unlock() + + timer.Reset(shortDuration) + + // Wait for timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestCancellableTimerResetFromShortDuration(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + timer.StopLocked() + lock.Unlock() + + // Wait for timer to fire if it wasn't correctly stopped. + select { + case <-ch: + t.Fatal("timer fired after being stopped") + case <-time.After(middleDuration): + } + + timer.Reset(shortDuration) + + // Wait for timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestCancellableTimerImmediatelyStop(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + for i := 0; i < 1000; i++ { + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + timer.StopLocked() + lock.Unlock() + } + + // Wait for timer to fire if it wasn't correctly stopped. + select { + case <-ch: + t.Fatal("timer fired after being stopped") + case <-time.After(middleDuration): + } +} + +func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + timer.StopLocked() + lock.Unlock() + + for i := 0; i < 10; i++ { + timer.Reset(middleDuration) + + lock.Lock() + // Sleep until the timer fires and gets blocked trying to take the lock. + time.Sleep(middleDuration * 2) + timer.StopLocked() + lock.Unlock() + } + + // Wait for double the duration so timers that weren't correctly stopped can + // fire. + select { + case <-ch: + t.Fatal("timer fired after being stopped") + case <-time.After(middleDuration * 2): + } +} + +func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + for i := 0; i < 10; i++ { + // Sleep until the timer fires and gets blocked trying to take the lock. + time.Sleep(middleDuration) + timer.StopLocked() + timer.Reset(shortDuration) + } + lock.Unlock() + + // Wait for double the duration for the last timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} + +func TestManyCancellableTimerResetUnderLock(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + var lock sync.Mutex + + lock.Lock() + timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer.Reset(shortDuration) + for i := 0; i < 10; i++ { + timer.StopLocked() + timer.Reset(shortDuration) + } + lock.Unlock() + + // Wait for double the duration for the last timer to fire. + select { + case <-ch: + case <-time.After(middleDuration): + t.Fatal("timed out waiting for timer to fire") + } + + // The timer should have fired only once. + select { + case <-ch: + t.Fatal("no other timers should have fired") + case <-time.After(middleDuration): + } +} -- cgit v1.2.3 From e752ddbb72d89b19863a6b50d99814149a08d5fe Mon Sep 17 00:00:00 2001 From: Bert Muthalaly Date: Thu, 9 Jan 2020 10:34:30 -0800 Subject: Allow clients to store an opaque NICContext with NICs ...retrievable later via stack.NICInfo(). Clients of this library can use it to add metadata that should be tracked alongside a NIC, to avoid having to keep a map[tcpip.NICID]metadata mirroring stack.Stack's nic map. PiperOrigin-RevId: 288924900 --- pkg/tcpip/stack/nic.go | 12 +++++++----- pkg/tcpip/stack/stack.go | 16 +++++++++++++++- pkg/tcpip/stack/stack_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 6 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 4144d5d0f..3810c6602 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -27,10 +27,11 @@ import ( // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { - stack *Stack - id tcpip.NICID - name string - linkEP LinkEndpoint + stack *Stack + id tcpip.NICID + name string + linkEP LinkEndpoint + context NICContext mu sync.RWMutex spoofing bool @@ -84,7 +85,7 @@ const ( ) // newNIC returns a new NIC using the default NDP configurations from stack. -func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC { +func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For // example, make sure that the link address it provides is a valid // unicast ethernet address. @@ -98,6 +99,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC { id: id, name: name, linkEP: ep, + context: ctx, primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index fb7ac409e..e2a2edb2c 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -796,6 +796,9 @@ func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNum return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue) } +// NICContext is an opaque pointer used to store client-supplied NIC metadata. +type NICContext interface{} + // NICOptions specifies the configuration of a NIC as it is being created. // The zero value creates an enabled, unnamed NIC. type NICOptions struct { @@ -805,6 +808,12 @@ type NICOptions struct { // Disabled specifies whether to avoid calling Attach on the passed // LinkEndpoint. Disabled bool + + // Context specifies user-defined data that will be returned in stack.NICInfo + // for the NIC. Clients of this library can use it to add metadata that + // should be tracked alongside a NIC, to avoid having to keep a + // map[tcpip.NICID]metadata mirroring stack.Stack's nic map. + Context NICContext } // CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and @@ -819,7 +828,7 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp return tcpip.ErrDuplicateNICID } - n := newNIC(s, id, opts.Name, ep) + n := newNIC(s, id, opts.Name, ep, opts.Context) s.nics[id] = n if !opts.Disabled { @@ -886,6 +895,10 @@ type NICInfo struct { MTU uint32 Stats NICStats + + // Context is user-supplied data optionally supplied in CreateNICWithOptions. + // See type NICOptions for more details. + Context NICContext } // NICInfo returns a map of NICIDs to their associated information. @@ -908,6 +921,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, + Context: nic.context, } } return nics diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9ac50bb23..44e5229cc 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2001,6 +2001,46 @@ func TestNICAutoGenAddr(t *testing.T) { } } +// TestNICContextPreservation tests that you can read out via stack.NICInfo the +// Context data you pass via NICContext.Context in stack.CreateNICWithOptions. +func TestNICContextPreservation(t *testing.T) { + var ctx *int + tests := []struct { + name string + opts stack.NICOptions + want stack.NICContext + }{ + { + "context_set", + stack.NICOptions{Context: ctx}, + ctx, + }, + { + "context_not_set", + stack.NICOptions{}, + nil, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + id := tcpip.NICID(1) + ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { + t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) + } + nicinfos := s.NICInfo() + nicinfo, ok := nicinfos[id] + if !ok { + t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) + } + if got, want := nicinfo.Context == test.want, true; got != want { + t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) + } + }) + } +} + // TestNICAutoGenAddrWithOpaque tests the auto-generation of IPv6 link-local // addresses with opaque interface identifiers. Link Local addresses should // always be generated with opaque IIDs if configured to use them, even if the -- cgit v1.2.3 From 8643933d6e58492cbe9d5c78124873ab40f65feb Mon Sep 17 00:00:00 2001 From: Eyal Soha Date: Thu, 9 Jan 2020 13:06:24 -0800 Subject: Change BindToDeviceOption to store NICID This makes it possible to call the sockopt from go even when the NIC has no name. PiperOrigin-RevId: 288955236 --- pkg/sentry/socket/netstack/netstack.go | 29 ++++++++-- pkg/tcpip/stack/stack.go | 8 +++ pkg/tcpip/stack/transport_demuxer_test.go | 89 +++++++++++++++---------------- pkg/tcpip/tcpip.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 27 ++++------ pkg/tcpip/transport/tcp/tcp_test.go | 42 +++++++-------- pkg/tcpip/transport/udp/endpoint.go | 27 ++++------ pkg/tcpip/transport/udp/udp_test.go | 31 +++++------ 8 files changed, 127 insertions(+), 128 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9e0d69046..764f11a6b 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -985,13 +985,23 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - if len(v) == 0 { + if v == 0 { return []byte{}, nil } if outLen < linux.IFNAMSIZ { return nil, syserr.ErrInvalidArgument } - return append([]byte(v), 0), nil + s := t.NetworkContext() + if s == nil { + return nil, syserr.ErrNoDevice + } + nic, ok := s.Interfaces()[int32(v)] + if !ok { + // The NICID no longer indicates a valid interface, probably because that + // interface was removed. + return nil, syserr.ErrUnknownDevice + } + return append([]byte(nic.Name), 0), nil case linux.SO_BROADCAST: if outLen < sizeOfInt32 { @@ -1438,7 +1448,20 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i if n == -1 { n = len(optVal) } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n]))) + name := string(optVal[:n]) + if name == "" { + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0))) + } + s := t.NetworkContext() + if s == nil { + return syserr.ErrNoDevice + } + for nicID, nic := range s.Interfaces() { + if nic.Name == name { + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID))) + } + } + return syserr.ErrUnknownDevice case linux.SO_BROADCAST: if len(optVal) < sizeOfInt32 { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index e2a2edb2c..41bf9fd9b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -901,6 +901,14 @@ type NICInfo struct { Context NICContext } +// HasNIC returns true if the NICID is defined in the stack. +func (s *Stack) HasNIC(id tcpip.NICID) bool { + s.mu.RLock() + _, ok := s.nics[id] + s.mu.RUnlock() + return ok +} + // NICInfo returns a map of NICIDs to their associated information. func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { s.mu.RLock() diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index df5ced887..5e9237de9 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -41,7 +41,7 @@ const ( type testContext struct { t *testing.T - linkEPs map[string]*channel.Endpoint + linkEps map[tcpip.NICID]*channel.Endpoint s *stack.Stack ep tcpip.Endpoint @@ -66,27 +66,24 @@ func (c *testContext) createV6Endpoint(v6only bool) { } } -// newDualTestContextMultiNic creates the testing context and also linkEpNames -// named NICs. -func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { +// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. +func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - linkEPs := make(map[string]*channel.Endpoint) - for i, linkEpName := range linkEpNames { - channelEP := channel.New(256, mtu, "") - nicID := tcpip.NICID(i + 1) - opts := stack.NICOptions{Name: linkEpName} - if err := s.CreateNICWithOptions(nicID, channelEP, opts); err != nil { - t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) + linkEps := make(map[tcpip.NICID]*channel.Endpoint) + for _, linkEpID := range linkEpIDs { + channelEp := channel.New(256, mtu, "") + if err := s.CreateNIC(linkEpID, channelEp); err != nil { + t.Fatalf("CreateNIC failed: %v", err) } - linkEPs[linkEpName] = channelEP + linkEps[linkEpID] = channelEp - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { + if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress IPv4 failed: %v", err) } - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil { + if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil { t.Fatalf("AddAddress IPv6 failed: %v", err) } } @@ -105,7 +102,7 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) return &testContext{ t: t, s: s, - linkEPs: linkEPs, + linkEps: linkEps, } } @@ -122,7 +119,7 @@ func newPayload() []byte { return b } -func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { +func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) copy(buf[len(buf)-len(payload):], payload) @@ -153,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -183,7 +180,7 @@ func TestTransportDemuxerRegister(t *testing.T) { func TestDistribution(t *testing.T) { type endpointSockopts struct { reuse int - bindToDevice string + bindToDevice tcpip.NICID } for _, test := range []struct { name string @@ -191,71 +188,71 @@ func TestDistribution(t *testing.T) { endpoints []endpointSockopts // wantedDistribution is the wanted ratio of packets received on each // endpoint for each NIC on which packets are injected. - wantedDistributions map[string][]float64 + wantedDistributions map[tcpip.NICID][]float64 }{ { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - {1, ""}, - {1, ""}, - {1, ""}, - {1, ""}, - {1, ""}, + {1, 0}, + {1, 0}, + {1, 0}, + {1, 0}, + {1, 0}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. - "dev0": {0.2, 0.2, 0.2, 0.2, 0.2}, + 1: {0.2, 0.2, 0.2, 0.2, 0.2}, }, }, { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - {0, "dev0"}, - {0, "dev1"}, - {0, "dev2"}, + {0, 1}, + {0, 2}, + {0, 3}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. - "dev0": {1, 0, 0}, + 1: {1, 0, 0}, // Injected packets on dev1 go only to the endpoint bound to dev1. - "dev1": {0, 1, 0}, + 2: {0, 1, 0}, // Injected packets on dev2 go only to the endpoint bound to dev2. - "dev2": {0, 0, 1}, + 3: {0, 0, 1}, }, }, { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - {1, "dev0"}, - {1, "dev0"}, - {1, "dev1"}, - {1, "dev1"}, - {1, "dev1"}, - {1, ""}, + {1, 1}, + {1, 1}, + {1, 2}, + {1, 2}, + {1, 2}, + {1, 0}, }, - map[string][]float64{ + map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to // dev0. - "dev0": {0.5, 0.5, 0, 0, 0, 0}, + 1: {0.5, 0.5, 0, 0, 0, 0}, // Injected packets on dev1 get distributed among endpoints bound to // dev1 or unbound. - "dev1": {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, // Injected packets on dev999 go only to the unbound. - "dev999": {0, 0, 0, 0, 0, 1}, + 1000: {0, 0, 0, 0, 0, 1}, }, }, } { t.Run(test.name, func(t *testing.T) { for device, wantedDistribution := range test.wantedDistributions { - t.Run(device, func(t *testing.T) { - var devices []string + t.Run(string(device), func(t *testing.T) { + var devices []tcpip.NICID for d := range test.wantedDistributions { devices = append(devices, d) } - c := newDualTestContextMultiNic(t, defaultMTU, devices) + c := newDualTestContextMultiNIC(t, defaultMTU, devices) defer c.cleanup() c.createV6Endpoint(false) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 1eca76c30..72b5ce179 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -552,7 +552,7 @@ type ReusePortOption int // BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets // should bind only on a specific NIC. -type BindToDeviceOption string +type BindToDeviceOption NICID // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. type QuickAckOption int diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 2ac1b6877..920b24975 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1279,19 +1279,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil case tcpip.BindToDeviceOption: - e.mu.Lock() - defer e.mu.Unlock() - if v == "" { - e.bindToDevice = 0 - return nil - } - for nicID, nic := range e.stack.NICInfo() { - if nic.Name == string(v) { - e.bindToDevice = nicID - return nil - } + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice } - return tcpip.ErrUnknownDevice + e.mu.Lock() + e.bindToDevice = id + e.mu.Unlock() + return nil case tcpip.QuickAckOption: if v == 0 { @@ -1550,12 +1545,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.BindToDeviceOption: e.mu.RLock() - defer e.mu.RUnlock() - if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { - *o = tcpip.BindToDeviceOption(nic.Name) - return nil - } - *o = "" + *o = tcpip.BindToDeviceOption(e.bindToDevice) + e.mu.RUnlock() return nil case *tcpip.QuickAckOption: diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 15745ebd4..1aa0733d0 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1083,12 +1083,12 @@ func TestTrafficClassV6(t *testing.T) { func TestConnectBindToDevice(t *testing.T) { for _, test := range []struct { name string - device string + device tcpip.NICID want tcp.EndpointState }{ - {"RightDevice", "nic1", tcp.StateEstablished}, - {"WrongDevice", "nic2", tcp.StateSynSent}, - {"AnyDevice", "", tcp.StateEstablished}, + {"RightDevice", 1, tcp.StateEstablished}, + {"WrongDevice", 2, tcp.StateSynSent}, + {"AnyDevice", 0, tcp.StateEstablished}, } { t.Run(test.name, func(t *testing.T) { c := context.New(t, defaultMTU) @@ -3794,47 +3794,41 @@ func TestBindToDeviceOption(t *testing.T) { } defer ep.Close() - opts := stack.NICOptions{Name: "my_device"} - if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { - t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) - } - - // Make an nameless NIC. - if err := s.CreateNIC(54321, loopback.New()); err != nil { + if err := s.CreateNIC(321, loopback.New()); err != nil { t.Errorf("CreateNIC failed: %v", err) } - // strPtr is used instead of taking the address of string literals, which is + // nicIDPtr is used instead of taking the address of NICID literals, which is // a compiler error. - strPtr := func(s string) *string { + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { return &s } testActions := []struct { name string - setBindToDevice *string + setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error getBindToDevice tcpip.BindToDeviceOption }{ - {"GetDefaultValue", nil, nil, ""}, - {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, - {"BindToExistent", strPtr("my_device"), nil, "my_device"}, - {"UnbindToDevice", strPtr(""), nil, ""}, + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, } for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") - if ep.GetSockOpt(&bindToDevice) != nil { - t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + bindToDevice := tcpip.BindToDeviceOption(88888) + if err := ep.GetSockOpt(&bindToDevice); err != nil { + t.Errorf("GetSockOpt got %v, want %v", err, nil) } if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %q, want %q", got, want) + t.Errorf("bindToDevice got %d, want %d", got, want) } }) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 1a5ee6317..864dc8733 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -631,19 +631,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() case tcpip.BindToDeviceOption: - e.mu.Lock() - defer e.mu.Unlock() - if v == "" { - e.bindToDevice = 0 - return nil - } - for nicID, nic := range e.stack.NICInfo() { - if nic.Name == string(v) { - e.bindToDevice = nicID - return nil - } + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice } - return tcpip.ErrUnknownDevice + e.mu.Lock() + e.bindToDevice = id + e.mu.Unlock() + return nil case tcpip.BroadcastOption: e.mu.Lock() @@ -767,12 +762,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case *tcpip.BindToDeviceOption: e.mu.RLock() - defer e.mu.RUnlock() - if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { - *o = tcpip.BindToDeviceOption(nic.Name) - return nil - } - *o = tcpip.BindToDeviceOption("") + *o = tcpip.BindToDeviceOption(e.bindToDevice) + e.mu.RUnlock() return nil case *tcpip.KeepaliveEnabledOption: diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 149fff999..0a82bc4fa 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -513,42 +513,37 @@ func TestBindToDeviceOption(t *testing.T) { t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } - // Make an nameless NIC. - if err := s.CreateNIC(54321, loopback.New()); err != nil { - t.Errorf("CreateNIC failed: %v", err) - } - - // strPtr is used instead of taking the address of string literals, which is + // nicIDPtr is used instead of taking the address of NICID literals, which is // a compiler error. - strPtr := func(s string) *string { + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { return &s } testActions := []struct { name string - setBindToDevice *string + setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error getBindToDevice tcpip.BindToDeviceOption }{ - {"GetDefaultValue", nil, nil, ""}, - {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, - {"BindToExistent", strPtr("my_device"), nil, "my_device"}, - {"UnbindToDevice", strPtr(""), nil, ""}, + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, } for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") - if ep.GetSockOpt(&bindToDevice) != nil { - t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + bindToDevice := tcpip.BindToDeviceOption(88888) + if err := ep.GetSockOpt(&bindToDevice); err != nil { + t.Errorf("GetSockOpt got %v, want %v", err, nil) } if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %q, want %q", got, want) + t.Errorf("bindToDevice got %d, want %d", got, want) } }) } -- cgit v1.2.3 From 8fafd3142e85175fe56bc3333d859f1a8cfbb878 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 9 Jan 2020 15:55:24 -0800 Subject: Separate NDP tests into its own package Internal tools timeout after 60s during tests that are required to pass before changes can be submitted. Separate out NDP tests into its own package to help prevent timeouts when testing. PiperOrigin-RevId: 288990597 --- pkg/tcpip/stack/BUILD | 23 ++++++++++-- pkg/tcpip/stack/ndp_test.go | 87 +++++++++++++++++++++++++++++++++++++++---- pkg/tcpip/stack/stack_test.go | 74 +++--------------------------------- 3 files changed, 106 insertions(+), 78 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index b8f9517d0..826fca4de 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -52,7 +52,6 @@ go_test( name = "stack_x_test", size = "small", srcs = [ - "ndp_test.go", "stack_test.go", "transport_demuxer_test.go", "transport_test.go", @@ -62,14 +61,12 @@ go_test( "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go-cmp//cmp:go_default_library", @@ -86,3 +83,23 @@ go_test( "//pkg/tcpip", ], ) + +go_test( + name = "ndp_test", + size = "small", + srcs = ["ndp_test.go"], + deps = [ + ":stack", + "//pkg/rand", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@com_github_google_go-cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index d334af289..fa84c94a6 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package stack_test +package ndp_test import ( "encoding/binary" @@ -285,6 +285,8 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc // Included in the subtests is a test to make sure that an invalid // RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. func TestDADResolve(t *testing.T) { + t.Parallel() + tests := []struct { name string dupAddrDetectTransmits uint8 @@ -417,6 +419,8 @@ func TestDADResolve(t *testing.T) { // a node doing DAD for the same address), or if another node is detected to own // the address already (receive an NA message for the tentative address). func TestDADFail(t *testing.T) { + t.Parallel() + tests := []struct { name string makeBuf func(tgt tcpip.Address) buffer.Prependable @@ -560,6 +564,8 @@ func TestDADFail(t *testing.T) { // TestDADStop tests to make sure that the DAD process stops when an address is // removed. func TestDADStop(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } @@ -632,6 +638,71 @@ func TestDADStop(t *testing.T) { } } +// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 +// link-local addresses will only be assigned after the DAD process resolves. +func TestNICAutoGenAddrDoesDAD(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.DefaultNDPConfigurations() + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + } + + e := channel.New(0, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Address should not be considered bound to the + // NIC yet (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + linkLocalAddr := header.LinkLocalAddr(linkAddr1) + + // Wait for DAD to resolve. + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // We should get a resolution event after 1s (default time to + // resolve as per default NDP configurations). Waiting for that + // resolution time + an extra 1s without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != linkLocalAddr { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } +} + // TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if // we attempt to update NDP configurations using an invalid NICID. func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { @@ -649,6 +720,8 @@ func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { // configurations without affecting the default NDP configurations or other // interfaces' configurations. func TestSetNDPConfigurations(t *testing.T) { + t.Parallel() + tests := []struct { name string dupAddrDetectTransmits uint8 @@ -875,6 +948,8 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on // TestNoRouterDiscovery tests that router discovery will not be performed if // configured not to. func TestNoRouterDiscovery(t *testing.T) { + t.Parallel() + // Being configured to discover routers means handle and // discover are set to true and forwarding is set to false. // This tests all possible combinations of the configurations, @@ -887,8 +962,6 @@ func TestNoRouterDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), } @@ -1123,6 +1196,8 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { // TestNoPrefixDiscovery tests that prefix discovery will not be performed if // configured not to. func TestNoPrefixDiscovery(t *testing.T) { + t.Parallel() + prefix := tcpip.AddressWithPrefix{ Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), PrefixLen: 64, @@ -1140,8 +1215,6 @@ func TestNoPrefixDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, 1), } @@ -1498,6 +1571,8 @@ func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { // TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. func TestNoAutoGenAddr(t *testing.T) { + t.Parallel() + prefix, _, _ := prefixSubnetAddr(0, "") // Being configured to auto-generate addresses means handle and @@ -1512,8 +1587,6 @@ func TestNoAutoGenAddr(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 44e5229cc..e8de4e87d 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -24,7 +24,6 @@ import ( "sort" "strings" "testing" - "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/rand" @@ -50,6 +49,8 @@ const ( // where another value is explicitly used. It is chosen to match the MTU // of loopback interfaces on linux systems. defaultMTU = 65536 + + linkAddr = "\x02\x02\x03\x04\x05\x06" ) // fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and @@ -1909,7 +1910,7 @@ func TestNICAutoGenAddr(t *testing.T) { { "Disabled", false, - linkAddr1, + linkAddr, stack.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(nicID tcpip.NICID, _ string) string { return fmt.Sprintf("nic%d", nicID) @@ -1920,7 +1921,7 @@ func TestNICAutoGenAddr(t *testing.T) { { "Enabled", true, - linkAddr1, + linkAddr, stack.OpaqueInterfaceIdentifierOptions{}, true, }, @@ -2068,14 +2069,14 @@ func TestNICAutoGenAddrWithOpaque(t *testing.T) { name: "Disabled", nicName: "nic1", autoGen: false, - linkAddr: linkAddr1, + linkAddr: linkAddr, secretKey: secretKey[:], }, { name: "Enabled", nicName: "nic1", autoGen: true, - linkAddr: linkAddr1, + linkAddr: linkAddr, secretKey: secretKey[:], }, // These are all cases where we would not have generated a @@ -2213,69 +2214,6 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { } } -// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 -// link-local addresses will only be assigned after the DAD process resolves. -func TestNICAutoGenAddrDoesDAD(t *testing.T) { - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), - } - ndpConfigs := stack.DefaultNDPConfigurations() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - AutoGenIPv6LinkLocal: true, - NDPDisp: &ndpDisp, - } - - e := channel.New(10, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - // Address should not be considered bound to the - // NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } - - linkLocalAddr := header.LinkLocalAddr(linkAddr1) - - // Wait for DAD to resolve. - select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): - // We should get a resolution event after 1s (default time to - // resolve as per default NDP configurations). Waiting for that - // resolution time + an extra 1s without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != linkLocalAddr { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") - } - } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) - } -} - // TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected // when an address's kind gets "promoted" to permanent from permanentExpired. func TestNewPEBOnPromotionToPermanent(t *testing.T) { -- cgit v1.2.3 From 26c5653bb547450e85666f345d542b516b3417fc Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 9 Jan 2020 16:55:01 -0800 Subject: Inform NDPDispatcher when Stack learns about available configurations via DHCPv6 Inform the Stack's NDPDispatcher when it receives an NDP Router Advertisement that updates the available configurations via DHCPv6. The Stack makes sure that its NDPDispatcher isn't informed unless the avaiable configurations via DHCPv6 for a NIC is updated. Tests: Test that a Stack's NDPDispatcher is informed when it receives an NDP Router Advertisement that informs it of new configurations available via DHCPv6. PiperOrigin-RevId: 289001283 --- pkg/tcpip/stack/ndp.go | 62 +++++++++++++++++ pkg/tcpip/stack/ndp_test.go | 165 ++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 212 insertions(+), 15 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 35825ebf7..a9dd322db 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -115,6 +115,30 @@ var ( MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour ) +// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an +// NDP Router Advertisement informed the Stack about. +type DHCPv6ConfigurationFromNDPRA int + +const ( + // DHCPv6NoConfiguration indicates that no configurations are available via + // DHCPv6. + DHCPv6NoConfiguration DHCPv6ConfigurationFromNDPRA = iota + + // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6. + // + // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6 + // will return all available configuration information. + DHCPv6ManagedAddress + + // DHCPv6OtherConfigurations indicates that other configuration information is + // available via DHCPv6. + // + // Other configurations are configurations other than addresses. Examples of + // other configurations are recursive DNS server list, DNS search lists and + // default gateway. + DHCPv6OtherConfigurations +) + // NDPDispatcher is the interface integrators of netstack must implement to // receive and handle NDP related events. type NDPDispatcher interface { @@ -194,7 +218,20 @@ type NDPDispatcher interface { // already known DNS servers. If called with known DNS servers, their // valid lifetimes must be refreshed to lifetime (it may be increased, // decreased, or completely invalidated when lifetime = 0). + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) + + // OnDHCPv6Configuration will be called with an updated configuration that is + // available via DHCPv6 for a specified NIC. + // + // NDPDispatcher assumes that the initial configuration available by DHCPv6 is + // DHCPv6NoConfiguration. + // + // This function is not permitted to block indefinitely. It must not + // call functions on the stack itself. + OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) } // NDPConfigurations is the NDP configurations for the netstack. @@ -281,6 +318,9 @@ type ndpState struct { // The addresses generated by SLAAC. autoGenAddresses map[tcpip.Address]autoGenAddressState + + // The last learned DHCPv6 configuration from an NDP RA. + dhcpv6Configuration DHCPv6ConfigurationFromNDPRA } // dadState holds the Duplicate Address Detection timer and channel to signal @@ -533,6 +573,28 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { return } + // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we + // only inform the dispatcher on configuration changes. We do nothing else + // with the information. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + var configuration DHCPv6ConfigurationFromNDPRA + switch { + case ra.ManagedAddrConfFlag(): + configuration = DHCPv6ManagedAddress + + case ra.OtherConfFlag(): + configuration = DHCPv6OtherConfigurations + + default: + configuration = DHCPv6NoConfiguration + } + + if ndp.dhcpv6Configuration != configuration { + ndp.dhcpv6Configuration = configuration + ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration) + } + } + // Is the NIC configured to discover default routers? if ndp.configs.DiscoverDefaultRouters { rtr, ok := ndp.defaultRouters[ip] diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index fa84c94a6..108762b6e 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -162,18 +162,24 @@ type ndpRDNSSEvent struct { rdnss ndpRDNSS } +type ndpDHCPv6Event struct { + nicID tcpip.NICID + configuration stack.DHCPv6ConfigurationFromNDPRA +} + var _ stack.NDPDispatcher = (*ndpDispatcher)(nil) // ndpDispatcher implements NDPDispatcher so tests can know when various NDP // related events happen for test purposes. type ndpDispatcher struct { - dadC chan ndpDADEvent - routerC chan ndpRouterEvent - rememberRouter bool - prefixC chan ndpPrefixEvent - rememberPrefix bool - autoGenAddrC chan ndpAutoGenAddrEvent - rdnssC chan ndpRDNSSEvent + dadC chan ndpDADEvent + routerC chan ndpRouterEvent + rememberRouter bool + prefixC chan ndpPrefixEvent + rememberPrefix bool + autoGenAddrC chan ndpAutoGenAddrEvent + rdnssC chan ndpRDNSSEvent + dhcpv6ConfigurationC chan ndpDHCPv6Event } // Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. @@ -280,6 +286,16 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc } } +// Implements stack.NDPDispatcher.OnDHCPv6Configuration. +func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) { + if c := n.dhcpv6ConfigurationC; c != nil { + c <- ndpDHCPv6Event{ + nicID, + configuration, + } + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -870,21 +886,32 @@ func TestSetNDPConfigurations(t *testing.T) { } } -// raBufWithOpts returns a valid NDP Router Advertisement with options. -// -// Note, raBufWithOpts does not populate any of the RA fields other than the -// Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { +// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options +// and DHCPv6 configurations specified. +func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(0) - ra := header.NDPRouterAdvert(pkt.NDPPayload()) + raPayload := pkt.NDPPayload() + ra := header.NDPRouterAdvert(raPayload) + // Populate the Router Lifetime. + binary.BigEndian.PutUint16(raPayload[2:], rl) + // Populate the Managed Address flag field. + if managedAddress { + // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing) + // of the RA payload. + raPayload[1] |= (1 << 7) + } + // Populate the Other Configurations flag field. + if otherConfigurations { + // The Other Configurations flag field is the 6th bit of byte #1 + // (0-indexing) of the RA payload. + raPayload[1] |= (1 << 6) + } opts := ra.Options() opts.Serialize(optSer) - // Populate the Router Lifetime. - binary.BigEndian.PutUint16(pkt.NDPPayload()[2:], rl) pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -899,6 +926,23 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} } +// raBufWithOpts returns a valid NDP Router Advertisement with options. +// +// Note, raBufWithOpts does not populate any of the RA fields other than the +// Router Lifetime. +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { + return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) +} + +// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related +// fields set. +// +// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the +// DHCPv6 related ones. +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer { + return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) +} + // raBuf returns a valid NDP Router Advertisement. // // Note, raBuf does not populate any of the RA fields other than the @@ -3024,3 +3068,94 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { default: } } + +// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly +// informed when new information about what configurations are available via +// DHCPv6 is learned. +func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { + const nicID = 1 + + ndpDisp := ndpDispatcher{ + dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), + rememberRouter: true, + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) { + t.Helper() + select { + case e := <-ndpDisp.dhcpv6ConfigurationC: + if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" { + t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DHCPv6 configuration event") + } + } + + expectNoDHCPv6Event := func() { + t.Helper() + select { + case <-ndpDisp.dhcpv6ConfigurationC: + t.Fatal("unexpected DHCPv6 configuration event") + default: + } + } + + // The initial DHCPv6 configuration should be stack.DHCPv6NoConfiguration. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Other + // Configurations. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + // Receiving the same update again should not result in an event to the + // NDPDispatcher. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Managed Address. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectDHCPv6Event(stack.DHCPv6ManagedAddress) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to none. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectDHCPv6Event(stack.DHCPv6NoConfiguration) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Managed Address. + // + // Note, when the M flag is set, the O flag is redundant. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectDHCPv6Event(stack.DHCPv6ManagedAddress) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectNoDHCPv6Event() + // Even though the DHCPv6 flags are different, the effective configuration is + // the same so we should not receive a new event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false)) + expectNoDHCPv6Event() + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true)) + expectNoDHCPv6Event() + + // Receive an RA that updates the DHCPv6 configuration to Other + // Configurations. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectNoDHCPv6Event() +} -- cgit v1.2.3 From 27500d529f7fb87eef8812278fd1bbca67bcba72 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Thu, 9 Jan 2020 22:00:42 -0800 Subject: New sync package. * Rename syncutil to sync. * Add aliases to sync types. * Replace existing usage of standard library sync package. This will make it easier to swap out synchronization primitives. For example, this will allow us to use primitives from github.com/sasha-s/go-deadlock to check for lock ordering violations. Updates #1472 PiperOrigin-RevId: 289033387 --- pkg/amutex/BUILD | 1 + pkg/amutex/amutex_test.go | 3 +- pkg/atomicbitops/BUILD | 1 + pkg/atomicbitops/atomic_bitops_test.go | 3 +- pkg/compressio/BUILD | 5 +- pkg/compressio/compressio.go | 2 +- pkg/control/server/BUILD | 1 + pkg/control/server/server.go | 2 +- pkg/eventchannel/BUILD | 2 + pkg/eventchannel/event.go | 2 +- pkg/eventchannel/event_test.go | 2 +- pkg/fdchannel/BUILD | 1 + pkg/fdchannel/fdchannel_test.go | 3 +- pkg/fdnotifier/BUILD | 1 + pkg/fdnotifier/fdnotifier.go | 2 +- pkg/flipcall/BUILD | 3 +- pkg/flipcall/flipcall_example_test.go | 3 +- pkg/flipcall/flipcall_test.go | 3 +- pkg/flipcall/flipcall_unsafe.go | 10 +- pkg/gate/BUILD | 1 + pkg/gate/gate_test.go | 2 +- pkg/linewriter/BUILD | 1 + pkg/linewriter/linewriter.go | 3 +- pkg/log/BUILD | 5 +- pkg/log/log.go | 2 +- pkg/metric/BUILD | 1 + pkg/metric/metric.go | 2 +- pkg/p9/BUILD | 1 + pkg/p9/client.go | 2 +- pkg/p9/p9test/BUILD | 2 + pkg/p9/p9test/client_test.go | 2 +- pkg/p9/p9test/p9test.go | 2 +- pkg/p9/path_tree.go | 3 +- pkg/p9/pool.go | 2 +- pkg/p9/server.go | 2 +- pkg/p9/transport.go | 2 +- pkg/procid/BUILD | 2 + pkg/procid/procid_test.go | 3 +- pkg/rand/BUILD | 5 +- pkg/rand/rand_linux.go | 2 +- pkg/refs/BUILD | 2 + pkg/refs/refcounter.go | 2 +- pkg/refs/refcounter_test.go | 3 +- pkg/sentry/arch/BUILD | 1 + pkg/sentry/arch/arch_x86.go | 2 +- pkg/sentry/control/BUILD | 1 + pkg/sentry/control/pprof.go | 2 +- pkg/sentry/device/BUILD | 5 +- pkg/sentry/device/device.go | 2 +- pkg/sentry/fs/BUILD | 3 +- pkg/sentry/fs/copy_up.go | 2 +- pkg/sentry/fs/copy_up_test.go | 2 +- pkg/sentry/fs/dirent.go | 2 +- pkg/sentry/fs/dirent_cache.go | 3 +- pkg/sentry/fs/dirent_cache_limiter.go | 3 +- pkg/sentry/fs/fdpipe/BUILD | 1 + pkg/sentry/fs/fdpipe/pipe.go | 2 +- pkg/sentry/fs/fdpipe/pipe_state.go | 2 +- pkg/sentry/fs/file.go | 2 +- pkg/sentry/fs/file_overlay.go | 2 +- pkg/sentry/fs/filesystems.go | 2 +- pkg/sentry/fs/fs.go | 3 +- pkg/sentry/fs/fsutil/BUILD | 1 + pkg/sentry/fs/fsutil/host_file_mapper.go | 2 +- pkg/sentry/fs/fsutil/host_mappable.go | 2 +- pkg/sentry/fs/fsutil/inode.go | 3 +- pkg/sentry/fs/fsutil/inode_cached.go | 2 +- pkg/sentry/fs/gofer/BUILD | 1 + pkg/sentry/fs/gofer/inode.go | 2 +- pkg/sentry/fs/gofer/session.go | 2 +- pkg/sentry/fs/host/BUILD | 1 + pkg/sentry/fs/host/inode.go | 2 +- pkg/sentry/fs/host/socket.go | 2 +- pkg/sentry/fs/host/tty.go | 3 +- pkg/sentry/fs/inode.go | 3 +- pkg/sentry/fs/inode_inotify.go | 3 +- pkg/sentry/fs/inotify.go | 2 +- pkg/sentry/fs/inotify_watch.go | 2 +- pkg/sentry/fs/lock/BUILD | 1 + pkg/sentry/fs/lock/lock.go | 2 +- pkg/sentry/fs/mounts.go | 2 +- pkg/sentry/fs/overlay.go | 5 +- pkg/sentry/fs/proc/BUILD | 1 + pkg/sentry/fs/proc/seqfile/BUILD | 1 + pkg/sentry/fs/proc/seqfile/seqfile.go | 2 +- pkg/sentry/fs/proc/sys_net.go | 2 +- pkg/sentry/fs/ramfs/BUILD | 1 + pkg/sentry/fs/ramfs/dir.go | 2 +- pkg/sentry/fs/restore.go | 2 +- pkg/sentry/fs/tmpfs/BUILD | 1 + pkg/sentry/fs/tmpfs/inode_file.go | 2 +- pkg/sentry/fs/tty/BUILD | 1 + pkg/sentry/fs/tty/dir.go | 2 +- pkg/sentry/fs/tty/line_discipline.go | 2 +- pkg/sentry/fs/tty/queue.go | 3 +- pkg/sentry/fsimpl/ext/BUILD | 1 + pkg/sentry/fsimpl/ext/directory.go | 3 +- pkg/sentry/fsimpl/ext/filesystem.go | 2 +- pkg/sentry/fsimpl/ext/regular_file.go | 2 +- pkg/sentry/fsimpl/kernfs/BUILD | 2 + pkg/sentry/fsimpl/kernfs/inode_impl_util.go | 2 +- pkg/sentry/fsimpl/kernfs/kernfs.go | 2 +- pkg/sentry/fsimpl/kernfs/kernfs_test.go | 2 +- pkg/sentry/fsimpl/tmpfs/BUILD | 1 + pkg/sentry/fsimpl/tmpfs/regular_file.go | 2 +- pkg/sentry/fsimpl/tmpfs/tmpfs.go | 2 +- pkg/sentry/kernel/BUILD | 5 +- pkg/sentry/kernel/abstract_socket_namespace.go | 2 +- pkg/sentry/kernel/auth/BUILD | 3 +- pkg/sentry/kernel/auth/user_namespace.go | 2 +- pkg/sentry/kernel/epoll/BUILD | 1 + pkg/sentry/kernel/epoll/epoll.go | 2 +- pkg/sentry/kernel/eventfd/BUILD | 1 + pkg/sentry/kernel/eventfd/eventfd.go | 2 +- pkg/sentry/kernel/fasync/BUILD | 1 + pkg/sentry/kernel/fasync/fasync.go | 3 +- pkg/sentry/kernel/fd_table.go | 2 +- pkg/sentry/kernel/fd_table_test.go | 2 +- pkg/sentry/kernel/fs_context.go | 2 +- pkg/sentry/kernel/futex/BUILD | 8 +- pkg/sentry/kernel/futex/futex.go | 3 +- pkg/sentry/kernel/futex/futex_test.go | 2 +- pkg/sentry/kernel/kernel.go | 2 +- pkg/sentry/kernel/memevent/BUILD | 1 + pkg/sentry/kernel/memevent/memory_events.go | 2 +- pkg/sentry/kernel/pipe/BUILD | 1 + pkg/sentry/kernel/pipe/buffer.go | 2 +- pkg/sentry/kernel/pipe/node.go | 3 +- pkg/sentry/kernel/pipe/pipe.go | 2 +- pkg/sentry/kernel/pipe/pipe_util.go | 2 +- pkg/sentry/kernel/pipe/vfs.go | 3 +- pkg/sentry/kernel/semaphore/BUILD | 1 + pkg/sentry/kernel/semaphore/semaphore.go | 2 +- pkg/sentry/kernel/shm/BUILD | 1 + pkg/sentry/kernel/shm/shm.go | 2 +- pkg/sentry/kernel/signal_handlers.go | 3 +- pkg/sentry/kernel/signalfd/BUILD | 1 + pkg/sentry/kernel/signalfd/signalfd.go | 3 +- pkg/sentry/kernel/syscalls.go | 2 +- pkg/sentry/kernel/syslog.go | 3 +- pkg/sentry/kernel/task.go | 5 +- pkg/sentry/kernel/thread_group.go | 2 +- pkg/sentry/kernel/threads.go | 2 +- pkg/sentry/kernel/time/BUILD | 1 + pkg/sentry/kernel/time/time.go | 2 +- pkg/sentry/kernel/timekeeper.go | 2 +- pkg/sentry/kernel/tty.go | 2 +- pkg/sentry/kernel/uts_namespace.go | 3 +- pkg/sentry/limits/BUILD | 1 + pkg/sentry/limits/limits.go | 3 +- pkg/sentry/mm/BUILD | 2 +- pkg/sentry/mm/aio_context.go | 3 +- pkg/sentry/mm/mm.go | 8 +- pkg/sentry/pgalloc/BUILD | 1 + pkg/sentry/pgalloc/pgalloc.go | 2 +- pkg/sentry/platform/interrupt/BUILD | 1 + pkg/sentry/platform/interrupt/interrupt.go | 3 +- pkg/sentry/platform/kvm/BUILD | 1 + pkg/sentry/platform/kvm/address_space.go | 2 +- pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go | 2 - pkg/sentry/platform/kvm/kvm.go | 2 +- pkg/sentry/platform/kvm/machine.go | 2 +- pkg/sentry/platform/ptrace/BUILD | 1 + pkg/sentry/platform/ptrace/ptrace.go | 2 +- pkg/sentry/platform/ptrace/subprocess.go | 2 +- .../platform/ptrace/subprocess_linux_unsafe.go | 2 +- pkg/sentry/platform/ring0/defs.go | 2 +- pkg/sentry/platform/ring0/defs_amd64.go | 1 + pkg/sentry/platform/ring0/defs_arm64.go | 1 + pkg/sentry/platform/ring0/pagetables/BUILD | 5 +- pkg/sentry/platform/ring0/pagetables/pcids_x86.go | 2 +- pkg/sentry/socket/netlink/BUILD | 1 + pkg/sentry/socket/netlink/port/BUILD | 1 + pkg/sentry/socket/netlink/port/port.go | 3 +- pkg/sentry/socket/netlink/socket.go | 2 +- pkg/sentry/socket/netstack/BUILD | 1 + pkg/sentry/socket/netstack/netstack.go | 2 +- pkg/sentry/socket/rpcinet/conn/BUILD | 1 + pkg/sentry/socket/rpcinet/conn/conn.go | 2 +- pkg/sentry/socket/rpcinet/notifier/BUILD | 1 + pkg/sentry/socket/rpcinet/notifier/notifier.go | 2 +- pkg/sentry/socket/unix/transport/BUILD | 1 + pkg/sentry/socket/unix/transport/connectioned.go | 3 +- pkg/sentry/socket/unix/transport/queue.go | 3 +- pkg/sentry/socket/unix/transport/unix.go | 2 +- pkg/sentry/syscalls/linux/BUILD | 1 + pkg/sentry/syscalls/linux/error.go | 2 +- pkg/sentry/time/BUILD | 4 +- pkg/sentry/time/calibrated_clock.go | 2 +- pkg/sentry/usage/BUILD | 1 + pkg/sentry/usage/memory.go | 2 +- pkg/sentry/vfs/BUILD | 3 +- pkg/sentry/vfs/dentry.go | 2 +- pkg/sentry/vfs/file_description_impl_util.go | 2 +- pkg/sentry/vfs/mount_test.go | 3 +- pkg/sentry/vfs/mount_unsafe.go | 4 +- pkg/sentry/vfs/pathname.go | 3 +- pkg/sentry/vfs/resolving_path.go | 2 +- pkg/sentry/vfs/vfs.go | 2 +- pkg/sentry/watchdog/BUILD | 1 + pkg/sentry/watchdog/watchdog.go | 2 +- pkg/sync/BUILD | 53 +++++++ pkg/sync/LICENSE | 27 ++++ pkg/sync/README.md | 5 + pkg/sync/aliases.go | 37 +++++ pkg/sync/atomicptr_unsafe.go | 47 +++++++ pkg/sync/atomicptrtest/BUILD | 29 ++++ pkg/sync/atomicptrtest/atomicptr_test.go | 31 +++++ pkg/sync/downgradable_rwmutex_test.go | 150 ++++++++++++++++++++ pkg/sync/downgradable_rwmutex_unsafe.go | 146 ++++++++++++++++++++ pkg/sync/memmove_unsafe.go | 28 ++++ pkg/sync/norace_unsafe.go | 35 +++++ pkg/sync/race_unsafe.go | 41 ++++++ pkg/sync/seqatomic_unsafe.go | 72 ++++++++++ pkg/sync/seqatomictest/BUILD | 33 +++++ pkg/sync/seqatomictest/seqatomic_test.go | 132 ++++++++++++++++++ pkg/sync/seqcount.go | 149 ++++++++++++++++++++ pkg/sync/seqcount_test.go | 153 +++++++++++++++++++++ pkg/sync/syncutil.go | 7 + pkg/syncutil/BUILD | 52 ------- pkg/syncutil/LICENSE | 27 ---- pkg/syncutil/README.md | 5 - pkg/syncutil/atomicptr_unsafe.go | 47 ------- pkg/syncutil/atomicptrtest/BUILD | 29 ---- pkg/syncutil/atomicptrtest/atomicptr_test.go | 31 ----- pkg/syncutil/downgradable_rwmutex_test.go | 150 -------------------- pkg/syncutil/downgradable_rwmutex_unsafe.go | 146 -------------------- pkg/syncutil/memmove_unsafe.go | 28 ---- pkg/syncutil/norace_unsafe.go | 35 ----- pkg/syncutil/race_unsafe.go | 41 ------ pkg/syncutil/seqatomic_unsafe.go | 72 ---------- pkg/syncutil/seqatomictest/BUILD | 35 ----- pkg/syncutil/seqatomictest/seqatomic_test.go | 132 ------------------ pkg/syncutil/seqcount.go | 149 -------------------- pkg/syncutil/seqcount_test.go | 153 --------------------- pkg/syncutil/syncutil.go | 7 - pkg/tcpip/BUILD | 1 + pkg/tcpip/adapters/gonet/BUILD | 1 + pkg/tcpip/adapters/gonet/gonet.go | 2 +- pkg/tcpip/link/fdbased/BUILD | 1 + pkg/tcpip/link/fdbased/endpoint.go | 2 +- pkg/tcpip/link/sharedmem/BUILD | 2 + pkg/tcpip/link/sharedmem/pipe/BUILD | 1 + pkg/tcpip/link/sharedmem/pipe/pipe_test.go | 3 +- pkg/tcpip/link/sharedmem/sharedmem.go | 2 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/network/fragmentation/BUILD | 1 + pkg/tcpip/network/fragmentation/fragmentation.go | 2 +- pkg/tcpip/network/fragmentation/reassembler.go | 2 +- pkg/tcpip/ports/BUILD | 1 + pkg/tcpip/ports/ports.go | 2 +- pkg/tcpip/stack/BUILD | 2 + pkg/tcpip/stack/linkaddrcache.go | 2 +- pkg/tcpip/stack/linkaddrcache_test.go | 2 +- pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack.go | 2 +- pkg/tcpip/stack/transport_demuxer.go | 2 +- pkg/tcpip/tcpip.go | 2 +- pkg/tcpip/transport/icmp/BUILD | 1 + pkg/tcpip/transport/icmp/endpoint.go | 3 +- pkg/tcpip/transport/packet/BUILD | 1 + pkg/tcpip/transport/packet/endpoint.go | 3 +- pkg/tcpip/transport/raw/BUILD | 1 + pkg/tcpip/transport/raw/endpoint.go | 3 +- pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/accept.go | 2 +- pkg/tcpip/transport/tcp/connect.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 2 +- pkg/tcpip/transport/tcp/endpoint_state.go | 2 +- pkg/tcpip/transport/tcp/forwarder.go | 3 +- pkg/tcpip/transport/tcp/protocol.go | 2 +- pkg/tcpip/transport/tcp/segment_queue.go | 2 +- pkg/tcpip/transport/tcp/snd.go | 2 +- pkg/tcpip/transport/udp/BUILD | 1 + pkg/tcpip/transport/udp/endpoint.go | 3 +- pkg/tmutex/BUILD | 1 + pkg/tmutex/tmutex_test.go | 3 +- pkg/unet/BUILD | 1 + pkg/unet/unet_test.go | 3 +- pkg/urpc/BUILD | 1 + pkg/urpc/urpc.go | 2 +- pkg/waiter/BUILD | 1 + pkg/waiter/waiter.go | 2 +- runsc/boot/BUILD | 2 + runsc/boot/compat.go | 2 +- runsc/boot/limits.go | 2 +- runsc/boot/loader.go | 2 +- runsc/boot/loader_test.go | 2 +- runsc/cmd/BUILD | 1 + runsc/cmd/create.go | 1 + runsc/cmd/gofer.go | 2 +- runsc/cmd/start.go | 1 + runsc/container/BUILD | 2 + runsc/container/console_test.go | 2 +- runsc/container/container_test.go | 2 +- runsc/container/multi_container_test.go | 2 +- runsc/container/state_file.go | 2 +- runsc/fsgofer/BUILD | 1 + runsc/fsgofer/fsgofer.go | 2 +- runsc/sandbox/BUILD | 1 + runsc/sandbox/sandbox.go | 2 +- runsc/testutil/BUILD | 1 + runsc/testutil/testutil.go | 2 +- 303 files changed, 1507 insertions(+), 1368 deletions(-) create mode 100644 pkg/sync/BUILD create mode 100644 pkg/sync/LICENSE create mode 100644 pkg/sync/README.md create mode 100644 pkg/sync/aliases.go create mode 100644 pkg/sync/atomicptr_unsafe.go create mode 100644 pkg/sync/atomicptrtest/BUILD create mode 100644 pkg/sync/atomicptrtest/atomicptr_test.go create mode 100644 pkg/sync/downgradable_rwmutex_test.go create mode 100644 pkg/sync/downgradable_rwmutex_unsafe.go create mode 100644 pkg/sync/memmove_unsafe.go create mode 100644 pkg/sync/norace_unsafe.go create mode 100644 pkg/sync/race_unsafe.go create mode 100644 pkg/sync/seqatomic_unsafe.go create mode 100644 pkg/sync/seqatomictest/BUILD create mode 100644 pkg/sync/seqatomictest/seqatomic_test.go create mode 100644 pkg/sync/seqcount.go create mode 100644 pkg/sync/seqcount_test.go create mode 100644 pkg/sync/syncutil.go delete mode 100644 pkg/syncutil/BUILD delete mode 100644 pkg/syncutil/LICENSE delete mode 100644 pkg/syncutil/README.md delete mode 100644 pkg/syncutil/atomicptr_unsafe.go delete mode 100644 pkg/syncutil/atomicptrtest/BUILD delete mode 100644 pkg/syncutil/atomicptrtest/atomicptr_test.go delete mode 100644 pkg/syncutil/downgradable_rwmutex_test.go delete mode 100644 pkg/syncutil/downgradable_rwmutex_unsafe.go delete mode 100644 pkg/syncutil/memmove_unsafe.go delete mode 100644 pkg/syncutil/norace_unsafe.go delete mode 100644 pkg/syncutil/race_unsafe.go delete mode 100644 pkg/syncutil/seqatomic_unsafe.go delete mode 100644 pkg/syncutil/seqatomictest/BUILD delete mode 100644 pkg/syncutil/seqatomictest/seqatomic_test.go delete mode 100644 pkg/syncutil/seqcount.go delete mode 100644 pkg/syncutil/seqcount_test.go delete mode 100644 pkg/syncutil/syncutil.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index 6bc486b62..d99e37b40 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -15,4 +15,5 @@ go_test( size = "small", srcs = ["amutex_test.go"], embed = [":amutex"], + deps = ["//pkg/sync"], ) diff --git a/pkg/amutex/amutex_test.go b/pkg/amutex/amutex_test.go index 1d7f45641..8a3952f2a 100644 --- a/pkg/amutex/amutex_test.go +++ b/pkg/amutex/amutex_test.go @@ -15,9 +15,10 @@ package amutex import ( - "sync" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) type sleeper struct { diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 36beaade9..6403c60c2 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -20,4 +20,5 @@ go_test( size = "small", srcs = ["atomic_bitops_test.go"], embed = [":atomicbitops"], + deps = ["//pkg/sync"], ) diff --git a/pkg/atomicbitops/atomic_bitops_test.go b/pkg/atomicbitops/atomic_bitops_test.go index 965e9be79..9466d3e23 100644 --- a/pkg/atomicbitops/atomic_bitops_test.go +++ b/pkg/atomicbitops/atomic_bitops_test.go @@ -16,8 +16,9 @@ package atomicbitops import ( "runtime" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) const iterations = 100 diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index a0b21d4bd..2bb581b18 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -8,7 +8,10 @@ go_library( srcs = ["compressio.go"], importpath = "gvisor.dev/gvisor/pkg/compressio", visibility = ["//:sandbox"], - deps = ["//pkg/binary"], + deps = [ + "//pkg/binary", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go index 3b0bb086e..5f52cbe74 100644 --- a/pkg/compressio/compressio.go +++ b/pkg/compressio/compressio.go @@ -52,9 +52,9 @@ import ( "hash" "io" "runtime" - "sync" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sync" ) var bufPool = sync.Pool{ diff --git a/pkg/control/server/BUILD b/pkg/control/server/BUILD index 21adf3adf..adbd1e3f8 100644 --- a/pkg/control/server/BUILD +++ b/pkg/control/server/BUILD @@ -9,6 +9,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/unet", "//pkg/urpc", ], diff --git a/pkg/control/server/server.go b/pkg/control/server/server.go index a56152d10..41abe1f2d 100644 --- a/pkg/control/server/server.go +++ b/pkg/control/server/server.go @@ -22,9 +22,9 @@ package server import ( "os" - "sync" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/urpc" ) diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 0b4b7cc44..9d68682c7 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -15,6 +15,7 @@ go_library( deps = [ ":eventchannel_go_proto", "//pkg/log", + "//pkg/sync", "//pkg/unet", "@com_github_golang_protobuf//proto:go_default_library", "@com_github_golang_protobuf//ptypes:go_default_library_gen", @@ -40,6 +41,7 @@ go_test( srcs = ["event_test.go"], embed = [":eventchannel"], deps = [ + "//pkg/sync", "@com_github_golang_protobuf//proto:go_default_library", ], ) diff --git a/pkg/eventchannel/event.go b/pkg/eventchannel/event.go index d37ad0428..9a29c58bd 100644 --- a/pkg/eventchannel/event.go +++ b/pkg/eventchannel/event.go @@ -22,13 +22,13 @@ package eventchannel import ( "encoding/binary" "fmt" - "sync" "syscall" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" pb "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go index 3649097d6..7f41b4a27 100644 --- a/pkg/eventchannel/event_test.go +++ b/pkg/eventchannel/event_test.go @@ -16,11 +16,11 @@ package eventchannel import ( "fmt" - "sync" "testing" "time" "github.com/golang/protobuf/proto" + "gvisor.dev/gvisor/pkg/sync" ) // testEmitter is an emitter that can be used in tests. It records all events diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD index 56495cbd9..b0478c672 100644 --- a/pkg/fdchannel/BUILD +++ b/pkg/fdchannel/BUILD @@ -15,4 +15,5 @@ go_test( size = "small", srcs = ["fdchannel_test.go"], embed = [":fdchannel"], + deps = ["//pkg/sync"], ) diff --git a/pkg/fdchannel/fdchannel_test.go b/pkg/fdchannel/fdchannel_test.go index 5d01dc636..7a8a63a59 100644 --- a/pkg/fdchannel/fdchannel_test.go +++ b/pkg/fdchannel/fdchannel_test.go @@ -17,10 +17,11 @@ package fdchannel import ( "io/ioutil" "os" - "sync" "syscall" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) func TestSendRecvFD(t *testing.T) { diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD index aca2d8a82..91a202a30 100644 --- a/pkg/fdnotifier/BUILD +++ b/pkg/fdnotifier/BUILD @@ -11,6 +11,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/fdnotifier", visibility = ["//:sandbox"], deps = [ + "//pkg/sync", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/fdnotifier/fdnotifier.go b/pkg/fdnotifier/fdnotifier.go index f4aae1953..a6b63c982 100644 --- a/pkg/fdnotifier/fdnotifier.go +++ b/pkg/fdnotifier/fdnotifier.go @@ -22,10 +22,10 @@ package fdnotifier import ( "fmt" - "sync" "syscall" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index e590a71ba..85bd83af1 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -19,7 +19,7 @@ go_library( "//pkg/abi/linux", "//pkg/log", "//pkg/memutil", - "//pkg/syncutil", + "//pkg/sync", ], ) @@ -31,4 +31,5 @@ go_test( "flipcall_test.go", ], embed = [":flipcall"], + deps = ["//pkg/sync"], ) diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go index 8d88b845d..2e28a149a 100644 --- a/pkg/flipcall/flipcall_example_test.go +++ b/pkg/flipcall/flipcall_example_test.go @@ -17,7 +17,8 @@ package flipcall import ( "bytes" "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) func Example() { diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go index 168a487ec..33fd55a44 100644 --- a/pkg/flipcall/flipcall_test.go +++ b/pkg/flipcall/flipcall_test.go @@ -16,9 +16,10 @@ package flipcall import ( "runtime" - "sync" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) var testPacketWindowSize = pageSize diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go index 27b8939fc..ac974b232 100644 --- a/pkg/flipcall/flipcall_unsafe.go +++ b/pkg/flipcall/flipcall_unsafe.go @@ -18,7 +18,7 @@ import ( "reflect" "unsafe" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" ) // Packets consist of a 16-byte header followed by an arbitrarily-sized @@ -75,13 +75,13 @@ func (ep *Endpoint) Data() []byte { var ioSync int64 func raceBecomeActive() { - if syncutil.RaceEnabled { - syncutil.RaceAcquire((unsafe.Pointer)(&ioSync)) + if sync.RaceEnabled { + sync.RaceAcquire((unsafe.Pointer)(&ioSync)) } } func raceBecomeInactive() { - if syncutil.RaceEnabled { - syncutil.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) + if sync.RaceEnabled { + sync.RaceReleaseMerge((unsafe.Pointer)(&ioSync)) } } diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD index 4b9321711..f22bd070d 100644 --- a/pkg/gate/BUILD +++ b/pkg/gate/BUILD @@ -19,5 +19,6 @@ go_test( ], deps = [ ":gate", + "//pkg/sync", ], ) diff --git a/pkg/gate/gate_test.go b/pkg/gate/gate_test.go index 5dbd8d712..850693df8 100644 --- a/pkg/gate/gate_test.go +++ b/pkg/gate/gate_test.go @@ -15,11 +15,11 @@ package gate_test import ( - "sync" "testing" "time" "gvisor.dev/gvisor/pkg/gate" + "gvisor.dev/gvisor/pkg/sync" ) func TestBasicEnter(t *testing.T) { diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index a5d980d14..bcde6d308 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -8,6 +8,7 @@ go_library( srcs = ["linewriter.go"], importpath = "gvisor.dev/gvisor/pkg/linewriter", visibility = ["//visibility:public"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/linewriter/linewriter.go b/pkg/linewriter/linewriter.go index cd6e4e2ce..a1b1285d4 100644 --- a/pkg/linewriter/linewriter.go +++ b/pkg/linewriter/linewriter.go @@ -17,7 +17,8 @@ package linewriter import ( "bytes" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // Writer is an io.Writer which buffers input, flushing diff --git a/pkg/log/BUILD b/pkg/log/BUILD index fc5f5779b..0df0f2849 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -16,7 +16,10 @@ go_library( visibility = [ "//visibility:public", ], - deps = ["//pkg/linewriter"], + deps = [ + "//pkg/linewriter", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/log/log.go b/pkg/log/log.go index 9387586e6..91a81b288 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -25,12 +25,12 @@ import ( stdlog "log" "os" "runtime" - "sync" "sync/atomic" "syscall" "time" "gvisor.dev/gvisor/pkg/linewriter" + "gvisor.dev/gvisor/pkg/sync" ) // Level is the log level. diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD index dd6ca6d39..9145f3233 100644 --- a/pkg/metric/BUILD +++ b/pkg/metric/BUILD @@ -14,6 +14,7 @@ go_library( ":metric_go_proto", "//pkg/eventchannel", "//pkg/log", + "//pkg/sync", ], ) diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index eadde06e4..93d4f2b8c 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -18,12 +18,12 @@ package metric import ( "errors" "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/eventchannel" "gvisor.dev/gvisor/pkg/log" pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto" + "gvisor.dev/gvisor/pkg/sync" ) var ( diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index f32244c69..a3e05c96d 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -29,6 +29,7 @@ go_library( "//pkg/fdchannel", "//pkg/flipcall", "//pkg/log", + "//pkg/sync", "//pkg/unet", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/p9/client.go b/pkg/p9/client.go index 221516c6c..4045e41fa 100644 --- a/pkg/p9/client.go +++ b/pkg/p9/client.go @@ -17,12 +17,12 @@ package p9 import ( "errors" "fmt" - "sync" "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index 28707c0ca..f4edd68b2 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -70,6 +70,7 @@ go_library( "//pkg/fd", "//pkg/log", "//pkg/p9", + "//pkg/sync", "//pkg/unet", "@com_github_golang_mock//gomock:go_default_library", ], @@ -83,6 +84,7 @@ go_test( deps = [ "//pkg/fd", "//pkg/p9", + "//pkg/sync", "@com_github_golang_mock//gomock:go_default_library", ], ) diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index 6e758148d..6e7bb3db2 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -22,7 +22,6 @@ import ( "os" "reflect" "strings" - "sync" "syscall" "testing" "time" @@ -30,6 +29,7 @@ import ( "github.com/golang/mock/gomock" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" ) func TestPanic(t *testing.T) { diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go index 4d3271b37..dd8b01b6d 100644 --- a/pkg/p9/p9test/p9test.go +++ b/pkg/p9/p9test/p9test.go @@ -17,13 +17,13 @@ package p9test import ( "fmt" - "sync" "sync/atomic" "syscall" "testing" "github.com/golang/mock/gomock" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/p9/path_tree.go b/pkg/p9/path_tree.go index 865459411..72ef53313 100644 --- a/pkg/p9/path_tree.go +++ b/pkg/p9/path_tree.go @@ -16,7 +16,8 @@ package p9 import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // pathNode is a single node in a path traversal. diff --git a/pkg/p9/pool.go b/pkg/p9/pool.go index 52de889e1..2b14a5ce3 100644 --- a/pkg/p9/pool.go +++ b/pkg/p9/pool.go @@ -15,7 +15,7 @@ package p9 import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // pool is a simple allocator. diff --git a/pkg/p9/server.go b/pkg/p9/server.go index 40b8fa023..fdfa83648 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -17,7 +17,6 @@ package p9 import ( "io" "runtime/debug" - "sync" "sync/atomic" "syscall" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/fdchannel" "gvisor.dev/gvisor/pkg/flipcall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go index 6e8b4bbcd..9c11e28ce 100644 --- a/pkg/p9/transport.go +++ b/pkg/p9/transport.go @@ -19,11 +19,11 @@ import ( "fmt" "io" "io/ioutil" - "sync" "syscall" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD index 078f084b2..b506813f0 100644 --- a/pkg/procid/BUILD +++ b/pkg/procid/BUILD @@ -21,6 +21,7 @@ go_test( "procid_test.go", ], embed = [":procid"], + deps = ["//pkg/sync"], ) go_test( @@ -31,4 +32,5 @@ go_test( "procid_test.go", ], embed = [":procid"], + deps = ["//pkg/sync"], ) diff --git a/pkg/procid/procid_test.go b/pkg/procid/procid_test.go index 88dd0b3ae..9ec08c3d6 100644 --- a/pkg/procid/procid_test.go +++ b/pkg/procid/procid_test.go @@ -17,9 +17,10 @@ package procid import ( "os" "runtime" - "sync" "syscall" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) // runOnMain is used to send functions to run on the main (initial) thread. diff --git a/pkg/rand/BUILD b/pkg/rand/BUILD index f4f2001f3..9d5b4859b 100644 --- a/pkg/rand/BUILD +++ b/pkg/rand/BUILD @@ -10,5 +10,8 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/rand", visibility = ["//:sandbox"], - deps = ["@org_golang_x_sys//unix:go_default_library"], + deps = [ + "//pkg/sync", + "@org_golang_x_sys//unix:go_default_library", + ], ) diff --git a/pkg/rand/rand_linux.go b/pkg/rand/rand_linux.go index 2b92db3e6..0bdad5fad 100644 --- a/pkg/rand/rand_linux.go +++ b/pkg/rand/rand_linux.go @@ -19,9 +19,9 @@ package rand import ( "crypto/rand" "io" - "sync" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sync" ) // reader implements an io.Reader that returns pseudorandom bytes. diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 7ad59dfd7..974d9af9b 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -27,6 +27,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/log", + "//pkg/sync", ], ) @@ -35,4 +36,5 @@ go_test( size = "small", srcs = ["refcounter_test.go"], embed = [":refs"], + deps = ["//pkg/sync"], ) diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index ad69e0757..c45ba8200 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -21,10 +21,10 @@ import ( "fmt" "reflect" "runtime" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" ) // RefCounter is the interface to be implemented by objects that are reference diff --git a/pkg/refs/refcounter_test.go b/pkg/refs/refcounter_test.go index ffd3d3f07..1ab4a4440 100644 --- a/pkg/refs/refcounter_test.go +++ b/pkg/refs/refcounter_test.go @@ -16,8 +16,9 @@ package refs import ( "reflect" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) type testCounter struct { diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 18c73cc24..ae3e364cd 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/sentry/context", "//pkg/sentry/limits", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index 9294ac773..9f41e566f 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -19,7 +19,6 @@ package arch import ( "fmt" "io" - "sync" "syscall" "gvisor.dev/gvisor/pkg/binary" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/log" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index 5522cecd0..2561a6109 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/sentry/strace", "//pkg/sentry/usage", "//pkg/sentry/watchdog", + "//pkg/sync", "//pkg/tcpip/link/sniffer", "//pkg/urpc", ], diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index e1f2fea60..151808911 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -19,10 +19,10 @@ import ( "runtime" "runtime/pprof" "runtime/trace" - "sync" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/urpc" ) diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD index 1098ed777..97fa1512c 100644 --- a/pkg/sentry/device/BUILD +++ b/pkg/sentry/device/BUILD @@ -8,7 +8,10 @@ go_library( srcs = ["device.go"], importpath = "gvisor.dev/gvisor/pkg/sentry/device", visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/abi/linux"], + deps = [ + "//pkg/abi/linux", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/sentry/device/device.go b/pkg/sentry/device/device.go index 47945d1a7..69e71e322 100644 --- a/pkg/sentry/device/device.go +++ b/pkg/sentry/device/device.go @@ -19,10 +19,10 @@ package device import ( "bytes" "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sync" ) // Registry tracks all simple devices and related state on the system for diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index c035ffff7..7d5d72d5a 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -68,7 +68,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], @@ -115,6 +115,7 @@ go_test( "//pkg/sentry/fs/tmpfs", "//pkg/sentry/kernel/contexttest", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index 9ac62c84d..734177e90 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -17,12 +17,12 @@ package fs import ( "fmt" "io" - "sync" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go index 1d80bf15a..738580c5f 100644 --- a/pkg/sentry/fs/copy_up_test.go +++ b/pkg/sentry/fs/copy_up_test.go @@ -19,13 +19,13 @@ import ( "crypto/rand" "fmt" "io" - "sync" "testing" "gvisor.dev/gvisor/pkg/sentry/fs" _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) const ( diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index 3cb73bd78..31fc4d87b 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -18,7 +18,6 @@ import ( "fmt" "path" "sort" - "sync" "sync/atomic" "syscall" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/uniqueid" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go index 60a15a275..25514ace4 100644 --- a/pkg/sentry/fs/dirent_cache.go +++ b/pkg/sentry/fs/dirent_cache.go @@ -16,7 +16,8 @@ package fs import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // DirentCache is an LRU cache of Dirents. The Dirent's refCount is diff --git a/pkg/sentry/fs/dirent_cache_limiter.go b/pkg/sentry/fs/dirent_cache_limiter.go index ebb80bd50..525ee25f9 100644 --- a/pkg/sentry/fs/dirent_cache_limiter.go +++ b/pkg/sentry/fs/dirent_cache_limiter.go @@ -16,7 +16,8 @@ package fs import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // DirentCacheLimiter acts as a global limit for all dirent caches in the diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index 277ee4c31..cc43de69d 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -23,6 +23,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/safemem", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go index 669ffcb75..5b6cfeb0a 100644 --- a/pkg/sentry/fs/fdpipe/pipe.go +++ b/pkg/sentry/fs/fdpipe/pipe.go @@ -17,7 +17,6 @@ package fdpipe import ( "os" - "sync" "syscall" "gvisor.dev/gvisor/pkg/fd" @@ -29,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/fdpipe/pipe_state.go b/pkg/sentry/fs/fdpipe/pipe_state.go index 29175fb3d..cee87f726 100644 --- a/pkg/sentry/fs/fdpipe/pipe_state.go +++ b/pkg/sentry/fs/fdpipe/pipe_state.go @@ -17,10 +17,10 @@ package fdpipe import ( "fmt" "io/ioutil" - "sync" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" ) // beforeSave is invoked by stateify. diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index a2f966cb6..7c4586296 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -16,7 +16,6 @@ package fs import ( "math" - "sync" "sync/atomic" "time" @@ -29,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go index 225e40186..8a633b1ba 100644 --- a/pkg/sentry/fs/file_overlay.go +++ b/pkg/sentry/fs/file_overlay.go @@ -16,13 +16,13 @@ package fs import ( "io" - "sync" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/filesystems.go b/pkg/sentry/fs/filesystems.go index b157fd228..c5b51620a 100644 --- a/pkg/sentry/fs/filesystems.go +++ b/pkg/sentry/fs/filesystems.go @@ -18,9 +18,9 @@ import ( "fmt" "sort" "strings" - "sync" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" ) // FilesystemFlags matches include/linux/fs.h:file_system_type.fs_flags. diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go index 8b2a5e6b2..26abf49e2 100644 --- a/pkg/sentry/fs/fs.go +++ b/pkg/sentry/fs/fs.go @@ -54,10 +54,9 @@ package fs import ( - "sync" - "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" ) var ( diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 9ca695a95..945b6270d 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -93,6 +93,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index b06a71cc2..837fc70b5 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -16,7 +16,6 @@ package fsutil import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/log" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // HostFileMapper caches mappings of an arbitrary host file descriptor. It is diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index 30475f340..a625f0e26 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -16,7 +16,6 @@ package fsutil import ( "math" - "sync" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // HostMappable implements memmap.Mappable and platform.File over a diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go index 4e100a402..adf5ec69c 100644 --- a/pkg/sentry/fs/fsutil/inode.go +++ b/pkg/sentry/fs/fsutil/inode.go @@ -15,13 +15,12 @@ package fsutil import ( - "sync" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 798920d18..20a014402 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -17,7 +17,6 @@ package fsutil import ( "fmt" "io" - "sync" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" @@ -30,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // Lock order (compare the lock order model in mm/mm.go): diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index 4a005c605..fd870e8e1 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -44,6 +44,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/unet", diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 91263ebdc..245fe2ef1 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -16,7 +16,6 @@ package gofer import ( "errors" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/host" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go index 4e358a46a..edc796ce0 100644 --- a/pkg/sentry/fs/gofer/session.go +++ b/pkg/sentry/fs/gofer/session.go @@ -16,7 +16,6 @@ package gofer import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/refs" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 23daeb528..2b581aa69 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -50,6 +50,7 @@ go_library( "//pkg/sentry/unimpl", "//pkg/sentry/uniqueid", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index a6e4a09e3..873a1c52d 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -15,7 +15,6 @@ package host import ( - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 107336a3e..c076d5bdd 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -16,7 +16,6 @@ package host import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -30,6 +29,7 @@ import ( unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/uniqueid" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index 90331e3b2..753ef8cd6 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -15,8 +15,6 @@ package host import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" @@ -24,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index 91e2fde2f..468043df0 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -15,8 +15,6 @@ package fs import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" @@ -26,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/inode_inotify.go b/pkg/sentry/fs/inode_inotify.go index 0f2a66a79..efd3c962b 100644 --- a/pkg/sentry/fs/inode_inotify.go +++ b/pkg/sentry/fs/inode_inotify.go @@ -16,7 +16,8 @@ package fs import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // Watches is the collection of inotify watches on an inode. diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index ba3e0233d..cc7dd1c92 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -16,7 +16,6 @@ package fs import ( "io" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/inotify_watch.go b/pkg/sentry/fs/inotify_watch.go index 0aa0a5e9b..900cba3ca 100644 --- a/pkg/sentry/fs/inotify_watch.go +++ b/pkg/sentry/fs/inotify_watch.go @@ -15,10 +15,10 @@ package fs import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sync" ) // Watch represent a particular inotify watch created by inotify_add_watch. diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index 8d62642e7..2c332a82a 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -44,6 +44,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/waiter", ], ) diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go index 636484424..41b040818 100644 --- a/pkg/sentry/fs/lock/lock.go +++ b/pkg/sentry/fs/lock/lock.go @@ -52,9 +52,9 @@ package lock import ( "fmt" "math" - "sync" "syscall" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go index ac0398bd9..db3dfd096 100644 --- a/pkg/sentry/fs/mounts.go +++ b/pkg/sentry/fs/mounts.go @@ -19,7 +19,6 @@ import ( "math" "path" "strings" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go index 25573e986..4cad55327 100644 --- a/pkg/sentry/fs/overlay.go +++ b/pkg/sentry/fs/overlay.go @@ -17,13 +17,12 @@ package fs import ( "fmt" "strings" - "sync" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) @@ -199,7 +198,7 @@ type overlayEntry struct { upper *Inode // dirCacheMu protects dirCache. - dirCacheMu syncutil.DowngradableRWMutex `state:"nosave"` + dirCacheMu sync.DowngradableRWMutex `state:"nosave"` // dirCache is cache of DentAttrs from upper and lower Inodes. dirCache *SortedDentryMap diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index 75cbb0622..94d46ab1b 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -51,6 +51,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/header", "//pkg/waiter", diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index fe7067be1..38b246dff 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/sentry/fs/proc/device", "//pkg/sentry/kernel/time", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/proc/seqfile/seqfile.go b/pkg/sentry/fs/proc/seqfile/seqfile.go index 5fe823000..f9af191d5 100644 --- a/pkg/sentry/fs/proc/seqfile/seqfile.go +++ b/pkg/sentry/fs/proc/seqfile/seqfile.go @@ -17,7 +17,6 @@ package seqfile import ( "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -26,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/proc/device" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index bd93f83fa..a37e1fa06 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -17,7 +17,6 @@ package proc import ( "fmt" "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index 012cb3e44..3fb7b0633 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go index 78e082b8e..dcbb8eb2e 100644 --- a/pkg/sentry/fs/ramfs/dir.go +++ b/pkg/sentry/fs/ramfs/dir.go @@ -17,7 +17,6 @@ package ramfs import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/restore.go b/pkg/sentry/fs/restore.go index f10168125..64c6a6ae9 100644 --- a/pkg/sentry/fs/restore.go +++ b/pkg/sentry/fs/restore.go @@ -15,7 +15,7 @@ package fs import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // RestoreEnvironment is the restore environment for file systems. It consists diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 59ce400c2..3400b940c 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index f86dfaa36..f1c87fe41 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -17,7 +17,6 @@ package tmpfs import ( "fmt" "io" - "sync" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 95ad98cb0..f6f60d0cf 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index 2f639c823..88aa66b24 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -19,7 +19,6 @@ import ( "fmt" "math" "strconv" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go index 7cc0eb409..894964260 100644 --- a/pkg/sentry/fs/tty/line_discipline.go +++ b/pkg/sentry/fs/tty/line_discipline.go @@ -16,13 +16,13 @@ package tty import ( "bytes" - "sync" "unicode/utf8" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go index 231e4e6eb..8b5d4699a 100644 --- a/pkg/sentry/fs/tty/queue.go +++ b/pkg/sentry/fs/tty/queue.go @@ -15,13 +15,12 @@ package tty import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index bc90330bc..903874141 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -50,6 +50,7 @@ go_library( "//pkg/sentry/syscalls/linux", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go index 91802dc1e..8944171c8 100644 --- a/pkg/sentry/fsimpl/ext/directory.go +++ b/pkg/sentry/fsimpl/ext/directory.go @@ -15,8 +15,6 @@ package ext import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/log" @@ -25,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 616fc002a..9afb1a84c 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -17,13 +17,13 @@ package ext import ( "errors" "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go index aec33e00a..d11153c90 100644 --- a/pkg/sentry/fsimpl/ext/regular_file.go +++ b/pkg/sentry/fsimpl/ext/regular_file.go @@ -16,7 +16,6 @@ package ext import ( "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 39c03ee9d..809178250 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -39,6 +39,7 @@ go_library( "//pkg/sentry/memmap", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", ], ) @@ -56,6 +57,7 @@ go_test( "//pkg/sentry/kernel/auth", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "@com_github_google_go-cmp//cmp:go_default_library", ], diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 752e0f659..1d469a0db 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -16,7 +16,6 @@ package kernfs import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index d69b299ae..bb12f39a2 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -53,7 +53,6 @@ package kernfs import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -61,6 +60,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" ) // FilesystemType implements vfs.FilesystemType. diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 4b6b95f5f..5c9d580e1 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -19,7 +19,6 @@ import ( "fmt" "io" "runtime" - "sync" "testing" "github.com/google/go-cmp/cmp" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index a5b285987..82f5c2f41 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -47,6 +47,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index f51e247a7..f200e767d 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -17,7 +17,6 @@ package tmpfs import ( "io" "math" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -30,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 7be6faa5b..701826f90 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -26,7 +26,6 @@ package tmpfs import ( "fmt" "math" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -34,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 2706927ff..ac85ba0c8 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -35,7 +35,7 @@ go_template_instance( out = "seqatomic_taskgoroutineschedinfo_unsafe.go", package = "kernel", suffix = "TaskGoroutineSchedInfo", - template = "//pkg/syncutil:generic_seqatomic", + template = "//pkg/sync:generic_seqatomic", types = { "Value": "TaskGoroutineSchedInfo", }, @@ -209,7 +209,7 @@ go_library( "//pkg/sentry/usermem", "//pkg/state", "//pkg/state/statefile", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", @@ -241,6 +241,7 @@ go_test( "//pkg/sentry/time", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go index 244655b5c..920fe4329 100644 --- a/pkg/sentry/kernel/abstract_socket_namespace.go +++ b/pkg/sentry/kernel/abstract_socket_namespace.go @@ -15,11 +15,11 @@ package kernel import ( - "sync" "syscall" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sync" ) // +stateify savable diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 04c244447..1aa72fa47 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "atomicptr_credentials_unsafe.go", package = "auth", suffix = "Credentials", - template = "//pkg/syncutil:generic_atomicptr", + template = "//pkg/sync:generic_atomicptr", types = { "Value": "Credentials", }, @@ -64,6 +64,7 @@ go_library( "//pkg/bits", "//pkg/log", "//pkg/sentry/context", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/auth/user_namespace.go b/pkg/sentry/kernel/auth/user_namespace.go index af28ccc65..9dd52c860 100644 --- a/pkg/sentry/kernel/auth/user_namespace.go +++ b/pkg/sentry/kernel/auth/user_namespace.go @@ -16,8 +16,8 @@ package auth import ( "math" - "sync" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD index 3361e8b7d..c47f6b6fc 100644 --- a/pkg/sentry/kernel/epoll/BUILD +++ b/pkg/sentry/kernel/epoll/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 9c0a4e1b4..430311cc0 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -18,7 +18,6 @@ package epoll import ( "fmt" - "sync" "syscall" "gvisor.dev/gvisor/pkg/refs" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index e65b961e8..c831fbab2 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/sentry/fs/anon", "//pkg/sentry/fs/fsutil", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go index 12f0d429b..687690679 100644 --- a/pkg/sentry/kernel/eventfd/eventfd.go +++ b/pkg/sentry/kernel/eventfd/eventfd.go @@ -18,7 +18,6 @@ package eventfd import ( "math" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/anon" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index 49d81b712..6b36bc63e 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -12,6 +12,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go index 6b0bb0324..d32c3e90a 100644 --- a/pkg/sentry/kernel/fasync/fasync.go +++ b/pkg/sentry/kernel/fasync/fasync.go @@ -16,12 +16,11 @@ package fasync import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 11f613a11..cd1501f85 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -18,7 +18,6 @@ import ( "bytes" "fmt" "math" - "sync" "sync/atomic" "syscall" @@ -28,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sync" ) // FDFlags define flags for an individual descriptor. diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go index 2bcb6216a..eccb7d1e7 100644 --- a/pkg/sentry/kernel/fd_table_test.go +++ b/pkg/sentry/kernel/fd_table_test.go @@ -16,7 +16,6 @@ package kernel import ( "runtime" - "sync" "testing" "gvisor.dev/gvisor/pkg/sentry/context" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/filetest" "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sync" ) const ( diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index ded27d668..2448c1d99 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -16,10 +16,10 @@ package kernel import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" ) // FSContext contains filesystem context. diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 75ec31761..50db443ce 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -9,7 +9,7 @@ go_template_instance( out = "atomicptr_bucket_unsafe.go", package = "futex", suffix = "Bucket", - template = "//pkg/syncutil:generic_atomicptr", + template = "//pkg/sync:generic_atomicptr", types = { "Value": "bucket", }, @@ -42,6 +42,7 @@ go_library( "//pkg/sentry/context", "//pkg/sentry/memmap", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) @@ -51,5 +52,8 @@ go_test( size = "small", srcs = ["futex_test.go"], embed = [":futex"], - deps = ["//pkg/sentry/usermem"], + deps = [ + "//pkg/sentry/usermem", + "//pkg/sync", + ], ) diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go index 278cc8143..d1931c8f4 100644 --- a/pkg/sentry/kernel/futex/futex.go +++ b/pkg/sentry/kernel/futex/futex.go @@ -18,11 +18,10 @@ package futex import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go index 65e5d1428..c23126ca5 100644 --- a/pkg/sentry/kernel/futex/futex_test.go +++ b/pkg/sentry/kernel/futex/futex_test.go @@ -17,13 +17,13 @@ package futex import ( "math" "runtime" - "sync" "sync/atomic" "syscall" "testing" "unsafe" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // testData implements the Target interface, and allows us to diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 8653d2f63..c85e97fef 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -36,7 +36,6 @@ import ( "fmt" "io" "path/filepath" - "sync" "sync/atomic" "time" @@ -67,6 +66,7 @@ import ( uspb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto" "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD index d7a7d1169..7f36252a9 100644 --- a/pkg/sentry/kernel/memevent/BUILD +++ b/pkg/sentry/kernel/memevent/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/metric", "//pkg/sentry/kernel", "//pkg/sentry/usage", + "//pkg/sync", ], ) diff --git a/pkg/sentry/kernel/memevent/memory_events.go b/pkg/sentry/kernel/memevent/memory_events.go index b0d98e7f0..200565bb8 100644 --- a/pkg/sentry/kernel/memevent/memory_events.go +++ b/pkg/sentry/kernel/memevent/memory_events.go @@ -17,7 +17,6 @@ package memevent import ( - "sync" "time" "gvisor.dev/gvisor/pkg/eventchannel" @@ -26,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" pb "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto" "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/sync" ) var totalTicks = metric.MustCreateNewUint64Metric("/memory_events/ticks", false /*sync*/, "Total number of memory event periods that have elapsed since startup.") diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 9d34f6d4d..5eeaeff66 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -43,6 +43,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/usermem", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go index 95bee2d37..1c0f34269 100644 --- a/pkg/sentry/kernel/pipe/buffer.go +++ b/pkg/sentry/kernel/pipe/buffer.go @@ -16,9 +16,9 @@ package pipe import ( "io" - "sync" "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/sync" ) // buffer encapsulates a queueable byte buffer. diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go index 4a19ab7ce..716f589af 100644 --- a/pkg/sentry/kernel/pipe/node.go +++ b/pkg/sentry/kernel/pipe/node.go @@ -15,12 +15,11 @@ package pipe import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 1a1b38f83..e4fd7d420 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -17,12 +17,12 @@ package pipe import ( "fmt" - "sync" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index ef9641e6a..8394eb78b 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -17,7 +17,6 @@ package pipe import ( "io" "math" - "sync" "syscall" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 6416e0dd8..bf7461cbb 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -15,13 +15,12 @@ package pipe import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index f4c00cd86..13a961594 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index de9617e9d..18299814e 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -17,7 +17,6 @@ package semaphore import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index cd48945e6..7321b22ed 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -24,6 +24,7 @@ go_library( "//pkg/sentry/platform", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index 19034a21e..8ddef7eb8 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -35,7 +35,6 @@ package shm import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" @@ -49,6 +48,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/signal_handlers.go b/pkg/sentry/kernel/signal_handlers.go index a16f3d57f..768fda220 100644 --- a/pkg/sentry/kernel/signal_handlers.go +++ b/pkg/sentry/kernel/signal_handlers.go @@ -15,10 +15,9 @@ package kernel import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sync" ) // SignalHandlers holds information about signal actions. diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 9f7e19b4d..89e4d84b1 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 4b08d7d72..28be4a939 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -16,8 +16,6 @@ package signalfd import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/context" @@ -26,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index 2fdee0282..d2d01add4 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -16,13 +16,13 @@ package kernel import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // maxSyscallNum is the highest supported syscall number. diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go index 8227ecf1d..4607cde2f 100644 --- a/pkg/sentry/kernel/syslog.go +++ b/pkg/sentry/kernel/syslog.go @@ -17,7 +17,8 @@ package kernel import ( "fmt" "math/rand" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // syslog represents a sentry-global kernel log. diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index d25a7903b..978d66da8 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -17,7 +17,6 @@ package kernel import ( gocontext "context" "runtime/trace" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -37,7 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) @@ -85,7 +84,7 @@ type Task struct { // // gosched is protected by goschedSeq. gosched is owned by the task // goroutine. - goschedSeq syncutil.SeqCount `state:"nosave"` + goschedSeq sync.SeqCount `state:"nosave"` gosched TaskGoroutineSchedInfo // yieldCount is the number of times the task goroutine has called diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index c0197a563..768e958d2 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -15,7 +15,6 @@ package kernel import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,6 +24,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index 8267929a6..bf2dabb6e 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -16,9 +16,9 @@ package kernel import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 31847e1df..4e4de0512 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -13,6 +13,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/sentry/context", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go index 107394183..706de83ef 100644 --- a/pkg/sentry/kernel/time/time.go +++ b/pkg/sentry/kernel/time/time.go @@ -19,10 +19,10 @@ package time import ( "fmt" "math" - "sync" "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go index 76417342a..dc99301de 100644 --- a/pkg/sentry/kernel/timekeeper.go +++ b/pkg/sentry/kernel/timekeeper.go @@ -16,7 +16,6 @@ package kernel import ( "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/log" @@ -24,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" + "gvisor.dev/gvisor/pkg/sync" ) // Timekeeper manages all of the kernel clocks. diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go index 048de26dc..464d2306a 100644 --- a/pkg/sentry/kernel/tty.go +++ b/pkg/sentry/kernel/tty.go @@ -14,7 +14,7 @@ package kernel -import "sync" +import "gvisor.dev/gvisor/pkg/sync" // TTY defines the relationship between a thread group and its controlling // terminal. diff --git a/pkg/sentry/kernel/uts_namespace.go b/pkg/sentry/kernel/uts_namespace.go index 0a563e715..8ccf04bd1 100644 --- a/pkg/sentry/kernel/uts_namespace.go +++ b/pkg/sentry/kernel/uts_namespace.go @@ -15,9 +15,8 @@ package kernel import ( - "sync" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" ) // UTSNamespace represents a UTS namespace, a holder of two system identifiers: diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD index 156e67bf8..9fa841e8b 100644 --- a/pkg/sentry/limits/BUILD +++ b/pkg/sentry/limits/BUILD @@ -15,6 +15,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/sentry/context", + "//pkg/sync", ], ) diff --git a/pkg/sentry/limits/limits.go b/pkg/sentry/limits/limits.go index b6c22656b..31b9e9ff6 100644 --- a/pkg/sentry/limits/limits.go +++ b/pkg/sentry/limits/limits.go @@ -16,8 +16,9 @@ package limits import ( - "sync" "syscall" + + "gvisor.dev/gvisor/pkg/sync" ) // LimitType defines a type of resource limit. diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 839931f67..83e248431 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -118,7 +118,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/usage", "//pkg/sentry/usermem", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/buffer", ], diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 1b746d030..4b48866ad 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -15,8 +15,6 @@ package mm import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/context" @@ -25,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 58a5c186d..fa86ebced 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -35,8 +35,6 @@ package mm import ( - "sync" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -44,7 +42,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" ) // MemoryManager implements a virtual address space. @@ -82,7 +80,7 @@ type MemoryManager struct { users int32 // mappingMu is analogous to Linux's struct mm_struct::mmap_sem. - mappingMu syncutil.DowngradableRWMutex `state:"nosave"` + mappingMu sync.DowngradableRWMutex `state:"nosave"` // vmas stores virtual memory areas. Since vmas are stored by value, // clients should usually use vmaIterator.ValuePtr() instead of @@ -125,7 +123,7 @@ type MemoryManager struct { // activeMu is loosely analogous to Linux's struct // mm_struct::page_table_lock. - activeMu syncutil.DowngradableRWMutex `state:"nosave"` + activeMu sync.DowngradableRWMutex `state:"nosave"` // pmas stores platform mapping areas used to implement vmas. Since pmas // are stored by value, clients should usually use pmaIterator.ValuePtr() diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index f404107af..a9a2642c5 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -73,6 +73,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/state", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index f7f7298c4..c99e023d9 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -25,7 +25,6 @@ import ( "fmt" "math" "os" - "sync" "sync/atomic" "syscall" "time" @@ -37,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD index b6d008dbe..85e882df9 100644 --- a/pkg/sentry/platform/interrupt/BUILD +++ b/pkg/sentry/platform/interrupt/BUILD @@ -10,6 +10,7 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/sentry/platform/interrupt", visibility = ["//pkg/sentry:internal"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/sentry/platform/interrupt/interrupt.go b/pkg/sentry/platform/interrupt/interrupt.go index a4651f500..57be41647 100644 --- a/pkg/sentry/platform/interrupt/interrupt.go +++ b/pkg/sentry/platform/interrupt/interrupt.go @@ -17,7 +17,8 @@ package interrupt import ( "fmt" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // Receiver receives interrupt notifications from a Forwarder. diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index f3afd98da..6a358d1d4 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -55,6 +55,7 @@ go_library( "//pkg/sentry/platform/safecopy", "//pkg/sentry/time", "//pkg/sentry/usermem", + "//pkg/sync", ], ) diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index ea8b9632e..a25f3c449 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -15,13 +15,13 @@ package kvm import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // dirtySet tracks vCPUs for invalidation. diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index e5fac0d6a..2f02c03cf 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -17,8 +17,6 @@ package kvm import ( - "unsafe" - "gvisor.dev/gvisor/pkg/sentry/arch" ) diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index f2c2c059e..a7850faed 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -18,13 +18,13 @@ package kvm import ( "fmt" "os" - "sync" "syscall" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // KVM represents a lightweight VM context. diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 7d02ebf19..e6d912168 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -17,7 +17,6 @@ package kvm import ( "fmt" "runtime" - "sync" "sync/atomic" "syscall" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // machine contains state associated with the VM as a whole. diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index 0df8cfa0f..cd13390c3 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -33,6 +33,7 @@ go_library( "//pkg/sentry/platform/interrupt", "//pkg/sentry/platform/safecopy", "//pkg/sentry/usermem", + "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index 7b120a15d..bb0e03880 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -46,13 +46,13 @@ package ptrace import ( "os" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/interrupt" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) var ( diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 20244fd95..15dc46a5b 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -18,7 +18,6 @@ import ( "fmt" "os" "runtime" - "sync" "syscall" "golang.org/x/sys/unix" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" ) // Linux kernel errnos which "should never be seen by user programs", but will diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go index 2e6fbe488..245b20722 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go @@ -18,7 +18,6 @@ package ptrace import ( - "sync" "sync/atomic" "syscall" "unsafe" @@ -26,6 +25,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/hostcpu" + "gvisor.dev/gvisor/pkg/sync" ) // maskPool contains reusable CPU masks for setting affinity. Unfortunately, diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go index 3f094c2a7..86fd5ed58 100644 --- a/pkg/sentry/platform/ring0/defs.go +++ b/pkg/sentry/platform/ring0/defs.go @@ -17,7 +17,7 @@ package ring0 import ( "syscall" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" ) // Kernel is a global kernel object. diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 10dbd381f..9dae0dccb 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -18,6 +18,7 @@ package ring0 import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/sentry/usermem" ) var ( diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go index dc0eeec01..a850ce6cf 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -18,6 +18,7 @@ package ring0 import ( "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + "gvisor.dev/gvisor/pkg/sentry/usermem" ) var ( diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index e2e15ba5c..387a7f6c3 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -96,7 +96,10 @@ go_library( "//pkg/sentry/platform/kvm:__subpackages__", "//pkg/sentry/platform/ring0:__subpackages__", ], - deps = ["//pkg/sentry/usermem"], + deps = [ + "//pkg/sentry/usermem", + "//pkg/sync", + ], ) go_test( diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go index 0f029f25d..e199bae18 100644 --- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go +++ b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go @@ -17,7 +17,7 @@ package pagetables import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // limitPCID is the number of valid PCIDs. diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 136821963..103933144 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -27,6 +27,7 @@ go_library( "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 463544c1a..2d9f4ba9b 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -8,6 +8,7 @@ go_library( srcs = ["port.go"], importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port", visibility = ["//pkg/sentry:internal"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/sentry/socket/netlink/port/port.go b/pkg/sentry/socket/netlink/port/port.go index e9d3275b1..2cd3afc22 100644 --- a/pkg/sentry/socket/netlink/port/port.go +++ b/pkg/sentry/socket/netlink/port/port.go @@ -24,7 +24,8 @@ import ( "fmt" "math" "math/rand" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // maxPorts is a sanity limit on the maximum number of ports to allocate per diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index d2e3644a6..cea56f4ed 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -17,7 +17,6 @@ package netlink import ( "math" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" @@ -34,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index e414d8055..f78784569 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -34,6 +34,7 @@ go_library( "//pkg/sentry/socket/netfilter", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 764f11a6b..0affb8071 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -29,7 +29,6 @@ import ( "io" "math" "reflect" - "sync" "syscall" "time" @@ -49,6 +48,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/pkg/sentry/socket/rpcinet/conn/BUILD b/pkg/sentry/socket/rpcinet/conn/BUILD index 23eadcb1b..b2677c659 100644 --- a/pkg/sentry/socket/rpcinet/conn/BUILD +++ b/pkg/sentry/socket/rpcinet/conn/BUILD @@ -10,6 +10,7 @@ go_library( deps = [ "//pkg/binary", "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto", + "//pkg/sync", "//pkg/syserr", "//pkg/unet", "@com_github_golang_protobuf//proto:go_default_library", diff --git a/pkg/sentry/socket/rpcinet/conn/conn.go b/pkg/sentry/socket/rpcinet/conn/conn.go index 356adad99..02f39c767 100644 --- a/pkg/sentry/socket/rpcinet/conn/conn.go +++ b/pkg/sentry/socket/rpcinet/conn/conn.go @@ -17,12 +17,12 @@ package conn import ( "fmt" - "sync" "sync/atomic" "syscall" "github.com/golang/protobuf/proto" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/unet" diff --git a/pkg/sentry/socket/rpcinet/notifier/BUILD b/pkg/sentry/socket/rpcinet/notifier/BUILD index a3585e10d..a5954f22b 100644 --- a/pkg/sentry/socket/rpcinet/notifier/BUILD +++ b/pkg/sentry/socket/rpcinet/notifier/BUILD @@ -10,6 +10,7 @@ go_library( deps = [ "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto", "//pkg/sentry/socket/rpcinet/conn", + "//pkg/sync", "//pkg/waiter", "@org_golang_x_sys//unix:go_default_library", ], diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier.go b/pkg/sentry/socket/rpcinet/notifier/notifier.go index 7efe4301f..82b75d6dd 100644 --- a/pkg/sentry/socket/rpcinet/notifier/notifier.go +++ b/pkg/sentry/socket/rpcinet/notifier/notifier.go @@ -17,12 +17,12 @@ package notifier import ( "fmt" - "sync" "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 788ad70d2..d7ba95dff 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/ilist", "//pkg/refs", "//pkg/sentry/context", + "//pkg/sync", "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index dea11e253..9e6fbc111 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -15,10 +15,9 @@ package transport import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/waiter" diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index e27b1c714..5dcd3d95e 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -15,9 +15,8 @@ package transport import ( - "sync" - "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 37c7ac3c1..fcc0da332 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -16,11 +16,11 @@ package transport import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index a76975cee..aa05e208a 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -91,6 +91,7 @@ go_library( "//pkg/sentry/syscalls", "//pkg/sentry/usage", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/waiter", diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index 1d9018c96..60469549d 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -16,13 +16,13 @@ package linux import ( "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 18e212dff..3cde3a0be 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -9,7 +9,7 @@ go_template_instance( out = "seqatomic_parameters_unsafe.go", package = "time", suffix = "Parameters", - template = "//pkg/syncutil:generic_seqatomic", + template = "//pkg/sync:generic_seqatomic", types = { "Value": "Parameters", }, @@ -36,7 +36,7 @@ go_library( deps = [ "//pkg/log", "//pkg/metric", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go index 318503277..f9a93115d 100644 --- a/pkg/sentry/time/calibrated_clock.go +++ b/pkg/sentry/time/calibrated_clock.go @@ -17,11 +17,11 @@ package time import ( - "sync" "time" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD index c32fe3241..5518ac3d0 100644 --- a/pkg/sentry/usage/BUILD +++ b/pkg/sentry/usage/BUILD @@ -18,5 +18,6 @@ go_library( deps = [ "//pkg/bits", "//pkg/memutil", + "//pkg/sync", ], ) diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go index d6ef644d8..538c645eb 100644 --- a/pkg/sentry/usage/memory.go +++ b/pkg/sentry/usage/memory.go @@ -17,12 +17,12 @@ package usage import ( "fmt" "os" - "sync" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/memutil" + "gvisor.dev/gvisor/pkg/sync" ) // MemoryKind represents a type of memory used by the application. diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 4c6aa04a1..35c7be259 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -34,7 +34,7 @@ go_library( "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", "//pkg/sentry/usermem", - "//pkg/syncutil", + "//pkg/sync", "//pkg/syserror", "//pkg/waiter", ], @@ -54,6 +54,7 @@ go_test( "//pkg/sentry/context/contexttest", "//pkg/sentry/kernel/auth", "//pkg/sentry/usermem", + "//pkg/sync", "//pkg/syserror", ], ) diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index 1bc9c4a38..486a76475 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -16,9 +16,9 @@ package vfs import ( "fmt" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 66eb57bc2..c00b3c84b 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -17,13 +17,13 @@ package vfs import ( "bytes" "io" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index adff0b94b..3b933468d 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -17,8 +17,9 @@ package vfs import ( "fmt" "runtime" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) func TestMountTableLookupEmpty(t *testing.T) { diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index ab13fa461..bd90d36c4 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -26,7 +26,7 @@ import ( "sync/atomic" "unsafe" - "gvisor.dev/gvisor/pkg/syncutil" + "gvisor.dev/gvisor/pkg/sync" ) // mountKey represents the location at which a Mount is mounted. It is @@ -75,7 +75,7 @@ type mountTable struct { // intrinsics and inline assembly, limiting the performance of this // approach.) - seq syncutil.SeqCount + seq sync.SeqCount seed uint32 // for hashing keys // size holds both length (number of elements) and capacity (number of diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go index 8e155654f..cf80df90e 100644 --- a/pkg/sentry/vfs/pathname.go +++ b/pkg/sentry/vfs/pathname.go @@ -15,10 +15,9 @@ package vfs import ( - "sync" - "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index f0641d314..8a0b382f6 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -16,11 +16,11 @@ package vfs import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index ea2db7031..1f21b0b31 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -29,12 +29,12 @@ package vfs import ( "fmt" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" ) diff --git a/pkg/sentry/watchdog/BUILD b/pkg/sentry/watchdog/BUILD index 4d8435265..28f21f13d 100644 --- a/pkg/sentry/watchdog/BUILD +++ b/pkg/sentry/watchdog/BUILD @@ -13,5 +13,6 @@ go_library( "//pkg/metric", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", + "//pkg/sync", ], ) diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 5e4611333..bfb2fac26 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -32,7 +32,6 @@ package watchdog import ( "bytes" "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/abi/linux" @@ -40,6 +39,7 @@ import ( "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sync" ) // Opts configures the watchdog. diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD new file mode 100644 index 000000000..e8cd16b8f --- /dev/null +++ b/pkg/sync/BUILD @@ -0,0 +1,53 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +exports_files(["LICENSE"]) + +go_template( + name = "generic_atomicptr", + srcs = ["atomicptr_unsafe.go"], + types = [ + "Value", + ], +) + +go_template( + name = "generic_seqatomic", + srcs = ["seqatomic_unsafe.go"], + types = [ + "Value", + ], + deps = [ + ":sync", + ], +) + +go_library( + name = "sync", + srcs = [ + "aliases.go", + "downgradable_rwmutex_unsafe.go", + "memmove_unsafe.go", + "norace_unsafe.go", + "race_unsafe.go", + "seqcount.go", + "syncutil.go", + ], + importpath = "gvisor.dev/gvisor/pkg/sync", +) + +go_test( + name = "sync_test", + size = "small", + srcs = [ + "downgradable_rwmutex_test.go", + "seqcount_test.go", + ], + embed = [":sync"], +) diff --git a/pkg/sync/LICENSE b/pkg/sync/LICENSE new file mode 100644 index 000000000..6a66aea5e --- /dev/null +++ b/pkg/sync/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/sync/README.md b/pkg/sync/README.md new file mode 100644 index 000000000..2183c4e20 --- /dev/null +++ b/pkg/sync/README.md @@ -0,0 +1,5 @@ +# Syncutil + +This package provides additional synchronization primitives not provided by the +Go stdlib 'sync' package. It is partially derived from the upstream 'sync' +package from go1.10. diff --git a/pkg/sync/aliases.go b/pkg/sync/aliases.go new file mode 100644 index 000000000..20c7ca041 --- /dev/null +++ b/pkg/sync/aliases.go @@ -0,0 +1,37 @@ +// Copyright 2020 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync + +import ( + "sync" +) + +// Aliases of standard library types. +type ( + // Mutex is an alias of sync.Mutex. + Mutex = sync.Mutex + + // RWMutex is an alias of sync.RWMutex. + RWMutex = sync.RWMutex + + // Cond is an alias of sync.Cond. + Cond = sync.Cond + + // Locker is an alias of sync.Locker. + Locker = sync.Locker + + // Once is an alias of sync.Once. + Once = sync.Once + + // Pool is an alias of sync.Pool. + Pool = sync.Pool + + // WaitGroup is an alias of sync.WaitGroup. + WaitGroup = sync.WaitGroup + + // Map is an alias of sync.Map. + Map = sync.Map +) diff --git a/pkg/sync/atomicptr_unsafe.go b/pkg/sync/atomicptr_unsafe.go new file mode 100644 index 000000000..525c4beed --- /dev/null +++ b/pkg/sync/atomicptr_unsafe.go @@ -0,0 +1,47 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package template doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package template + +import ( + "sync/atomic" + "unsafe" +) + +// Value is a required type parameter. +type Value struct{} + +// An AtomicPtr is a pointer to a value of type Value that can be atomically +// loaded and stored. The zero value of an AtomicPtr represents nil. +// +// Note that copying AtomicPtr by value performs a non-atomic read of the +// stored pointer, which is unsafe if Store() can be called concurrently; in +// this case, do `dst.Store(src.Load())` instead. +// +// +stateify savable +type AtomicPtr struct { + ptr unsafe.Pointer `state:".(*Value)"` +} + +func (p *AtomicPtr) savePtr() *Value { + return p.Load() +} + +func (p *AtomicPtr) loadPtr(v *Value) { + p.Store(v) +} + +// Load returns the value set by the most recent Store. It returns nil if there +// has been no previous call to Store. +func (p *AtomicPtr) Load() *Value { + return (*Value)(atomic.LoadPointer(&p.ptr)) +} + +// Store sets the value returned by Load to x. +func (p *AtomicPtr) Store(x *Value) { + atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) +} diff --git a/pkg/sync/atomicptrtest/BUILD b/pkg/sync/atomicptrtest/BUILD new file mode 100644 index 000000000..418eda29c --- /dev/null +++ b/pkg/sync/atomicptrtest/BUILD @@ -0,0 +1,29 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "atomicptr_int", + out = "atomicptr_int_unsafe.go", + package = "atomicptr", + suffix = "Int", + template = "//pkg/sync:generic_atomicptr", + types = { + "Value": "int", + }, +) + +go_library( + name = "atomicptr", + srcs = ["atomicptr_int_unsafe.go"], + importpath = "gvisor.dev/gvisor/pkg/sync/atomicptr", +) + +go_test( + name = "atomicptr_test", + size = "small", + srcs = ["atomicptr_test.go"], + embed = [":atomicptr"], +) diff --git a/pkg/sync/atomicptrtest/atomicptr_test.go b/pkg/sync/atomicptrtest/atomicptr_test.go new file mode 100644 index 000000000..8fdc5112e --- /dev/null +++ b/pkg/sync/atomicptrtest/atomicptr_test.go @@ -0,0 +1,31 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package atomicptr + +import ( + "testing" +) + +func newInt(val int) *int { + return &val +} + +func TestAtomicPtr(t *testing.T) { + var p AtomicPtrInt + if got := p.Load(); got != nil { + t.Errorf("initial value is %p (%v), wanted nil", got, got) + } + want := newInt(42) + p.Store(want) + if got := p.Load(); got != want { + t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) + } + want = newInt(100) + p.Store(want) + if got := p.Load(); got != want { + t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) + } +} diff --git a/pkg/sync/downgradable_rwmutex_test.go b/pkg/sync/downgradable_rwmutex_test.go new file mode 100644 index 000000000..f04496bc5 --- /dev/null +++ b/pkg/sync/downgradable_rwmutex_test.go @@ -0,0 +1,150 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright 2019 The gVisor Authors. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// GOMAXPROCS=10 go test + +// Copy/pasted from the standard library's sync/rwmutex_test.go, except for the +// addition of downgradingWriter and the renaming of num_iterations to +// numIterations to shut up Golint. + +package sync + +import ( + "fmt" + "runtime" + "sync/atomic" + "testing" +) + +func parallelReader(m *DowngradableRWMutex, clocked, cunlock, cdone chan bool) { + m.RLock() + clocked <- true + <-cunlock + m.RUnlock() + cdone <- true +} + +func doTestParallelReaders(numReaders, gomaxprocs int) { + runtime.GOMAXPROCS(gomaxprocs) + var m DowngradableRWMutex + clocked := make(chan bool) + cunlock := make(chan bool) + cdone := make(chan bool) + for i := 0; i < numReaders; i++ { + go parallelReader(&m, clocked, cunlock, cdone) + } + // Wait for all parallel RLock()s to succeed. + for i := 0; i < numReaders; i++ { + <-clocked + } + for i := 0; i < numReaders; i++ { + cunlock <- true + } + // Wait for the goroutines to finish. + for i := 0; i < numReaders; i++ { + <-cdone + } +} + +func TestParallelReaders(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) + doTestParallelReaders(1, 4) + doTestParallelReaders(3, 4) + doTestParallelReaders(4, 2) +} + +func reader(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.RLock() + n := atomic.AddInt32(activity, 1) + if n < 1 || n >= 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -1) + rwm.RUnlock() + } + cdone <- true +} + +func writer(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.Lock() + n := atomic.AddInt32(activity, 10000) + if n != 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -10000) + rwm.Unlock() + } + cdone <- true +} + +func downgradingWriter(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { + for i := 0; i < numIterations; i++ { + rwm.Lock() + n := atomic.AddInt32(activity, 10000) + if n != 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + atomic.AddInt32(activity, -10000) + rwm.DowngradeLock() + n = atomic.AddInt32(activity, 1) + if n < 1 || n >= 10000 { + panic(fmt.Sprintf("wlock(%d)\n", n)) + } + for i := 0; i < 100; i++ { + } + n = atomic.AddInt32(activity, -1) + rwm.RUnlock() + } + cdone <- true +} + +func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) { + runtime.GOMAXPROCS(gomaxprocs) + // Number of active readers + 10000 * number of active writers. + var activity int32 + var rwm DowngradableRWMutex + cdone := make(chan bool) + go writer(&rwm, numIterations, &activity, cdone) + go downgradingWriter(&rwm, numIterations, &activity, cdone) + var i int + for i = 0; i < numReaders/2; i++ { + go reader(&rwm, numIterations, &activity, cdone) + } + go writer(&rwm, numIterations, &activity, cdone) + go downgradingWriter(&rwm, numIterations, &activity, cdone) + for ; i < numReaders; i++ { + go reader(&rwm, numIterations, &activity, cdone) + } + // Wait for the 4 writers and all readers to finish. + for i := 0; i < 4+numReaders; i++ { + <-cdone + } +} + +func TestDowngradableRWMutex(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) + n := 1000 + if testing.Short() { + n = 5 + } + HammerDowngradableRWMutex(1, 1, n) + HammerDowngradableRWMutex(1, 3, n) + HammerDowngradableRWMutex(1, 10, n) + HammerDowngradableRWMutex(4, 1, n) + HammerDowngradableRWMutex(4, 3, n) + HammerDowngradableRWMutex(4, 10, n) + HammerDowngradableRWMutex(10, 1, n) + HammerDowngradableRWMutex(10, 3, n) + HammerDowngradableRWMutex(10, 10, n) + HammerDowngradableRWMutex(10, 5, n) +} diff --git a/pkg/sync/downgradable_rwmutex_unsafe.go b/pkg/sync/downgradable_rwmutex_unsafe.go new file mode 100644 index 000000000..9bb55cd3a --- /dev/null +++ b/pkg/sync/downgradable_rwmutex_unsafe.go @@ -0,0 +1,146 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright 2019 The gVisor Authors. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.13 +// +build !go1.15 + +// Check go:linkname function signatures when updating Go version. + +// This is mostly copied from the standard library's sync/rwmutex.go. +// +// Happens-before relationships indicated to the race detector: +// - Unlock -> Lock (via writerSem) +// - Unlock -> RLock (via readerSem) +// - RUnlock -> Lock (via writerSem) +// - DowngradeLock -> RLock (via readerSem) + +package sync + +import ( + "sync" + "sync/atomic" + "unsafe" +) + +//go:linkname runtimeSemacquire sync.runtime_Semacquire +func runtimeSemacquire(s *uint32) + +//go:linkname runtimeSemrelease sync.runtime_Semrelease +func runtimeSemrelease(s *uint32, handoff bool, skipframes int) + +// DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock +// method. +type DowngradableRWMutex struct { + w sync.Mutex // held if there are pending writers + writerSem uint32 // semaphore for writers to wait for completing readers + readerSem uint32 // semaphore for readers to wait for completing writers + readerCount int32 // number of pending readers + readerWait int32 // number of departing readers +} + +const rwmutexMaxReaders = 1 << 30 + +// RLock locks rw for reading. +func (rw *DowngradableRWMutex) RLock() { + if RaceEnabled { + RaceDisable() + } + if atomic.AddInt32(&rw.readerCount, 1) < 0 { + // A writer is pending, wait for it. + runtimeSemacquire(&rw.readerSem) + } + if RaceEnabled { + RaceEnable() + RaceAcquire(unsafe.Pointer(&rw.readerSem)) + } +} + +// RUnlock undoes a single RLock call. +func (rw *DowngradableRWMutex) RUnlock() { + if RaceEnabled { + RaceReleaseMerge(unsafe.Pointer(&rw.writerSem)) + RaceDisable() + } + if r := atomic.AddInt32(&rw.readerCount, -1); r < 0 { + if r+1 == 0 || r+1 == -rwmutexMaxReaders { + panic("RUnlock of unlocked DowngradableRWMutex") + } + // A writer is pending. + if atomic.AddInt32(&rw.readerWait, -1) == 0 { + // The last reader unblocks the writer. + runtimeSemrelease(&rw.writerSem, false, 0) + } + } + if RaceEnabled { + RaceEnable() + } +} + +// Lock locks rw for writing. +func (rw *DowngradableRWMutex) Lock() { + if RaceEnabled { + RaceDisable() + } + // First, resolve competition with other writers. + rw.w.Lock() + // Announce to readers there is a pending writer. + r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders + // Wait for active readers. + if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 { + runtimeSemacquire(&rw.writerSem) + } + if RaceEnabled { + RaceEnable() + RaceAcquire(unsafe.Pointer(&rw.writerSem)) + } +} + +// Unlock unlocks rw for writing. +func (rw *DowngradableRWMutex) Unlock() { + if RaceEnabled { + RaceRelease(unsafe.Pointer(&rw.writerSem)) + RaceRelease(unsafe.Pointer(&rw.readerSem)) + RaceDisable() + } + // Announce to readers there is no active writer. + r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders) + if r >= rwmutexMaxReaders { + panic("Unlock of unlocked DowngradableRWMutex") + } + // Unblock blocked readers, if any. + for i := 0; i < int(r); i++ { + runtimeSemrelease(&rw.readerSem, false, 0) + } + // Allow other writers to proceed. + rw.w.Unlock() + if RaceEnabled { + RaceEnable() + } +} + +// DowngradeLock atomically unlocks rw for writing and locks it for reading. +func (rw *DowngradableRWMutex) DowngradeLock() { + if RaceEnabled { + RaceRelease(unsafe.Pointer(&rw.readerSem)) + RaceDisable() + } + // Announce to readers there is no active writer and one additional reader. + r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders+1) + if r >= rwmutexMaxReaders+1 { + panic("DowngradeLock of unlocked DowngradableRWMutex") + } + // Unblock blocked readers, if any. Note that this loop starts as 1 since r + // includes this goroutine. + for i := 1; i < int(r); i++ { + runtimeSemrelease(&rw.readerSem, false, 0) + } + // Allow other writers to proceed to rw.w.Lock(). Note that they will still + // block on rw.writerSem since at least this reader exists, such that + // DowngradeLock() is atomic with the previous write lock. + rw.w.Unlock() + if RaceEnabled { + RaceEnable() + } +} diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go new file mode 100644 index 000000000..ad4a3a37e --- /dev/null +++ b/pkg/sync/memmove_unsafe.go @@ -0,0 +1,28 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.12 +// +build !go1.15 + +// Check go:linkname function signatures when updating Go version. + +package sync + +import ( + "unsafe" +) + +//go:linkname memmove runtime.memmove +//go:noescape +func memmove(to, from unsafe.Pointer, n uintptr) + +// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad, which can't +// define it because go_generics can't update the go:linkname annotation. +// Furthermore, go:linkname silently doesn't work if the local name is exported +// (this is of course undocumented), which is why this indirection is +// necessary. +func Memmove(to, from unsafe.Pointer, n uintptr) { + memmove(to, from, n) +} diff --git a/pkg/sync/norace_unsafe.go b/pkg/sync/norace_unsafe.go new file mode 100644 index 000000000..006055dd6 --- /dev/null +++ b/pkg/sync/norace_unsafe.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !race + +package sync + +import ( + "unsafe" +) + +// RaceEnabled is true if the Go data race detector is enabled. +const RaceEnabled = false + +// RaceDisable has the same semantics as runtime.RaceDisable. +func RaceDisable() { +} + +// RaceEnable has the same semantics as runtime.RaceEnable. +func RaceEnable() { +} + +// RaceAcquire has the same semantics as runtime.RaceAcquire. +func RaceAcquire(addr unsafe.Pointer) { +} + +// RaceRelease has the same semantics as runtime.RaceRelease. +func RaceRelease(addr unsafe.Pointer) { +} + +// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. +func RaceReleaseMerge(addr unsafe.Pointer) { +} diff --git a/pkg/sync/race_unsafe.go b/pkg/sync/race_unsafe.go new file mode 100644 index 000000000..31d8fa9a6 --- /dev/null +++ b/pkg/sync/race_unsafe.go @@ -0,0 +1,41 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build race + +package sync + +import ( + "runtime" + "unsafe" +) + +// RaceEnabled is true if the Go data race detector is enabled. +const RaceEnabled = true + +// RaceDisable has the same semantics as runtime.RaceDisable. +func RaceDisable() { + runtime.RaceDisable() +} + +// RaceEnable has the same semantics as runtime.RaceEnable. +func RaceEnable() { + runtime.RaceEnable() +} + +// RaceAcquire has the same semantics as runtime.RaceAcquire. +func RaceAcquire(addr unsafe.Pointer) { + runtime.RaceAcquire(addr) +} + +// RaceRelease has the same semantics as runtime.RaceRelease. +func RaceRelease(addr unsafe.Pointer) { + runtime.RaceRelease(addr) +} + +// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. +func RaceReleaseMerge(addr unsafe.Pointer) { + runtime.RaceReleaseMerge(addr) +} diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/seqatomic_unsafe.go new file mode 100644 index 000000000..eda6fb131 --- /dev/null +++ b/pkg/sync/seqatomic_unsafe.go @@ -0,0 +1,72 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package template doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package template + +import ( + "fmt" + "reflect" + "strings" + "unsafe" + + "gvisor.dev/gvisor/pkg/sync" +) + +// Value is a required type parameter. +// +// Value must not contain any pointers, including interface objects, function +// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs +// containing any of the above. An init() function will panic if this property +// does not hold. +type Value struct{} + +// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race +// with any writer critical sections in sc. +func SeqAtomicLoad(sc *sync.SeqCount, ptr *Value) Value { + // This function doesn't use SeqAtomicTryLoad because doing so is + // measurably, significantly (~20%) slower; Go is awful at inlining. + var val Value + for { + epoch := sc.BeginRead() + if sync.RaceEnabled { + // runtime.RaceDisable() doesn't actually stop the race detector, + // so it can't help us here. Instead, call runtime.memmove + // directly, which is not instrumented by the race detector. + sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + // This is ~40% faster for short reads than going through memmove. + val = *ptr + } + if sc.ReadOk(epoch) { + break + } + } + return val +} + +// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section +// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read +// would race with a writer critical section, SeqAtomicTryLoad returns +// (unspecified, false). +func SeqAtomicTryLoad(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (Value, bool) { + var val Value + if sync.RaceEnabled { + sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + val = *ptr + } + return val, sc.ReadOk(epoch) +} + +func init() { + var val Value + typ := reflect.TypeOf(val) + name := typ.Name() + if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 { + panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) + } +} diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD new file mode 100644 index 000000000..eba21518d --- /dev/null +++ b/pkg/sync/seqatomictest/BUILD @@ -0,0 +1,33 @@ +load("//tools/go_stateify:defs.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "seqatomic_int", + out = "seqatomic_int_unsafe.go", + package = "seqatomic", + suffix = "Int", + template = "//pkg/sync:generic_seqatomic", + types = { + "Value": "int", + }, +) + +go_library( + name = "seqatomic", + srcs = ["seqatomic_int_unsafe.go"], + importpath = "gvisor.dev/gvisor/pkg/sync/seqatomic", + deps = [ + "//pkg/sync", + ], +) + +go_test( + name = "seqatomic_test", + size = "small", + srcs = ["seqatomic_test.go"], + embed = [":seqatomic"], + deps = ["//pkg/sync"], +) diff --git a/pkg/sync/seqatomictest/seqatomic_test.go b/pkg/sync/seqatomictest/seqatomic_test.go new file mode 100644 index 000000000..2c4568b07 --- /dev/null +++ b/pkg/sync/seqatomictest/seqatomic_test.go @@ -0,0 +1,132 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seqatomic + +import ( + "sync/atomic" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sync" +) + +func TestSeqAtomicLoadUncontended(t *testing.T) { + var seq sync.SeqCount + const want = 1 + data := want + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicLoadAfterWrite(t *testing.T) { + var seq sync.SeqCount + var data int + const want = 1 + seq.BeginWrite() + data = want + seq.EndWrite() + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicLoadDuringWrite(t *testing.T) { + var seq sync.SeqCount + var data int + const want = 1 + seq.BeginWrite() + go func() { + time.Sleep(time.Second) + data = want + seq.EndWrite() + }() + if got := SeqAtomicLoadInt(&seq, &data); got != want { + t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } +} + +func TestSeqAtomicTryLoadUncontended(t *testing.T) { + var seq sync.SeqCount + const want = 1 + data := want + epoch := seq.BeginRead() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { + t.Errorf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) + } +} + +func TestSeqAtomicTryLoadDuringWrite(t *testing.T) { + var seq sync.SeqCount + var data int + epoch := seq.BeginRead() + seq.BeginWrite() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { + t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) + } + seq.EndWrite() +} + +func TestSeqAtomicTryLoadAfterWrite(t *testing.T) { + var seq sync.SeqCount + var data int + epoch := seq.BeginRead() + seq.BeginWrite() + seq.EndWrite() + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { + t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) + } +} + +func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) { + var seq sync.SeqCount + const want = 42 + data := want + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if got := SeqAtomicLoadInt(&seq, &data); got != want { + b.Fatalf("SeqAtomicLoadInt: got %v, wanted %v", got, want) + } + } + }) +} + +func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) { + var seq sync.SeqCount + const want = 42 + data := want + b.RunParallel(func(pb *testing.PB) { + epoch := seq.BeginRead() + for pb.Next() { + if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { + b.Fatalf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) + } + } + }) +} + +// For comparison: +func BenchmarkAtomicValueLoadIntUncontended(b *testing.B) { + var a atomic.Value + const want = 42 + a.Store(int(want)) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if got := a.Load().(int); got != want { + b.Fatalf("atomic.Value.Load: got %v, wanted %v", got, want) + } + } + }) +} diff --git a/pkg/sync/seqcount.go b/pkg/sync/seqcount.go new file mode 100644 index 000000000..a1e895352 --- /dev/null +++ b/pkg/sync/seqcount.go @@ -0,0 +1,149 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync + +import ( + "fmt" + "reflect" + "runtime" + "sync/atomic" +) + +// SeqCount is a synchronization primitive for optimistic reader/writer +// synchronization in cases where readers can work with stale data and +// therefore do not need to block writers. +// +// Compared to sync/atomic.Value: +// +// - Mutation of SeqCount-protected data does not require memory allocation, +// whereas atomic.Value generally does. This is a significant advantage when +// writes are common. +// +// - Atomic reads of SeqCount-protected data require copying. This is a +// disadvantage when atomic reads are common. +// +// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other +// operations to be made atomic with reads of SeqCount-protected data. +// +// - SeqCount may be less flexible: as of this writing, SeqCount-protected data +// cannot include pointers. +// +// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected +// data require instantiating function templates using go_generics (see +// seqatomic.go). +type SeqCount struct { + // epoch is incremented by BeginWrite and EndWrite, such that epoch is odd + // if a writer critical section is active, and a read from data protected + // by this SeqCount is atomic iff epoch is the same even value before and + // after the read. + epoch uint32 +} + +// SeqCountEpoch tracks writer critical sections in a SeqCount. +type SeqCountEpoch struct { + val uint32 +} + +// We assume that: +// +// - All functions in sync/atomic that perform a memory read are at least a +// read fence: memory reads before calls to such functions cannot be reordered +// after the call, and memory reads after calls to such functions cannot be +// reordered before the call, even if those reads do not use sync/atomic. +// +// - All functions in sync/atomic that perform a memory write are at least a +// write fence: memory writes before calls to such functions cannot be +// reordered after the call, and memory writes after calls to such functions +// cannot be reordered before the call, even if those writes do not use +// sync/atomic. +// +// As of this writing, the Go memory model completely fails to describe +// sync/atomic, but these properties are implied by +// https://groups.google.com/forum/#!topic/golang-nuts/7EnEhM3U7B8. + +// BeginRead indicates the beginning of a reader critical section. Reader +// critical sections DO NOT BLOCK writer critical sections, so operations in a +// reader critical section MAY RACE with writer critical sections. Races are +// detected by ReadOk at the end of the reader critical section. Thus, the +// low-level structure of readers is generally: +// +// for { +// epoch := seq.BeginRead() +// // do something idempotent with seq-protected data +// if seq.ReadOk(epoch) { +// break +// } +// } +// +// However, since reader critical sections may race with writer critical +// sections, the Go race detector will (accurately) flag data races in readers +// using this pattern. Most users of SeqCount will need to use the +// SeqAtomicLoad function template in seqatomic.go. +func (s *SeqCount) BeginRead() SeqCountEpoch { + epoch := atomic.LoadUint32(&s.epoch) + for epoch&1 != 0 { + runtime.Gosched() + epoch = atomic.LoadUint32(&s.epoch) + } + return SeqCountEpoch{epoch} +} + +// ReadOk returns true if the reader critical section initiated by a previous +// call to BeginRead() that returned epoch did not race with any writer critical +// sections. +// +// ReadOk may be called any number of times during a reader critical section. +// Reader critical sections do not need to be explicitly terminated; the last +// call to ReadOk is implicitly the end of the reader critical section. +func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool { + return atomic.LoadUint32(&s.epoch) == epoch.val +} + +// BeginWrite indicates the beginning of a writer critical section. +// +// SeqCount does not support concurrent writer critical sections; clients with +// concurrent writers must synchronize them using e.g. sync.Mutex. +func (s *SeqCount) BeginWrite() { + if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 == 0 { + panic("SeqCount.BeginWrite during writer critical section") + } +} + +// EndWrite ends the effect of a preceding BeginWrite. +func (s *SeqCount) EndWrite() { + if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 != 0 { + panic("SeqCount.EndWrite outside writer critical section") + } +} + +// PointersInType returns a list of pointers reachable from values named +// valName of the given type. +// +// PointersInType is not exhaustive, but it is guaranteed that if typ contains +// at least one pointer, then PointersInTypeOf returns a non-empty list. +func PointersInType(typ reflect.Type, valName string) []string { + switch kind := typ.Kind(); kind { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + return nil + + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer: + return []string{valName} + + case reflect.Array: + return PointersInType(typ.Elem(), valName+"[]") + + case reflect.Struct: + var ptrs []string + for i, n := 0, typ.NumField(); i < n; i++ { + field := typ.Field(i) + ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...) + } + return ptrs + + default: + return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)} + } +} diff --git a/pkg/sync/seqcount_test.go b/pkg/sync/seqcount_test.go new file mode 100644 index 000000000..6eb7b4b59 --- /dev/null +++ b/pkg/sync/seqcount_test.go @@ -0,0 +1,153 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sync + +import ( + "reflect" + "testing" + "time" +) + +func TestSeqCountWriteUncontended(t *testing.T) { + var seq SeqCount + seq.BeginWrite() + seq.EndWrite() +} + +func TestSeqCountReadUncontended(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountBeginReadAfterWrite(t *testing.T) { + var seq SeqCount + var data int32 + const want = 1 + seq.BeginWrite() + data = want + seq.EndWrite() + epoch := seq.BeginRead() + if data != want { + t.Errorf("Reader: got %v, wanted %v", data, want) + } + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountBeginReadDuringWrite(t *testing.T) { + var seq SeqCount + var data int + const want = 1 + seq.BeginWrite() + go func() { + time.Sleep(time.Second) + data = want + seq.EndWrite() + }() + epoch := seq.BeginRead() + if data != want { + t.Errorf("Reader: got %v, wanted %v", data, want) + } + if !seq.ReadOk(epoch) { + t.Errorf("ReadOk: got false, wanted true") + } +} + +func TestSeqCountReadOkAfterWrite(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + seq.BeginWrite() + seq.EndWrite() + if seq.ReadOk(epoch) { + t.Errorf("ReadOk: got true, wanted false") + } +} + +func TestSeqCountReadOkDuringWrite(t *testing.T) { + var seq SeqCount + epoch := seq.BeginRead() + seq.BeginWrite() + if seq.ReadOk(epoch) { + t.Errorf("ReadOk: got true, wanted false") + } + seq.EndWrite() +} + +func BenchmarkSeqCountWriteUncontended(b *testing.B) { + var seq SeqCount + for i := 0; i < b.N; i++ { + seq.BeginWrite() + seq.EndWrite() + } +} + +func BenchmarkSeqCountReadUncontended(b *testing.B) { + var seq SeqCount + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + epoch := seq.BeginRead() + if !seq.ReadOk(epoch) { + b.Fatalf("ReadOk: got false, wanted true") + } + } + }) +} + +func TestPointersInType(t *testing.T) { + for _, test := range []struct { + name string // used for both test and value name + val interface{} + ptrs []string + }{ + { + name: "EmptyStruct", + val: struct{}{}, + }, + { + name: "Int", + val: int(0), + }, + { + name: "MixedStruct", + val: struct { + b bool + I int + ExportedPtr *struct{} + unexportedPtr *struct{} + arr [2]int + ptrArr [2]*int + nestedStruct struct { + nestedNonptr int + nestedPtr *int + } + structArr [1]struct { + nonptr int + ptr *int + } + }{}, + ptrs: []string{ + "MixedStruct.ExportedPtr", + "MixedStruct.unexportedPtr", + "MixedStruct.ptrArr[]", + "MixedStruct.nestedStruct.nestedPtr", + "MixedStruct.structArr[].ptr", + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + typ := reflect.TypeOf(test.val) + ptrs := PointersInType(typ, test.name) + t.Logf("Found pointers: %v", ptrs) + if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) { + t.Errorf("Got %v, wanted %v", ptrs, test.ptrs) + } + }) + } +} diff --git a/pkg/sync/syncutil.go b/pkg/sync/syncutil.go new file mode 100644 index 000000000..b16cf5333 --- /dev/null +++ b/pkg/sync/syncutil.go @@ -0,0 +1,7 @@ +// Copyright 2019 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sync provides synchronization primitives. +package sync diff --git a/pkg/syncutil/BUILD b/pkg/syncutil/BUILD deleted file mode 100644 index cb1f41628..000000000 --- a/pkg/syncutil/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -exports_files(["LICENSE"]) - -go_template( - name = "generic_atomicptr", - srcs = ["atomicptr_unsafe.go"], - types = [ - "Value", - ], -) - -go_template( - name = "generic_seqatomic", - srcs = ["seqatomic_unsafe.go"], - types = [ - "Value", - ], - deps = [ - ":sync", - ], -) - -go_library( - name = "syncutil", - srcs = [ - "downgradable_rwmutex_unsafe.go", - "memmove_unsafe.go", - "norace_unsafe.go", - "race_unsafe.go", - "seqcount.go", - "syncutil.go", - ], - importpath = "gvisor.dev/gvisor/pkg/syncutil", -) - -go_test( - name = "syncutil_test", - size = "small", - srcs = [ - "downgradable_rwmutex_test.go", - "seqcount_test.go", - ], - embed = [":syncutil"], -) diff --git a/pkg/syncutil/LICENSE b/pkg/syncutil/LICENSE deleted file mode 100644 index 6a66aea5e..000000000 --- a/pkg/syncutil/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2009 The Go Authors. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google Inc. nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/syncutil/README.md b/pkg/syncutil/README.md deleted file mode 100644 index 2183c4e20..000000000 --- a/pkg/syncutil/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Syncutil - -This package provides additional synchronization primitives not provided by the -Go stdlib 'sync' package. It is partially derived from the upstream 'sync' -package from go1.10. diff --git a/pkg/syncutil/atomicptr_unsafe.go b/pkg/syncutil/atomicptr_unsafe.go deleted file mode 100644 index 525c4beed..000000000 --- a/pkg/syncutil/atomicptr_unsafe.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package template doesn't exist. This file must be instantiated using the -// go_template_instance rule in tools/go_generics/defs.bzl. -package template - -import ( - "sync/atomic" - "unsafe" -) - -// Value is a required type parameter. -type Value struct{} - -// An AtomicPtr is a pointer to a value of type Value that can be atomically -// loaded and stored. The zero value of an AtomicPtr represents nil. -// -// Note that copying AtomicPtr by value performs a non-atomic read of the -// stored pointer, which is unsafe if Store() can be called concurrently; in -// this case, do `dst.Store(src.Load())` instead. -// -// +stateify savable -type AtomicPtr struct { - ptr unsafe.Pointer `state:".(*Value)"` -} - -func (p *AtomicPtr) savePtr() *Value { - return p.Load() -} - -func (p *AtomicPtr) loadPtr(v *Value) { - p.Store(v) -} - -// Load returns the value set by the most recent Store. It returns nil if there -// has been no previous call to Store. -func (p *AtomicPtr) Load() *Value { - return (*Value)(atomic.LoadPointer(&p.ptr)) -} - -// Store sets the value returned by Load to x. -func (p *AtomicPtr) Store(x *Value) { - atomic.StorePointer(&p.ptr, (unsafe.Pointer)(x)) -} diff --git a/pkg/syncutil/atomicptrtest/BUILD b/pkg/syncutil/atomicptrtest/BUILD deleted file mode 100644 index 63f411a90..000000000 --- a/pkg/syncutil/atomicptrtest/BUILD +++ /dev/null @@ -1,29 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "atomicptr_int", - out = "atomicptr_int_unsafe.go", - package = "atomicptr", - suffix = "Int", - template = "//pkg/syncutil:generic_atomicptr", - types = { - "Value": "int", - }, -) - -go_library( - name = "atomicptr", - srcs = ["atomicptr_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/syncutil/atomicptr", -) - -go_test( - name = "atomicptr_test", - size = "small", - srcs = ["atomicptr_test.go"], - embed = [":atomicptr"], -) diff --git a/pkg/syncutil/atomicptrtest/atomicptr_test.go b/pkg/syncutil/atomicptrtest/atomicptr_test.go deleted file mode 100644 index 8fdc5112e..000000000 --- a/pkg/syncutil/atomicptrtest/atomicptr_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package atomicptr - -import ( - "testing" -) - -func newInt(val int) *int { - return &val -} - -func TestAtomicPtr(t *testing.T) { - var p AtomicPtrInt - if got := p.Load(); got != nil { - t.Errorf("initial value is %p (%v), wanted nil", got, got) - } - want := newInt(42) - p.Store(want) - if got := p.Load(); got != want { - t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) - } - want = newInt(100) - p.Store(want) - if got := p.Load(); got != want { - t.Errorf("wrong value: got %p (%v), wanted %p (%v)", got, got, want, want) - } -} diff --git a/pkg/syncutil/downgradable_rwmutex_test.go b/pkg/syncutil/downgradable_rwmutex_test.go deleted file mode 100644 index ffaf7ecc7..000000000 --- a/pkg/syncutil/downgradable_rwmutex_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Copyright 2019 The gVisor Authors. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// GOMAXPROCS=10 go test - -// Copy/pasted from the standard library's sync/rwmutex_test.go, except for the -// addition of downgradingWriter and the renaming of num_iterations to -// numIterations to shut up Golint. - -package syncutil - -import ( - "fmt" - "runtime" - "sync/atomic" - "testing" -) - -func parallelReader(m *DowngradableRWMutex, clocked, cunlock, cdone chan bool) { - m.RLock() - clocked <- true - <-cunlock - m.RUnlock() - cdone <- true -} - -func doTestParallelReaders(numReaders, gomaxprocs int) { - runtime.GOMAXPROCS(gomaxprocs) - var m DowngradableRWMutex - clocked := make(chan bool) - cunlock := make(chan bool) - cdone := make(chan bool) - for i := 0; i < numReaders; i++ { - go parallelReader(&m, clocked, cunlock, cdone) - } - // Wait for all parallel RLock()s to succeed. - for i := 0; i < numReaders; i++ { - <-clocked - } - for i := 0; i < numReaders; i++ { - cunlock <- true - } - // Wait for the goroutines to finish. - for i := 0; i < numReaders; i++ { - <-cdone - } -} - -func TestParallelReaders(t *testing.T) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) - doTestParallelReaders(1, 4) - doTestParallelReaders(3, 4) - doTestParallelReaders(4, 2) -} - -func reader(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.RLock() - n := atomic.AddInt32(activity, 1) - if n < 1 || n >= 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -1) - rwm.RUnlock() - } - cdone <- true -} - -func writer(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.Lock() - n := atomic.AddInt32(activity, 10000) - if n != 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -10000) - rwm.Unlock() - } - cdone <- true -} - -func downgradingWriter(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) { - for i := 0; i < numIterations; i++ { - rwm.Lock() - n := atomic.AddInt32(activity, 10000) - if n != 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - atomic.AddInt32(activity, -10000) - rwm.DowngradeLock() - n = atomic.AddInt32(activity, 1) - if n < 1 || n >= 10000 { - panic(fmt.Sprintf("wlock(%d)\n", n)) - } - for i := 0; i < 100; i++ { - } - n = atomic.AddInt32(activity, -1) - rwm.RUnlock() - } - cdone <- true -} - -func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) { - runtime.GOMAXPROCS(gomaxprocs) - // Number of active readers + 10000 * number of active writers. - var activity int32 - var rwm DowngradableRWMutex - cdone := make(chan bool) - go writer(&rwm, numIterations, &activity, cdone) - go downgradingWriter(&rwm, numIterations, &activity, cdone) - var i int - for i = 0; i < numReaders/2; i++ { - go reader(&rwm, numIterations, &activity, cdone) - } - go writer(&rwm, numIterations, &activity, cdone) - go downgradingWriter(&rwm, numIterations, &activity, cdone) - for ; i < numReaders; i++ { - go reader(&rwm, numIterations, &activity, cdone) - } - // Wait for the 4 writers and all readers to finish. - for i := 0; i < 4+numReaders; i++ { - <-cdone - } -} - -func TestDowngradableRWMutex(t *testing.T) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(-1)) - n := 1000 - if testing.Short() { - n = 5 - } - HammerDowngradableRWMutex(1, 1, n) - HammerDowngradableRWMutex(1, 3, n) - HammerDowngradableRWMutex(1, 10, n) - HammerDowngradableRWMutex(4, 1, n) - HammerDowngradableRWMutex(4, 3, n) - HammerDowngradableRWMutex(4, 10, n) - HammerDowngradableRWMutex(10, 1, n) - HammerDowngradableRWMutex(10, 3, n) - HammerDowngradableRWMutex(10, 10, n) - HammerDowngradableRWMutex(10, 5, n) -} diff --git a/pkg/syncutil/downgradable_rwmutex_unsafe.go b/pkg/syncutil/downgradable_rwmutex_unsafe.go deleted file mode 100644 index 51e11555d..000000000 --- a/pkg/syncutil/downgradable_rwmutex_unsafe.go +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Copyright 2019 The gVisor Authors. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.13 -// +build !go1.15 - -// Check go:linkname function signatures when updating Go version. - -// This is mostly copied from the standard library's sync/rwmutex.go. -// -// Happens-before relationships indicated to the race detector: -// - Unlock -> Lock (via writerSem) -// - Unlock -> RLock (via readerSem) -// - RUnlock -> Lock (via writerSem) -// - DowngradeLock -> RLock (via readerSem) - -package syncutil - -import ( - "sync" - "sync/atomic" - "unsafe" -) - -//go:linkname runtimeSemacquire sync.runtime_Semacquire -func runtimeSemacquire(s *uint32) - -//go:linkname runtimeSemrelease sync.runtime_Semrelease -func runtimeSemrelease(s *uint32, handoff bool, skipframes int) - -// DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock -// method. -type DowngradableRWMutex struct { - w sync.Mutex // held if there are pending writers - writerSem uint32 // semaphore for writers to wait for completing readers - readerSem uint32 // semaphore for readers to wait for completing writers - readerCount int32 // number of pending readers - readerWait int32 // number of departing readers -} - -const rwmutexMaxReaders = 1 << 30 - -// RLock locks rw for reading. -func (rw *DowngradableRWMutex) RLock() { - if RaceEnabled { - RaceDisable() - } - if atomic.AddInt32(&rw.readerCount, 1) < 0 { - // A writer is pending, wait for it. - runtimeSemacquire(&rw.readerSem) - } - if RaceEnabled { - RaceEnable() - RaceAcquire(unsafe.Pointer(&rw.readerSem)) - } -} - -// RUnlock undoes a single RLock call. -func (rw *DowngradableRWMutex) RUnlock() { - if RaceEnabled { - RaceReleaseMerge(unsafe.Pointer(&rw.writerSem)) - RaceDisable() - } - if r := atomic.AddInt32(&rw.readerCount, -1); r < 0 { - if r+1 == 0 || r+1 == -rwmutexMaxReaders { - panic("RUnlock of unlocked DowngradableRWMutex") - } - // A writer is pending. - if atomic.AddInt32(&rw.readerWait, -1) == 0 { - // The last reader unblocks the writer. - runtimeSemrelease(&rw.writerSem, false, 0) - } - } - if RaceEnabled { - RaceEnable() - } -} - -// Lock locks rw for writing. -func (rw *DowngradableRWMutex) Lock() { - if RaceEnabled { - RaceDisable() - } - // First, resolve competition with other writers. - rw.w.Lock() - // Announce to readers there is a pending writer. - r := atomic.AddInt32(&rw.readerCount, -rwmutexMaxReaders) + rwmutexMaxReaders - // Wait for active readers. - if r != 0 && atomic.AddInt32(&rw.readerWait, r) != 0 { - runtimeSemacquire(&rw.writerSem) - } - if RaceEnabled { - RaceEnable() - RaceAcquire(unsafe.Pointer(&rw.writerSem)) - } -} - -// Unlock unlocks rw for writing. -func (rw *DowngradableRWMutex) Unlock() { - if RaceEnabled { - RaceRelease(unsafe.Pointer(&rw.writerSem)) - RaceRelease(unsafe.Pointer(&rw.readerSem)) - RaceDisable() - } - // Announce to readers there is no active writer. - r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders) - if r >= rwmutexMaxReaders { - panic("Unlock of unlocked DowngradableRWMutex") - } - // Unblock blocked readers, if any. - for i := 0; i < int(r); i++ { - runtimeSemrelease(&rw.readerSem, false, 0) - } - // Allow other writers to proceed. - rw.w.Unlock() - if RaceEnabled { - RaceEnable() - } -} - -// DowngradeLock atomically unlocks rw for writing and locks it for reading. -func (rw *DowngradableRWMutex) DowngradeLock() { - if RaceEnabled { - RaceRelease(unsafe.Pointer(&rw.readerSem)) - RaceDisable() - } - // Announce to readers there is no active writer and one additional reader. - r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders+1) - if r >= rwmutexMaxReaders+1 { - panic("DowngradeLock of unlocked DowngradableRWMutex") - } - // Unblock blocked readers, if any. Note that this loop starts as 1 since r - // includes this goroutine. - for i := 1; i < int(r); i++ { - runtimeSemrelease(&rw.readerSem, false, 0) - } - // Allow other writers to proceed to rw.w.Lock(). Note that they will still - // block on rw.writerSem since at least this reader exists, such that - // DowngradeLock() is atomic with the previous write lock. - rw.w.Unlock() - if RaceEnabled { - RaceEnable() - } -} diff --git a/pkg/syncutil/memmove_unsafe.go b/pkg/syncutil/memmove_unsafe.go deleted file mode 100644 index 348675baa..000000000 --- a/pkg/syncutil/memmove_unsafe.go +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.12 -// +build !go1.15 - -// Check go:linkname function signatures when updating Go version. - -package syncutil - -import ( - "unsafe" -) - -//go:linkname memmove runtime.memmove -//go:noescape -func memmove(to, from unsafe.Pointer, n uintptr) - -// Memmove is exported for SeqAtomicLoad/SeqAtomicTryLoad, which can't -// define it because go_generics can't update the go:linkname annotation. -// Furthermore, go:linkname silently doesn't work if the local name is exported -// (this is of course undocumented), which is why this indirection is -// necessary. -func Memmove(to, from unsafe.Pointer, n uintptr) { - memmove(to, from, n) -} diff --git a/pkg/syncutil/norace_unsafe.go b/pkg/syncutil/norace_unsafe.go deleted file mode 100644 index 0a0a9deda..000000000 --- a/pkg/syncutil/norace_unsafe.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !race - -package syncutil - -import ( - "unsafe" -) - -// RaceEnabled is true if the Go data race detector is enabled. -const RaceEnabled = false - -// RaceDisable has the same semantics as runtime.RaceDisable. -func RaceDisable() { -} - -// RaceEnable has the same semantics as runtime.RaceEnable. -func RaceEnable() { -} - -// RaceAcquire has the same semantics as runtime.RaceAcquire. -func RaceAcquire(addr unsafe.Pointer) { -} - -// RaceRelease has the same semantics as runtime.RaceRelease. -func RaceRelease(addr unsafe.Pointer) { -} - -// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. -func RaceReleaseMerge(addr unsafe.Pointer) { -} diff --git a/pkg/syncutil/race_unsafe.go b/pkg/syncutil/race_unsafe.go deleted file mode 100644 index 206067ec1..000000000 --- a/pkg/syncutil/race_unsafe.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build race - -package syncutil - -import ( - "runtime" - "unsafe" -) - -// RaceEnabled is true if the Go data race detector is enabled. -const RaceEnabled = true - -// RaceDisable has the same semantics as runtime.RaceDisable. -func RaceDisable() { - runtime.RaceDisable() -} - -// RaceEnable has the same semantics as runtime.RaceEnable. -func RaceEnable() { - runtime.RaceEnable() -} - -// RaceAcquire has the same semantics as runtime.RaceAcquire. -func RaceAcquire(addr unsafe.Pointer) { - runtime.RaceAcquire(addr) -} - -// RaceRelease has the same semantics as runtime.RaceRelease. -func RaceRelease(addr unsafe.Pointer) { - runtime.RaceRelease(addr) -} - -// RaceReleaseMerge has the same semantics as runtime.RaceReleaseMerge. -func RaceReleaseMerge(addr unsafe.Pointer) { - runtime.RaceReleaseMerge(addr) -} diff --git a/pkg/syncutil/seqatomic_unsafe.go b/pkg/syncutil/seqatomic_unsafe.go deleted file mode 100644 index cb6d2eb22..000000000 --- a/pkg/syncutil/seqatomic_unsafe.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package template doesn't exist. This file must be instantiated using the -// go_template_instance rule in tools/go_generics/defs.bzl. -package template - -import ( - "fmt" - "reflect" - "strings" - "unsafe" - - "gvisor.dev/gvisor/pkg/syncutil" -) - -// Value is a required type parameter. -// -// Value must not contain any pointers, including interface objects, function -// objects, slices, maps, channels, unsafe.Pointer, and arrays or structs -// containing any of the above. An init() function will panic if this property -// does not hold. -type Value struct{} - -// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race -// with any writer critical sections in sc. -func SeqAtomicLoad(sc *syncutil.SeqCount, ptr *Value) Value { - // This function doesn't use SeqAtomicTryLoad because doing so is - // measurably, significantly (~20%) slower; Go is awful at inlining. - var val Value - for { - epoch := sc.BeginRead() - if syncutil.RaceEnabled { - // runtime.RaceDisable() doesn't actually stop the race detector, - // so it can't help us here. Instead, call runtime.memmove - // directly, which is not instrumented by the race detector. - syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) - } else { - // This is ~40% faster for short reads than going through memmove. - val = *ptr - } - if sc.ReadOk(epoch) { - break - } - } - return val -} - -// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section -// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read -// would race with a writer critical section, SeqAtomicTryLoad returns -// (unspecified, false). -func SeqAtomicTryLoad(sc *syncutil.SeqCount, epoch syncutil.SeqCountEpoch, ptr *Value) (Value, bool) { - var val Value - if syncutil.RaceEnabled { - syncutil.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) - } else { - val = *ptr - } - return val, sc.ReadOk(epoch) -} - -func init() { - var val Value - typ := reflect.TypeOf(val) - name := typ.Name() - if ptrs := syncutil.PointersInType(typ, name); len(ptrs) != 0 { - panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n"))) - } -} diff --git a/pkg/syncutil/seqatomictest/BUILD b/pkg/syncutil/seqatomictest/BUILD deleted file mode 100644 index ba18f3238..000000000 --- a/pkg/syncutil/seqatomictest/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -package(licenses = ["notice"]) - -go_template_instance( - name = "seqatomic_int", - out = "seqatomic_int_unsafe.go", - package = "seqatomic", - suffix = "Int", - template = "//pkg/syncutil:generic_seqatomic", - types = { - "Value": "int", - }, -) - -go_library( - name = "seqatomic", - srcs = ["seqatomic_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/syncutil/seqatomic", - deps = [ - "//pkg/syncutil", - ], -) - -go_test( - name = "seqatomic_test", - size = "small", - srcs = ["seqatomic_test.go"], - embed = [":seqatomic"], - deps = [ - "//pkg/syncutil", - ], -) diff --git a/pkg/syncutil/seqatomictest/seqatomic_test.go b/pkg/syncutil/seqatomictest/seqatomic_test.go deleted file mode 100644 index b0db44999..000000000 --- a/pkg/syncutil/seqatomictest/seqatomic_test.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package seqatomic - -import ( - "sync/atomic" - "testing" - "time" - - "gvisor.dev/gvisor/pkg/syncutil" -) - -func TestSeqAtomicLoadUncontended(t *testing.T) { - var seq syncutil.SeqCount - const want = 1 - data := want - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicLoadAfterWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - const want = 1 - seq.BeginWrite() - data = want - seq.EndWrite() - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicLoadDuringWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - const want = 1 - seq.BeginWrite() - go func() { - time.Sleep(time.Second) - data = want - seq.EndWrite() - }() - if got := SeqAtomicLoadInt(&seq, &data); got != want { - t.Errorf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } -} - -func TestSeqAtomicTryLoadUncontended(t *testing.T) { - var seq syncutil.SeqCount - const want = 1 - data := want - epoch := seq.BeginRead() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { - t.Errorf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) - } -} - -func TestSeqAtomicTryLoadDuringWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - epoch := seq.BeginRead() - seq.BeginWrite() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { - t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) - } - seq.EndWrite() -} - -func TestSeqAtomicTryLoadAfterWrite(t *testing.T) { - var seq syncutil.SeqCount - var data int - epoch := seq.BeginRead() - seq.BeginWrite() - seq.EndWrite() - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); ok { - t.Errorf("SeqAtomicTryLoadInt: got (%v, true), wanted (_, false)", got) - } -} - -func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) { - var seq syncutil.SeqCount - const want = 42 - data := want - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if got := SeqAtomicLoadInt(&seq, &data); got != want { - b.Fatalf("SeqAtomicLoadInt: got %v, wanted %v", got, want) - } - } - }) -} - -func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) { - var seq syncutil.SeqCount - const want = 42 - data := want - b.RunParallel(func(pb *testing.PB) { - epoch := seq.BeginRead() - for pb.Next() { - if got, ok := SeqAtomicTryLoadInt(&seq, epoch, &data); !ok || got != want { - b.Fatalf("SeqAtomicTryLoadInt: got (%v, %v), wanted (%v, true)", got, ok, want) - } - } - }) -} - -// For comparison: -func BenchmarkAtomicValueLoadIntUncontended(b *testing.B) { - var a atomic.Value - const want = 42 - a.Store(int(want)) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if got := a.Load().(int); got != want { - b.Fatalf("atomic.Value.Load: got %v, wanted %v", got, want) - } - } - }) -} diff --git a/pkg/syncutil/seqcount.go b/pkg/syncutil/seqcount.go deleted file mode 100644 index 11d8dbfaa..000000000 --- a/pkg/syncutil/seqcount.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package syncutil - -import ( - "fmt" - "reflect" - "runtime" - "sync/atomic" -) - -// SeqCount is a synchronization primitive for optimistic reader/writer -// synchronization in cases where readers can work with stale data and -// therefore do not need to block writers. -// -// Compared to sync/atomic.Value: -// -// - Mutation of SeqCount-protected data does not require memory allocation, -// whereas atomic.Value generally does. This is a significant advantage when -// writes are common. -// -// - Atomic reads of SeqCount-protected data require copying. This is a -// disadvantage when atomic reads are common. -// -// - SeqCount may be more flexible: correct use of SeqCount.ReadOk allows other -// operations to be made atomic with reads of SeqCount-protected data. -// -// - SeqCount may be less flexible: as of this writing, SeqCount-protected data -// cannot include pointers. -// -// - SeqCount is more cumbersome to use; atomic reads of SeqCount-protected -// data require instantiating function templates using go_generics (see -// seqatomic.go). -type SeqCount struct { - // epoch is incremented by BeginWrite and EndWrite, such that epoch is odd - // if a writer critical section is active, and a read from data protected - // by this SeqCount is atomic iff epoch is the same even value before and - // after the read. - epoch uint32 -} - -// SeqCountEpoch tracks writer critical sections in a SeqCount. -type SeqCountEpoch struct { - val uint32 -} - -// We assume that: -// -// - All functions in sync/atomic that perform a memory read are at least a -// read fence: memory reads before calls to such functions cannot be reordered -// after the call, and memory reads after calls to such functions cannot be -// reordered before the call, even if those reads do not use sync/atomic. -// -// - All functions in sync/atomic that perform a memory write are at least a -// write fence: memory writes before calls to such functions cannot be -// reordered after the call, and memory writes after calls to such functions -// cannot be reordered before the call, even if those writes do not use -// sync/atomic. -// -// As of this writing, the Go memory model completely fails to describe -// sync/atomic, but these properties are implied by -// https://groups.google.com/forum/#!topic/golang-nuts/7EnEhM3U7B8. - -// BeginRead indicates the beginning of a reader critical section. Reader -// critical sections DO NOT BLOCK writer critical sections, so operations in a -// reader critical section MAY RACE with writer critical sections. Races are -// detected by ReadOk at the end of the reader critical section. Thus, the -// low-level structure of readers is generally: -// -// for { -// epoch := seq.BeginRead() -// // do something idempotent with seq-protected data -// if seq.ReadOk(epoch) { -// break -// } -// } -// -// However, since reader critical sections may race with writer critical -// sections, the Go race detector will (accurately) flag data races in readers -// using this pattern. Most users of SeqCount will need to use the -// SeqAtomicLoad function template in seqatomic.go. -func (s *SeqCount) BeginRead() SeqCountEpoch { - epoch := atomic.LoadUint32(&s.epoch) - for epoch&1 != 0 { - runtime.Gosched() - epoch = atomic.LoadUint32(&s.epoch) - } - return SeqCountEpoch{epoch} -} - -// ReadOk returns true if the reader critical section initiated by a previous -// call to BeginRead() that returned epoch did not race with any writer critical -// sections. -// -// ReadOk may be called any number of times during a reader critical section. -// Reader critical sections do not need to be explicitly terminated; the last -// call to ReadOk is implicitly the end of the reader critical section. -func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool { - return atomic.LoadUint32(&s.epoch) == epoch.val -} - -// BeginWrite indicates the beginning of a writer critical section. -// -// SeqCount does not support concurrent writer critical sections; clients with -// concurrent writers must synchronize them using e.g. sync.Mutex. -func (s *SeqCount) BeginWrite() { - if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 == 0 { - panic("SeqCount.BeginWrite during writer critical section") - } -} - -// EndWrite ends the effect of a preceding BeginWrite. -func (s *SeqCount) EndWrite() { - if epoch := atomic.AddUint32(&s.epoch, 1); epoch&1 != 0 { - panic("SeqCount.EndWrite outside writer critical section") - } -} - -// PointersInType returns a list of pointers reachable from values named -// valName of the given type. -// -// PointersInType is not exhaustive, but it is guaranteed that if typ contains -// at least one pointer, then PointersInTypeOf returns a non-empty list. -func PointersInType(typ reflect.Type, valName string) []string { - switch kind := typ.Kind(); kind { - case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: - return nil - - case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String, reflect.UnsafePointer: - return []string{valName} - - case reflect.Array: - return PointersInType(typ.Elem(), valName+"[]") - - case reflect.Struct: - var ptrs []string - for i, n := 0, typ.NumField(); i < n; i++ { - field := typ.Field(i) - ptrs = append(ptrs, PointersInType(field.Type, fmt.Sprintf("%s.%s", valName, field.Name))...) - } - return ptrs - - default: - return []string{fmt.Sprintf("%s (of type %s with unknown kind %s)", valName, typ, kind)} - } -} diff --git a/pkg/syncutil/seqcount_test.go b/pkg/syncutil/seqcount_test.go deleted file mode 100644 index 14d6aedea..000000000 --- a/pkg/syncutil/seqcount_test.go +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package syncutil - -import ( - "reflect" - "testing" - "time" -) - -func TestSeqCountWriteUncontended(t *testing.T) { - var seq SeqCount - seq.BeginWrite() - seq.EndWrite() -} - -func TestSeqCountReadUncontended(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountBeginReadAfterWrite(t *testing.T) { - var seq SeqCount - var data int32 - const want = 1 - seq.BeginWrite() - data = want - seq.EndWrite() - epoch := seq.BeginRead() - if data != want { - t.Errorf("Reader: got %v, wanted %v", data, want) - } - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountBeginReadDuringWrite(t *testing.T) { - var seq SeqCount - var data int - const want = 1 - seq.BeginWrite() - go func() { - time.Sleep(time.Second) - data = want - seq.EndWrite() - }() - epoch := seq.BeginRead() - if data != want { - t.Errorf("Reader: got %v, wanted %v", data, want) - } - if !seq.ReadOk(epoch) { - t.Errorf("ReadOk: got false, wanted true") - } -} - -func TestSeqCountReadOkAfterWrite(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - seq.BeginWrite() - seq.EndWrite() - if seq.ReadOk(epoch) { - t.Errorf("ReadOk: got true, wanted false") - } -} - -func TestSeqCountReadOkDuringWrite(t *testing.T) { - var seq SeqCount - epoch := seq.BeginRead() - seq.BeginWrite() - if seq.ReadOk(epoch) { - t.Errorf("ReadOk: got true, wanted false") - } - seq.EndWrite() -} - -func BenchmarkSeqCountWriteUncontended(b *testing.B) { - var seq SeqCount - for i := 0; i < b.N; i++ { - seq.BeginWrite() - seq.EndWrite() - } -} - -func BenchmarkSeqCountReadUncontended(b *testing.B) { - var seq SeqCount - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - epoch := seq.BeginRead() - if !seq.ReadOk(epoch) { - b.Fatalf("ReadOk: got false, wanted true") - } - } - }) -} - -func TestPointersInType(t *testing.T) { - for _, test := range []struct { - name string // used for both test and value name - val interface{} - ptrs []string - }{ - { - name: "EmptyStruct", - val: struct{}{}, - }, - { - name: "Int", - val: int(0), - }, - { - name: "MixedStruct", - val: struct { - b bool - I int - ExportedPtr *struct{} - unexportedPtr *struct{} - arr [2]int - ptrArr [2]*int - nestedStruct struct { - nestedNonptr int - nestedPtr *int - } - structArr [1]struct { - nonptr int - ptr *int - } - }{}, - ptrs: []string{ - "MixedStruct.ExportedPtr", - "MixedStruct.unexportedPtr", - "MixedStruct.ptrArr[]", - "MixedStruct.nestedStruct.nestedPtr", - "MixedStruct.structArr[].ptr", - }, - }, - } { - t.Run(test.name, func(t *testing.T) { - typ := reflect.TypeOf(test.val) - ptrs := PointersInType(typ, test.name) - t.Logf("Found pointers: %v", ptrs) - if (len(ptrs) != 0 || len(test.ptrs) != 0) && !reflect.DeepEqual(ptrs, test.ptrs) { - t.Errorf("Got %v, wanted %v", ptrs, test.ptrs) - } - }) - } -} diff --git a/pkg/syncutil/syncutil.go b/pkg/syncutil/syncutil.go deleted file mode 100644 index 66e750d06..000000000 --- a/pkg/syncutil/syncutil.go +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package syncutil provides synchronization primitives. -package syncutil diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index e07ebd153..db06d02c6 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -15,6 +15,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip/buffer", "//pkg/tcpip/iptables", "//pkg/waiter", diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 78df5a0b1..3df7d18d3 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -9,6 +9,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index cd6ce930a..a2f44b496 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -20,9 +20,9 @@ import ( "errors" "io" "net" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 897c94821..66cc53ed4 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -16,6 +16,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index fa8a703d9..b7f60178e 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,10 +41,10 @@ package fdbased import ( "fmt" - "sync" "syscall" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index a4f9cdd69..09165dd4c 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -15,6 +15,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", @@ -31,6 +32,7 @@ go_test( ], embed = [":sharedmem"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index 6b5bc542c..a0d4ad0be 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -21,4 +21,5 @@ go_test( "pipe_test.go", ], embed = [":pipe"], + deps = ["//pkg/sync"], ) diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go index 59ef69a8b..dc239a0d0 100644 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go @@ -18,8 +18,9 @@ import ( "math/rand" "reflect" "runtime" - "sync" "testing" + + "gvisor.dev/gvisor/pkg/sync" ) func TestSimpleReadWrite(t *testing.T) { diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 080f9d667..655e537c4 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -23,11 +23,11 @@ package sharedmem import ( - "sync" "sync/atomic" "syscall" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 89603c48f..5c729a439 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -22,11 +22,11 @@ import ( "math/rand" "os" "strings" - "sync" "syscall" "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index acf1e022c..ed16076fd 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -28,6 +28,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/log", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", ], diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 6da5238ec..92f2aa13a 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -19,9 +19,9 @@ package fragmentation import ( "fmt" "log" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 9e002e396..0a83d81f2 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -18,9 +18,9 @@ import ( "container/heap" "fmt" "math" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index e156b01f6..a6ef3bdcc 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -9,6 +9,7 @@ go_library( importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", ], ) diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 6c5e19e8f..b937cb84b 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -18,9 +18,9 @@ package ports import ( "math" "math/rand" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 826fca4de..6a8654105 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -36,6 +36,7 @@ go_library( "//pkg/ilist", "//pkg/rand", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", @@ -80,6 +81,7 @@ go_test( embed = [":stack"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", ], ) diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 267df60d1..403557fd7 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -16,10 +16,10 @@ package stack import ( "fmt" - "sync" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 9946b8fe8..1baa498d0 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -16,12 +16,12 @@ package stack import ( "fmt" - "sync" "sync/atomic" "testing" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3810c6602..fe557ccbd 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,9 +16,9 @@ package stack import ( "strings" - "sync" "sync/atomic" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 41bf9fd9b..a47ceba54 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -21,13 +21,13 @@ package stack import ( "encoding/binary" - "sync" "sync/atomic" "time" "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 67c21be42..f384a91de 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -18,8 +18,8 @@ import ( "fmt" "math/rand" "sort" - "sync" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 72b5ce179..4a090ac86 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -35,10 +35,10 @@ import ( "reflect" "strconv" "strings" - "sync" "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/waiter" diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index d8c5b5058..3aa23d529 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -28,6 +28,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index c7ce74cdd..330786f4c 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,8 +15,7 @@ package icmp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD index 44b58ff6b..4858d150c 100644 --- a/pkg/tcpip/transport/packet/BUILD +++ b/pkg/tcpip/transport/packet/BUILD @@ -28,6 +28,7 @@ go_library( deps = [ "//pkg/log", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 07ffa8aba..fc5bc69fa 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -25,8 +25,7 @@ package packet import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 00991ac8e..2f2131ff7 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -29,6 +29,7 @@ go_library( deps = [ "//pkg/log", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 85f7eb76b..ee9c4c58b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -26,8 +26,7 @@ package raw import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 3b353d56c..353bd06f4 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -48,6 +48,7 @@ go_library( "//pkg/log", "//pkg/rand", "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 5422ae80c..1ea996936 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -19,11 +19,11 @@ import ( "encoding/binary" "hash" "io" - "sync" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index cdd69f360..613ec1775 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -16,11 +16,11 @@ package tcp import ( "encoding/binary" - "sync" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 830bc1e3e..cca511fb9 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -19,12 +19,12 @@ import ( "fmt" "math" "strings" - "sync" "sync/atomic" "time" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 7aa4c3f0e..4b8d867bc 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -16,9 +16,9 @@ package tcp import ( "fmt" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 4983bca81..7eb613be5 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -15,8 +15,7 @@ package tcp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index bc718064c..9a8f64aa6 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -22,9 +22,9 @@ package tcp import ( "strings" - "sync" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go index e0759225e..bd20a7ee9 100644 --- a/pkg/tcpip/transport/tcp/segment_queue.go +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -15,7 +15,7 @@ package tcp import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // segmentQueue is a bounded, thread-safe queue of TCP segments. diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 8a947dc66..79f2d274b 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -16,11 +16,11 @@ package tcp import ( "math" - "sync" "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 97e4d5825..57ff123e3 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -30,6 +30,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/sleep", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 864dc8733..a4ff29a7d 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -15,8 +15,7 @@ package udp import ( - "sync" - + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD index 6afdb29b7..07778e4f7 100644 --- a/pkg/tmutex/BUILD +++ b/pkg/tmutex/BUILD @@ -15,4 +15,5 @@ go_test( size = "medium", srcs = ["tmutex_test.go"], embed = [":tmutex"], + deps = ["//pkg/sync"], ) diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go index ce34c7962..05540696a 100644 --- a/pkg/tmutex/tmutex_test.go +++ b/pkg/tmutex/tmutex_test.go @@ -17,10 +17,11 @@ package tmutex import ( "fmt" "runtime" - "sync" "sync/atomic" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) func TestBasicLock(t *testing.T) { diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index 8f6f180e5..d1885ae66 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -24,4 +24,5 @@ go_test( "unet_test.go", ], embed = [":unet"], + deps = ["//pkg/sync"], ) diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go index a3cc6f5d3..5c4b9e8e9 100644 --- a/pkg/unet/unet_test.go +++ b/pkg/unet/unet_test.go @@ -19,10 +19,11 @@ import ( "os" "path/filepath" "reflect" - "sync" "syscall" "testing" "time" + + "gvisor.dev/gvisor/pkg/sync" ) func randomFilename() (string, error) { diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD index b6bbb0ea2..b8fdc3125 100644 --- a/pkg/urpc/BUILD +++ b/pkg/urpc/BUILD @@ -11,6 +11,7 @@ go_library( deps = [ "//pkg/fd", "//pkg/log", + "//pkg/sync", "//pkg/unet", ], ) diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go index df59ffab1..13b2ea314 100644 --- a/pkg/urpc/urpc.go +++ b/pkg/urpc/urpc.go @@ -27,10 +27,10 @@ import ( "os" "reflect" "runtime" - "sync" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" ) diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 0427bc41f..1c6890e52 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -24,6 +24,7 @@ go_library( ], importpath = "gvisor.dev/gvisor/pkg/waiter", visibility = ["//visibility:public"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go index 8a65ed164..f708e95fa 100644 --- a/pkg/waiter/waiter.go +++ b/pkg/waiter/waiter.go @@ -58,7 +58,7 @@ package waiter import ( - "sync" + "gvisor.dev/gvisor/pkg/sync" ) // EventMask represents io events as used in the poll() syscall. diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 6226b63f8..3e20f8f2f 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -74,6 +74,7 @@ go_library( "//pkg/sentry/usage", "//pkg/sentry/usermem", "//pkg/sentry/watchdog", + "//pkg/sync", "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/link/fdbased", @@ -114,6 +115,7 @@ go_test( "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/unet", "//runsc/fsgofer", "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go index 352e710d2..9c23b9553 100644 --- a/runsc/boot/compat.go +++ b/runsc/boot/compat.go @@ -17,7 +17,6 @@ package boot import ( "fmt" "os" - "sync" "syscall" "github.com/golang/protobuf/proto" @@ -27,6 +26,7 @@ import ( ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto" "gvisor.dev/gvisor/pkg/sentry/strace" spb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto" + "gvisor.dev/gvisor/pkg/sync" ) func initCompatLogs(fd int) error { diff --git a/runsc/boot/limits.go b/runsc/boot/limits.go index d1c0bb9b5..ce62236e5 100644 --- a/runsc/boot/limits.go +++ b/runsc/boot/limits.go @@ -16,12 +16,12 @@ package boot import ( "fmt" - "sync" "syscall" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/limits" + "gvisor.dev/gvisor/pkg/sync" ) // Mapping from linux resource names to limits.LimitType. diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index bc1d0c1bb..fad72f4ab 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -20,7 +20,6 @@ import ( mrand "math/rand" "os" "runtime" - "sync" "sync/atomic" "syscall" gtime "time" @@ -46,6 +45,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/watchdog" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index 147ff7703..bec0dc292 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -19,7 +19,6 @@ import ( "math/rand" "os" "reflect" - "sync" "syscall" "testing" "time" @@ -30,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/context/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/runsc/fsgofer" ) diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index 250845ad7..b94bc4fa0 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -44,6 +44,7 @@ go_library( "//pkg/sentry/control", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/unet", "//pkg/urpc", "//runsc/boot", diff --git a/runsc/cmd/create.go b/runsc/cmd/create.go index a4e3071b3..1815c93b9 100644 --- a/runsc/cmd/create.go +++ b/runsc/cmd/create.go @@ -16,6 +16,7 @@ package cmd import ( "context" + "flag" "github.com/google/subcommands" "gvisor.dev/gvisor/runsc/boot" diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go index 4831210c0..7df7995f0 100644 --- a/runsc/cmd/gofer.go +++ b/runsc/cmd/gofer.go @@ -21,7 +21,6 @@ import ( "os" "path/filepath" "strings" - "sync" "syscall" "flag" @@ -30,6 +29,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/fsgofer" diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go index de2115dff..5e9bc53ab 100644 --- a/runsc/cmd/start.go +++ b/runsc/cmd/start.go @@ -16,6 +16,7 @@ package cmd import ( "context" + "flag" "github.com/google/subcommands" "gvisor.dev/gvisor/runsc/boot" diff --git a/runsc/container/BUILD b/runsc/container/BUILD index 2bd12120d..6dea179e4 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -18,6 +18,7 @@ go_library( deps = [ "//pkg/log", "//pkg/sentry/control", + "//pkg/sync", "//runsc/boot", "//runsc/cgroup", "//runsc/sandbox", @@ -53,6 +54,7 @@ go_test( "//pkg/sentry/control", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sync", "//pkg/unet", "//pkg/urpc", "//runsc/boot", diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index 5ed131a7f..060b63bf3 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -20,7 +20,6 @@ import ( "io" "os" "path/filepath" - "sync" "syscall" "testing" "time" @@ -29,6 +28,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/testutil" diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index c10f85992..b54d8f712 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -26,7 +26,6 @@ import ( "reflect" "strconv" "strings" - "sync" "syscall" "testing" "time" @@ -39,6 +38,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/boot/platforms" "gvisor.dev/gvisor/runsc/specutils" diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 4ad09ceab..2da93ec5b 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -22,7 +22,6 @@ import ( "path" "path/filepath" "strings" - "sync" "syscall" "testing" "time" @@ -30,6 +29,7 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/specutils" "gvisor.dev/gvisor/runsc/testutil" diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go index d95151ea5..17a251530 100644 --- a/runsc/container/state_file.go +++ b/runsc/container/state_file.go @@ -20,10 +20,10 @@ import ( "io/ioutil" "os" "path/filepath" - "sync" "github.com/gofrs/flock" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" ) const stateFileExtension = ".state" diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index afcb41801..a9582d92b 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -19,6 +19,7 @@ go_library( "//pkg/fd", "//pkg/log", "//pkg/p9", + "//pkg/sync", "//pkg/syserr", "//runsc/specutils", "@org_golang_x_sys//unix:go_default_library", diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index b59e1a70e..93606d051 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -29,7 +29,6 @@ import ( "path/filepath" "runtime" "strconv" - "sync" "syscall" "golang.org/x/sys/unix" @@ -37,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/specutils" ) diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD index 8001949d5..ddbc37456 100644 --- a/runsc/sandbox/BUILD +++ b/runsc/sandbox/BUILD @@ -19,6 +19,7 @@ go_library( "//pkg/log", "//pkg/sentry/control", "//pkg/sentry/platform", + "//pkg/sync", "//pkg/tcpip/header", "//pkg/tcpip/stack", "//pkg/urpc", diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index ce1452b87..ec72bdbfd 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -22,7 +22,6 @@ import ( "os" "os/exec" "strconv" - "sync" "syscall" "time" @@ -34,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/boot/platforms" diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD index c96ca2eb6..3c3027cb5 100644 --- a/runsc/testutil/BUILD +++ b/runsc/testutil/BUILD @@ -10,6 +10,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/log", + "//pkg/sync", "//runsc/boot", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go index 9632776d2..fb22eae39 100644 --- a/runsc/testutil/testutil.go +++ b/runsc/testutil/testutil.go @@ -34,7 +34,6 @@ import ( "path/filepath" "strconv" "strings" - "sync" "sync/atomic" "syscall" "time" @@ -42,6 +41,7 @@ import ( "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/specutils" ) -- cgit v1.2.3 From d27208463e93c01d4e39c0450c3b27c00c466728 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 10 Jan 2020 14:47:08 -0800 Subject: Automated rollback of changelist 288990597 PiperOrigin-RevId: 289169518 --- pkg/tcpip/stack/BUILD | 23 ++---------- pkg/tcpip/stack/ndp_test.go | 87 ++++--------------------------------------- pkg/tcpip/stack/stack_test.go | 74 +++++++++++++++++++++++++++++++++--- 3 files changed, 78 insertions(+), 106 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 6a8654105..705e984c1 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -53,6 +53,7 @@ go_test( name = "stack_x_test", size = "small", srcs = [ + "ndp_test.go", "stack_test.go", "transport_demuxer_test.go", "transport_test.go", @@ -62,12 +63,14 @@ go_test( "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go-cmp//cmp:go_default_library", @@ -85,23 +88,3 @@ go_test( "//pkg/tcpip", ], ) - -go_test( - name = "ndp_test", - size = "small", - srcs = ["ndp_test.go"], - deps = [ - ":stack", - "//pkg/rand", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/tcpip/checker", - "//pkg/tcpip/header", - "//pkg/tcpip/link/channel", - "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/transport/icmp", - "//pkg/tcpip/transport/udp", - "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", - ], -) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 108762b6e..f9bc18c55 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package ndp_test +package stack_test import ( "encoding/binary" @@ -301,8 +301,6 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s // Included in the subtests is a test to make sure that an invalid // RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. func TestDADResolve(t *testing.T) { - t.Parallel() - tests := []struct { name string dupAddrDetectTransmits uint8 @@ -435,8 +433,6 @@ func TestDADResolve(t *testing.T) { // a node doing DAD for the same address), or if another node is detected to own // the address already (receive an NA message for the tentative address). func TestDADFail(t *testing.T) { - t.Parallel() - tests := []struct { name string makeBuf func(tgt tcpip.Address) buffer.Prependable @@ -580,8 +576,6 @@ func TestDADFail(t *testing.T) { // TestDADStop tests to make sure that the DAD process stops when an address is // removed. func TestDADStop(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } @@ -654,71 +648,6 @@ func TestDADStop(t *testing.T) { } } -// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 -// link-local addresses will only be assigned after the DAD process resolves. -func TestNICAutoGenAddrDoesDAD(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), - } - ndpConfigs := stack.DefaultNDPConfigurations() - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: ndpConfigs, - AutoGenIPv6LinkLocal: true, - NDPDisp: &ndpDisp, - } - - e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - // Address should not be considered bound to the - // NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } - - linkLocalAddr := header.LinkLocalAddr(linkAddr1) - - // Wait for DAD to resolve. - select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): - // We should get a resolution event after 1s (default time to - // resolve as per default NDP configurations). Waiting for that - // resolution time + an extra 1s without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != linkLocalAddr { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") - } - } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) - } -} - // TestSetNDPConfigurationFailsForBadNICID tests to make sure we get an error if // we attempt to update NDP configurations using an invalid NICID. func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { @@ -736,8 +665,6 @@ func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { // configurations without affecting the default NDP configurations or other // interfaces' configurations. func TestSetNDPConfigurations(t *testing.T) { - t.Parallel() - tests := []struct { name string dupAddrDetectTransmits uint8 @@ -992,8 +919,6 @@ func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, on // TestNoRouterDiscovery tests that router discovery will not be performed if // configured not to. func TestNoRouterDiscovery(t *testing.T) { - t.Parallel() - // Being configured to discover routers means handle and // discover are set to true and forwarding is set to false. // This tests all possible combinations of the configurations, @@ -1006,6 +931,8 @@ func TestNoRouterDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), } @@ -1240,8 +1167,6 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { // TestNoPrefixDiscovery tests that prefix discovery will not be performed if // configured not to. func TestNoPrefixDiscovery(t *testing.T) { - t.Parallel() - prefix := tcpip.AddressWithPrefix{ Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"), PrefixLen: 64, @@ -1259,6 +1184,8 @@ func TestNoPrefixDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, 1), } @@ -1615,8 +1542,6 @@ func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { // TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. func TestNoAutoGenAddr(t *testing.T) { - t.Parallel() - prefix, _, _ := prefixSubnetAddr(0, "") // Being configured to auto-generate addresses means handle and @@ -1631,6 +1556,8 @@ func TestNoAutoGenAddr(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { + t.Parallel() + ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e8de4e87d..44e5229cc 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -24,6 +24,7 @@ import ( "sort" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/rand" @@ -49,8 +50,6 @@ const ( // where another value is explicitly used. It is chosen to match the MTU // of loopback interfaces on linux systems. defaultMTU = 65536 - - linkAddr = "\x02\x02\x03\x04\x05\x06" ) // fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and @@ -1910,7 +1909,7 @@ func TestNICAutoGenAddr(t *testing.T) { { "Disabled", false, - linkAddr, + linkAddr1, stack.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(nicID tcpip.NICID, _ string) string { return fmt.Sprintf("nic%d", nicID) @@ -1921,7 +1920,7 @@ func TestNICAutoGenAddr(t *testing.T) { { "Enabled", true, - linkAddr, + linkAddr1, stack.OpaqueInterfaceIdentifierOptions{}, true, }, @@ -2069,14 +2068,14 @@ func TestNICAutoGenAddrWithOpaque(t *testing.T) { name: "Disabled", nicName: "nic1", autoGen: false, - linkAddr: linkAddr, + linkAddr: linkAddr1, secretKey: secretKey[:], }, { name: "Enabled", nicName: "nic1", autoGen: true, - linkAddr: linkAddr, + linkAddr: linkAddr1, secretKey: secretKey[:], }, // These are all cases where we would not have generated a @@ -2214,6 +2213,69 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { } } +// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 +// link-local addresses will only be assigned after the DAD process resolves. +func TestNICAutoGenAddrDoesDAD(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + ndpConfigs := stack.DefaultNDPConfigurations() + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + } + + e := channel.New(10, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } + + // Address should not be considered bound to the + // NIC yet (DAD ongoing). + addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + } + + linkLocalAddr := header.LinkLocalAddr(linkAddr1) + + // Wait for DAD to resolve. + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // We should get a resolution event after 1s (default time to + // resolve as per default NDP configurations). Waiting for that + // resolution time + an extra 1s without a resolution event + // means something is wrong. + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != 1 { + t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + } + if e.addr != linkLocalAddr { + t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + } +} + // TestNewPEB tests that a new PrimaryEndpointBehavior value (peb) is respected // when an address's kind gets "promoted" to permanent from permanentExpired. func TestNewPEBOnPromotionToPermanent(t *testing.T) { -- cgit v1.2.3 From debd213da61cf35d7c91346820e93fc87bfa5896 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Mon, 13 Jan 2020 14:45:31 -0800 Subject: Allow dual stack sockets to operate on AF_INET Fixes #1490 Fixes #1495 PiperOrigin-RevId: 289523250 --- pkg/sentry/socket/netstack/netstack.go | 65 +++++++++--- pkg/sentry/socket/unix/unix.go | 5 +- pkg/sentry/strace/socket.go | 2 +- pkg/tcpip/stack/stack.go | 43 ++++++++ pkg/tcpip/transport/icmp/endpoint.go | 22 ++-- pkg/tcpip/transport/tcp/endpoint.go | 23 +--- pkg/tcpip/transport/udp/endpoint.go | 41 ++------ scripts/common.sh | 2 +- test/syscalls/linux/BUILD | 1 + test/syscalls/linux/socket_inet_loopback.cc | 156 ++++++++++++++++++++++++++++ 10 files changed, 278 insertions(+), 82 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 099319327..c020c11cb 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -324,22 +324,15 @@ func bytesToIPAddress(addr []byte) tcpip.Address { // converts it to the FullAddress format. It supports AF_UNIX, AF_INET, // AF_INET6, and AF_PACKET addresses. // -// strict indicates whether addresses with the AF_UNSPEC family are accepted of not. -// // AddressAndFamily returns an address and its family. -func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) { +func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument } - family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) && (strict || family != linux.AF_UNSPEC) { - return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported - } - // Get the rest of the fields based on the address family. - switch family { + switch family := usermem.ByteOrder.Uint16(addr); family { case linux.AF_UNIX: path := addr[2:] if len(path) > linux.UnixPathMax { @@ -638,10 +631,40 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { return r } +func (s *SocketOperations) checkFamily(family uint16, exact bool) *syserr.Error { + if family == uint16(s.family) { + return nil + } + if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { + v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption) + if err != nil { + return syserr.TranslateNetstackError(err) + } + if !v { + return nil + } + } + return syserr.ErrInvalidArgument +} + +// mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the +// receiver's family is AF_INET6. +// +// This is a hack to work around the fact that both IPv4 and IPv6 ANY are +// represented by the empty string. +// +// TODO(gvisor.dev/issues/1556): remove this function. +func (s *SocketOperations) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress { + if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET { + addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + } + return addr +} + // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */) + addr, family, err := AddressAndFamily(sockaddr) if err != nil { return err } @@ -653,6 +676,12 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo } return syserr.TranslateNetstackError(err) } + + if err := s.checkFamily(family, false /* exact */); err != nil { + return err + } + addr = s.mapFamily(addr, family) + // Always return right away in the non-blocking case. if !blocking { return syserr.TranslateNetstackError(s.Endpoint.Connect(addr)) @@ -681,10 +710,14 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */) + addr, family, err := AddressAndFamily(sockaddr) if err != nil { return err } + if err := s.checkFamily(family, true /* exact */); err != nil { + return err + } + addr = s.mapFamily(addr, family) // Issue the bind request to the endpoint. return syserr.TranslateNetstackError(s.Endpoint.Bind(addr)) @@ -2080,8 +2113,8 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) case linux.AF_INET6: var out linux.SockAddrInet6 - if len(addr.Addr) == 4 { - // Copy address is v4-mapped format. + if len(addr.Addr) == header.IPv4AddressSize { + // Copy address in v4-mapped format. copy(out.Addr[12:], addr.Addr) out.Addr[10] = 0xff out.Addr[11] = 0xff @@ -2395,10 +2428,14 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */) + addrBuf, family, err := AddressAndFamily(to) if err != nil { return 0, err } + if err := s.checkFamily(family, false /* exact */); err != nil { + return 0, err + } + addrBuf = s.mapFamily(addrBuf, family) addr = &addrBuf } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 91effe89a..7f49ba864 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -116,13 +116,16 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) + addr, family, err := netstack.AddressAndFamily(sockaddr) if err != nil { if err == syserr.ErrAddressFamilyNotSupported { err = syserr.ErrInvalidArgument } return "", err } + if family != linux.AF_UNIX { + return "", syserr.ErrInvalidArgument + } // The address is trimmed by GetAddress. p := string(addr.Addr) diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index 51f2efb39..b6d7177f4 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -341,7 +341,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { switch family { case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX: - fa, _, err := netstack.AddressAndFamily(int(family), b, true /* strict */) + fa, _, err := netstack.AddressAndFamily(b) if err != nil { return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err) } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a47ceba54..113b457fb 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -547,6 +547,49 @@ type TransportEndpointInfo struct { RegisterNICID tcpip.NICID } +// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address +// and returns the network protocol number to be used to communicate with the +// specified address. It returns an error if the passed address is incompatible +// with the receiver. +func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.NetProto + switch len(addr.Addr) { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + if header.IsV4MappedAddress(addr.Addr) { + netProto = header.IPv4ProtocolNumber + addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] + if addr.Addr == header.IPv4Any { + addr.Addr = "" + } + } + } + + switch len(e.ID.LocalAddress) { + case header.IPv4AddressSize: + if len(addr.Addr) == header.IPv6AddressSize { + return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + } + case header.IPv6AddressSize: + if len(addr.Addr) == header.IPv4AddressSize { + return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable + } + } + + switch { + case netProto == e.NetProto: + case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber: + if v6only { + return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute + } + default: + return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState + } + + return addr, netProto, nil +} + // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo // marker interface. func (*TransportEndpointInfo) IsEndpointInfo() {} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 330786f4c..42afb3f5b 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -288,7 +288,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c toCopy := *to to = &toCopy - netProto, err := e.checkV4Mapped(to, true) + netProto, err := e.checkV4Mapped(to) if err != nil { return 0, nil, err } @@ -475,18 +475,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err }) } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto - if header.IsV4MappedAddress(addr.Addr) { - return 0, tcpip.ErrNoRoute - } - - // Fail if we're bound to an address length different from the one we're - // checking. - if l := len(e.ID.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { - return 0, tcpip.ErrInvalidEndpointState +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, false /* v6only */) + if err != nil { + return 0, err } - + *addr = unwrapped return netProto, nil } @@ -518,7 +512,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr, false) + netProto, err := e.checkV4Mapped(&addr) if err != nil { return err } @@ -631,7 +625,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr, false) + netProto, err := e.checkV4Mapped(&addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index cca511fb9..cc8b533c8 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1691,26 +1691,11 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto - if header.IsV4MappedAddress(addr.Addr) { - // Fail if using a v4 mapped address on a v6only endpoint. - if e.v6only { - return 0, tcpip.ErrNoRoute - } - - netProto = header.IPv4ProtocolNumber - addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == header.IPv4Any { - addr.Addr = "" - } - } - - // Fail if we're bound to an address length different from the one we're - // checking. - if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { - return 0, tcpip.ErrInvalidEndpointState + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) + if err != nil { + return 0, err } - + *addr = unwrapped return netProto, nil } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index a4ff29a7d..13446f5d9 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -402,7 +402,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrBroadcastDisabled } - netProto, err := e.checkV4Mapped(to, false) + netProto, err := e.checkV4Mapped(to) if err != nil { return 0, nil, err } @@ -501,7 +501,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - netProto, err := e.checkV4Mapped(&fa, false) + netProto, err := e.checkV4Mapped(&fa) if err != nil { return err } @@ -839,35 +839,12 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u return nil } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.NetProto - if len(addr.Addr) == 0 { - return netProto, nil - } - if header.IsV4MappedAddress(addr.Addr) { - // Fail if using a v4 mapped address on a v6only endpoint. - if e.v6only { - return 0, tcpip.ErrNoRoute - } - - netProto = header.IPv4ProtocolNumber - addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == header.IPv4Any { - addr.Addr = "" - } - - // Fail if we are bound to an IPv6 address. - if !allowMismatch && len(e.ID.LocalAddress) == 16 { - return 0, tcpip.ErrNetworkUnreachable - } - } - - // Fail if we're bound to an address length different from the one we're - // checking. - if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) { - return 0, tcpip.ErrInvalidEndpointState +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) + if err != nil { + return 0, err } - + *addr = unwrapped return netProto, nil } @@ -916,7 +893,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - netProto, err := e.checkV4Mapped(&addr, false) + netProto, err := e.checkV4Mapped(&addr) if err != nil { return err } @@ -1074,7 +1051,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr, true) + netProto, err := e.checkV4Mapped(&addr) if err != nil { return err } diff --git a/scripts/common.sh b/scripts/common.sh index 6dabad141..fdb1aa142 100755 --- a/scripts/common.sh +++ b/scripts/common.sh @@ -73,7 +73,7 @@ function install_runsc() { sudo "${RUNSC_BIN}" install --experimental=true --runtime="${runtime}" -- --debug-log "${RUNSC_LOGS}" "$@" # Clear old logs files that may exist. - sudo rm -f "${RUNSC_LOGS_DIR}"/* + sudo rm -f "${RUNSC_LOGS_DIR}"/'*' # Restart docker to pick up the new runtime configuration. sudo systemctl restart docker diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index ce8abe217..4c7ec3f06 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2693,6 +2693,7 @@ cc_binary( srcs = ["socket_inet_loopback.cc"], linkstatic = 1, deps = [ + ":ip_socket_test_util", ":socket_test_util", "//test/util:file_descriptor", "//test/util:posix_error", diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 138024d9e..5d114d460 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -32,6 +32,7 @@ #include "absl/strings/str_cat.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -102,6 +103,161 @@ TEST(BadSocketPairArgs, ValidateErrForBadCallsToSocketPair) { SyscallFailsWithErrno(EAFNOSUPPORT)); } +enum class Operation { + Bind, + Connect, + SendTo, +}; + +std::string OperationToString(Operation operation) { + switch (operation) { + case Operation::Bind: + return "Bind"; + case Operation::Connect: + return "Connect"; + case Operation::SendTo: + return "SendTo"; + } +} + +using OperationSequence = std::vector; + +using DualStackSocketTest = + ::testing::TestWithParam>; + +TEST_P(DualStackSocketTest, AddressOperations) { + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_DGRAM, 0)); + + const TestAddress& addr = std::get<0>(GetParam()); + const OperationSequence& operations = std::get<1>(GetParam()); + + auto addr_in = reinterpret_cast(&addr.addr); + + // sockets may only be bound once. Both `connect` and `sendto` cause a socket + // to be bound. + bool bound = false; + for (const Operation& operation : operations) { + bool sockname = false; + bool peername = false; + switch (operation) { + case Operation::Bind: { + ASSERT_NO_ERRNO(SetAddrPort( + addr.family(), const_cast(&addr.addr), 0)); + + int bind_ret = bind(fd.get(), addr_in, addr.addr_len); + + // Dual stack sockets may only be bound to AF_INET6. + if (!bound && addr.family() == AF_INET6) { + EXPECT_THAT(bind_ret, SyscallSucceeds()); + bound = true; + + sockname = true; + } else { + EXPECT_THAT(bind_ret, SyscallFailsWithErrno(EINVAL)); + } + break; + } + case Operation::Connect: { + ASSERT_NO_ERRNO(SetAddrPort( + addr.family(), const_cast(&addr.addr), 1337)); + + EXPECT_THAT(connect(fd.get(), addr_in, addr.addr_len), + SyscallSucceeds()) + << GetAddrStr(addr_in); + bound = true; + + sockname = true; + peername = true; + + break; + } + case Operation::SendTo: { + const char payload[] = "hello"; + ASSERT_NO_ERRNO(SetAddrPort( + addr.family(), const_cast(&addr.addr), 1337)); + + ssize_t sendto_ret = sendto(fd.get(), &payload, sizeof(payload), 0, + addr_in, addr.addr_len); + + EXPECT_THAT(sendto_ret, SyscallSucceedsWithValue(sizeof(payload))); + sockname = !bound; + bound = true; + break; + } + } + + if (sockname) { + sockaddr_storage sock_addr; + socklen_t addrlen = sizeof(sock_addr); + ASSERT_THAT(getsockname(fd.get(), reinterpret_cast(&sock_addr), + &addrlen), + SyscallSucceeds()); + ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6)); + + auto sock_addr_in6 = reinterpret_cast(&sock_addr); + + if (operation == Operation::SendTo) { + EXPECT_EQ(sock_addr_in6->sin6_family, AF_INET6); + EXPECT_TRUE(IN6_IS_ADDR_UNSPECIFIED(sock_addr_in6->sin6_addr.s6_addr32)) + << OperationToString(operation) << " getsocknam=" + << GetAddrStr(reinterpret_cast(&sock_addr)); + + EXPECT_NE(sock_addr_in6->sin6_port, 0); + } else if (IN6_IS_ADDR_V4MAPPED( + reinterpret_cast(addr_in) + ->sin6_addr.s6_addr32)) { + EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(sock_addr_in6->sin6_addr.s6_addr32)) + << OperationToString(operation) << " getsocknam=" + << GetAddrStr(reinterpret_cast(&sock_addr)); + } + } + + if (peername) { + sockaddr_storage peer_addr; + socklen_t addrlen = sizeof(peer_addr); + ASSERT_THAT(getpeername(fd.get(), reinterpret_cast(&peer_addr), + &addrlen), + SyscallSucceeds()); + ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6)); + + if (addr.family() == AF_INET || + IN6_IS_ADDR_V4MAPPED(reinterpret_cast(addr_in) + ->sin6_addr.s6_addr32)) { + EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED( + reinterpret_cast(&peer_addr) + ->sin6_addr.s6_addr32)) + << OperationToString(operation) << " getpeername=" + << GetAddrStr(reinterpret_cast(&peer_addr)); + } + } + } +} + +// TODO(gvisor.dev/issues/1556): uncomment V4MappedAny. +INSTANTIATE_TEST_SUITE_P( + All, DualStackSocketTest, + ::testing::Combine( + ::testing::Values(V4Any(), V4Loopback(), /*V4MappedAny(),*/ + V4MappedLoopback(), V6Any(), V6Loopback()), + ::testing::ValuesIn( + {{Operation::Bind, Operation::Connect, Operation::SendTo}, + {Operation::Bind, Operation::SendTo, Operation::Connect}, + {Operation::Connect, Operation::Bind, Operation::SendTo}, + {Operation::Connect, Operation::SendTo, Operation::Bind}, + {Operation::SendTo, Operation::Bind, Operation::Connect}, + {Operation::SendTo, Operation::Connect, Operation::Bind}})), + [](::testing::TestParamInfo< + std::tuple> const& info) { + const TestAddress& addr = std::get<0>(info.param); + const OperationSequence& operations = std::get<1>(info.param); + std::string s = addr.description; + for (const Operation& operation : operations) { + absl::StrAppend(&s, OperationToString(operation)); + } + return s; + }); + void tcpSimpleConnectTest(TestAddress const& listener, TestAddress const& connector, bool unbound) { // Create the listening socket. -- cgit v1.2.3 From 1c3d3c70b93d483894dd49fb444171347f0ca250 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Mon, 13 Jan 2020 14:54:32 -0800 Subject: Fix test building. --- pkg/tcpip/network/ip_test.go | 21 ++++++++++++++------- pkg/tcpip/network/ipv6/icmp_test.go | 2 +- pkg/tcpip/network/ipv6/ndp_test.go | 2 +- pkg/tcpip/stack/stack_test.go | 2 +- pkg/tcpip/transport/udp/udp_test.go | 10 ++++++++-- 5 files changed, 25 insertions(+), 12 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index f1bc33adf..f4d78f8c6 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -212,10 +212,17 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) } +func buildDummyStack() *stack.Stack { + return stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, + }) +} + func TestIPv4Send(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o) + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o, buildDummyStack()) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -250,7 +257,7 @@ func TestIPv4Send(t *testing.T) { func TestIPv4Receive(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil) + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -318,7 +325,7 @@ func TestIPv4ReceiveControl(t *testing.T) { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil) + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -385,7 +392,7 @@ func TestIPv4ReceiveControl(t *testing.T) { func TestIPv4FragmentationReceive(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil) + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -456,7 +463,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { func TestIPv6Send(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o) + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o, buildDummyStack()) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -491,7 +498,7 @@ func TestIPv6Send(t *testing.T) { func TestIPv6Receive(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil) + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack()) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } @@ -568,7 +575,7 @@ func TestIPv6ReceiveControl(t *testing.T) { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil) + ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack()) if err != nil { t.Fatalf("NewEndpoint failed: %v", err) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 335f634d5..a2fdc5dcd 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -109,7 +109,7 @@ func TestICMPCounts(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil) + ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) if err != nil { t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 0dbce14a0..fe895b376 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -62,7 +62,7 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil) + ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) if err != nil { t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 44e5229cc..cf41e02eb 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -200,7 +200,7 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ nicID: nicID, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 0a82bc4fa..d33507156 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1228,7 +1228,10 @@ func TestTTL(t *testing.T) { } else { p = ipv6.NewProtocol() } - ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) + ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + })) if err != nil { t.Fatal(err) } @@ -1261,7 +1264,10 @@ func TestSetTTL(t *testing.T) { } else { p = ipv6.NewProtocol() } - ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil) + ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + })) if err != nil { t.Fatal(err) } -- cgit v1.2.3 From 1ad8381eac108304f7b96162674624b34b95ec7b Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Mon, 13 Jan 2020 17:56:44 -0800 Subject: Do Source Address Selection when choosing an IPv6 source address Do Source Address Selection when choosing an IPv6 source address as per RFC 6724 section 5 rules 1-3: 1) Prefer same address 2) Prefer appropriate scope 3) Avoid deprecated addresses. A later change will update Source Address Selection to follow rules 4-8. Tests: Rule 1 & 2: stack.TestIPv6SourceAddressSelectionScopeAndSameAddress, Rule 3: stack.TestAutoGenAddrTimerDeprecation, stack.TestAutoGenAddrDeprecateFromPI PiperOrigin-RevId: 289559373 --- pkg/tcpip/header/ipv6.go | 43 ++++++++++++ pkg/tcpip/header/ipv6_test.go | 96 +++++++++++++++++++++++++- pkg/tcpip/stack/ndp_test.go | 22 ++++-- pkg/tcpip/stack/nic.go | 115 ++++++++++++++++++++++++++++++-- pkg/tcpip/stack/stack.go | 8 +-- pkg/tcpip/stack/stack_test.go | 152 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 420 insertions(+), 16 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 135a60b12..83425c614 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -333,6 +333,17 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool { return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } +// IsV6UniqueLocalAddress determines if the provided address is an IPv6 +// unique-local address (within the prefix FC00::/7). +func IsV6UniqueLocalAddress(addr tcpip.Address) bool { + if len(addr) != IPv6AddressSize { + return false + } + // According to RFC 4193 section 3.1, a unique local address has the prefix + // FC00::/7. + return (addr[0] & 0xfe) == 0xfc +} + // AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier // (IID) to buf as outlined by RFC 7217 and returns the extended buffer. // @@ -371,3 +382,35 @@ func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []by return tcpip.Address(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey)) } + +// IPv6AddressScope is the scope of an IPv6 address. +type IPv6AddressScope int + +const ( + // LinkLocalScope indicates a link-local address. + LinkLocalScope IPv6AddressScope = iota + + // UniqueLocalScope indicates a unique-local address. + UniqueLocalScope + + // GlobalScope indicates a global address. + GlobalScope +) + +// ScopeForIPv6Address returns the scope for an IPv6 address. +func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { + if len(addr) != IPv6AddressSize { + return GlobalScope, tcpip.ErrBadAddress + } + + switch { + case IsV6LinkLocalAddress(addr): + return LinkLocalScope, nil + + case IsV6UniqueLocalAddress(addr): + return UniqueLocalScope, nil + + default: + return GlobalScope, nil + } +} diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index 1994003ed..29f54bc57 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -25,7 +25,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -const linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") +const ( + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") +) func TestEthernetAdddressToModifiedEUI64(t *testing.T) { expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6} @@ -206,3 +212,91 @@ func TestLinkLocalAddrWithOpaqueIID(t *testing.T) { }) } } + +func TestIsV6UniqueLocalAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid Unique 1", + addr: uniqueLocalAddr1, + expected: true, + }, + { + name: "Valid Unique 2", + addr: uniqueLocalAddr1, + expected: true, + }, + { + name: "Link Local", + addr: linkLocalAddr, + expected: false, + }, + { + name: "Global", + addr: globalAddr, + expected: false, + }, + { + name: "IPv4", + addr: "\x01\x02\x03\x04", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + +func TestScopeForIPv6Address(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + scope header.IPv6AddressScope + err *tcpip.Error + }{ + { + name: "Unique Local", + addr: uniqueLocalAddr1, + scope: header.UniqueLocalScope, + err: nil, + }, + { + name: "Link Local", + addr: linkLocalAddr, + scope: header.LinkLocalScope, + err: nil, + }, + { + name: "Global", + addr: globalAddr, + scope: header.GlobalScope, + err: nil, + }, + { + name: "IPv4", + addr: "\x01\x02\x03\x04", + scope: header.GlobalScope, + err: tcpip.ErrBadAddress, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := header.ScopeForIPv6Address(test.addr) + if err != test.err { + t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (_, %v), want = (_, %v)", test.addr, err, test.err) + } + if got != test.scope { + t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope) + } + }) + } +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index f9bc18c55..d390c6312 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1732,9 +1732,11 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd return ndpDisp, e, s } -// addrForNewConnection returns the local address used when creating a new -// connection. -func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { +// addrForNewConnectionTo returns the local address used when creating a new +// connection to addr. +func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { + t.Helper() + wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) @@ -1748,8 +1750,8 @@ func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) } - if err := ep.Connect(dstAddr); err != nil { - t.Fatalf("ep.Connect(%+v): %s", dstAddr, err) + if err := ep.Connect(addr); err != nil { + t.Fatalf("ep.Connect(%+v): %s", addr, err) } got, err := ep.GetLocalAddress() if err != nil { @@ -1758,9 +1760,19 @@ func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { return got.Addr } +// addrForNewConnection returns the local address used when creating a new +// connection. +func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address { + t.Helper() + + return addrForNewConnectionTo(t, s, dstAddr) +} + // addrForNewConnectionWithAddr returns the local address used when creating a // new connection with a specific local address. func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address { + t.Helper() + wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.EventIn) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index fe557ccbd..abf73fe33 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -15,6 +15,8 @@ package stack import ( + "log" + "sort" "strings" "sync/atomic" @@ -251,13 +253,17 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -// primaryEndpoint returns the primary endpoint of n for the given network -// protocol. -// // primaryEndpoint will return the first non-deprecated endpoint if such an -// endpoint exists. If no non-deprecated endpoint exists, the first deprecated -// endpoint will be returned. -func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { +// endpoint exists for the given protocol and remoteAddr. If no non-deprecated +// endpoint exists, the first deprecated endpoint will be returned. +// +// If an IPv6 primary endpoint is requested, Source Address Selection (as +// defined by RFC 6724 section 5) will be performed. +func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint { + if protocol == header.IPv6ProtocolNumber && remoteAddr != "" { + return n.primaryIPv6Endpoint(remoteAddr) + } + n.mu.RLock() defer n.mu.RUnlock() @@ -296,6 +302,103 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return deprecatedEndpoint } +// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC +// 6724 section 5). +type ipv6AddrCandidate struct { + ref *referencedNetworkEndpoint + scope header.IPv6AddressScope +} + +// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address +// Selection (RFC 6724 section 5). +// +// Note, only rules 1-3 are followed. +// +// remoteAddr must be a valid IPv6 address. +func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { + n.mu.RLock() + defer n.mu.RUnlock() + + primaryAddrs := n.primary[header.IPv6ProtocolNumber] + + if len(primaryAddrs) == 0 { + return nil + } + + // Create a candidate set of available addresses we can potentially use as a + // source address. + cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs)) + for _, r := range primaryAddrs { + // If r is not valid for outgoing connections, it is not a valid endpoint. + if !r.isValidForOutgoing() { + continue + } + + addr := r.ep.ID().LocalAddress + scope, err := header.ScopeForIPv6Address(addr) + if err != nil { + // Should never happen as we got r from the primary IPv6 endpoint list and + // ScopeForIPv6Address only returns an error if addr is not an IPv6 + // address. + log.Fatalf("header.ScopeForIPv6Address(%s): %s", addr, err) + } + + cs = append(cs, ipv6AddrCandidate{ + ref: r, + scope: scope, + }) + } + + remoteScope, err := header.ScopeForIPv6Address(remoteAddr) + if err != nil { + // primaryIPv6Endpoint should never be called with an invalid IPv6 address. + log.Fatalf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err) + } + + // Sort the addresses as per RFC 6724 section 5 rules 1-3. + // + // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. + sort.Slice(cs, func(i, j int) bool { + sa := cs[i] + sb := cs[j] + + // Prefer same address as per RFC 6724 section 5 rule 1. + if sa.ref.ep.ID().LocalAddress == remoteAddr { + return true + } + if sb.ref.ep.ID().LocalAddress == remoteAddr { + return false + } + + // Prefer appropriate scope as per RFC 6724 section 5 rule 2. + if sa.scope < sb.scope { + return sa.scope >= remoteScope + } else if sb.scope < sa.scope { + return sb.scope < remoteScope + } + + // Avoid deprecated addresses as per RFC 6724 section 5 rule 3. + if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep { + // If sa is not deprecated, it is preferred over sb. + return sbDep + } + + // sa and sb are equal, return the endpoint that is closest to the front of + // the primary endpoint list. + return i < j + }) + + // Return the most preferred address that can have its reference count + // incremented. + for _, c := range cs { + if r := c.ref; r.tryIncRef() { + return r + } + } + + return nil +} + // hasPermanentAddrLocked returns true if n has a permanent (including currently // tentative) address, addr. func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 113b457fb..f8d89248e 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1106,9 +1106,9 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol return nic.primaryAddress(protocol), nil } -func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { +func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { if len(localAddr) == 0 { - return nic.primaryEndpoint(netProto) + return nic.primaryEndpoint(netProto, remoteAddr) } return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) } @@ -1124,7 +1124,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok { - if ref := s.getRefEP(nic, localAddr, netProto); ref != nil { + if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil } } @@ -1134,7 +1134,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n continue } if nic, ok := s.nics[route.NIC]; ok { - if ref := s.getRefEP(nic, localAddr, netProto); ref != nil { + if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { if len(remoteAddr) == 0 { // If no remote address was provided, then the route // provided will refer to the link local address. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 44e5229cc..4b3d18f1b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) const ( @@ -2411,3 +2412,154 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { } } } + +func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { + const ( + linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + nicID = 1 + ) + + // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test. + tests := []struct { + name string + nicAddrs []tcpip.Address + connectAddr tcpip.Address + expectedLocalAddr tcpip.Address + }{ + // Test Rule 1 of RFC 6724 section 5. + { + name: "Same Global most preferred (last address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: globalAddr1, + expectedLocalAddr: globalAddr1, + }, + { + name: "Same Global most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: globalAddr1, + expectedLocalAddr: globalAddr1, + }, + { + name: "Same Link Local most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalAddr1, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Same Link Local most preferred (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalAddr1, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Same Unique Local most preferred (last address)", + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, + connectAddr: uniqueLocalAddr1, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Same Unique Local most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: uniqueLocalAddr1, + expectedLocalAddr: uniqueLocalAddr1, + }, + + // Test Rule 2 of RFC 6724 section 5. + { + name: "Global most preferred (last address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: globalAddr2, + expectedLocalAddr: globalAddr1, + }, + { + name: "Global most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: globalAddr2, + expectedLocalAddr: globalAddr1, + }, + { + name: "Link Local most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local most preferred (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Unique Local most preferred (last address)", + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Unique Local most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + + // Test returning the endpoint that is closest to the front when + // candidate addresses are "equal" from the perspective of RFC 6724 + // section 5. + { + name: "Unique Local for Global", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2}, + connectAddr: globalAddr2, + expectedLocalAddr: uniqueLocalAddr1, + }, + { + name: "Link Local for Global", + nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, + connectAddr: globalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local for Unique Local", + nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, + connectAddr: uniqueLocalAddr2, + expectedLocalAddr: linkLocalAddr1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: llAddr3, + NIC: nicID, + }}) + s.AddLinkAddress(nicID, llAddr3, linkAddr3) + + for _, a := range test.nicAddrs { + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { + t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) + } + } + + if t.Failed() { + t.FailNow() + } + + if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr { + t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr) + } + }) + } +} -- cgit v1.2.3 From 50625cee59aaff834c7968771ab385ad0e7b0e1f Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Tue, 14 Jan 2020 13:31:52 -0800 Subject: Implement {g,s}etsockopt(IP_RECVTOS) for UDP sockets PiperOrigin-RevId: 289718534 --- pkg/sentry/socket/control/control.go | 2 +- pkg/sentry/socket/netstack/netstack.go | 36 +++++++++++++-- pkg/tcpip/checker/checker.go | 16 +++++++ pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack.go | 2 +- pkg/tcpip/tcpip.go | 8 +++- pkg/tcpip/transport/udp/endpoint.go | 40 +++++++++++++++-- pkg/tcpip/transport/udp/udp_test.go | 67 ++++++++++++++++++++++++---- test/syscalls/linux/socket_ip_udp_generic.cc | 40 +++++++++++++++++ test/syscalls/linux/udp_socket_test_cases.cc | 8 ++-- 10 files changed, 197 insertions(+), 24 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 4301b697c..1684dfc24 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -327,7 +327,7 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { } // PackTOS packs an IP_TOS socket control message. -func PackTOS(t *kernel.Task, tos int8, buf []byte) []byte { +func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IP, diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index c020c11cb..d2f7e987d 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1268,11 +1268,11 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - var o uint32 + var o int32 if v { o = 1 } - return int32(o), nil + return o, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1377,6 +1377,21 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } return int32(v), nil + case linux.IP_RECVTOS: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + var o int32 + if v { + o = 1 + } + return o, nil + default: emitUnimplementedEventIP(t, name) } @@ -1895,6 +1910,13 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) + case linux.IP_RECVTOS: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -1915,7 +1937,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, linux.IP_RECVORIGDSTADDR, - linux.IP_RECVTOS, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2335,7 +2356,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe } func (s *SocketOperations) controlMessages() socket.ControlMessages { - return socket.ControlMessages{IP: tcpip.ControlMessages{HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, Timestamp: s.readCM.Timestamp}} + return socket.ControlMessages{ + IP: tcpip.ControlMessages{ + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + }, + } } // updateTimestamp sets the timestamp for SIOCGSTAMP. It should be called after diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 2f15bf1f1..542abc99d 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -33,6 +33,9 @@ type NetworkChecker func(*testing.T, []header.Network) // TransportChecker is a function to check a property of a transport packet. type TransportChecker func(*testing.T, header.Transport) +// ControlMessagesChecker is a function to check a property of ancillary data. +type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages) + // IPv4 checks the validity and properties of the given IPv4 packet. It is // expected to be used in conjunction with other network checkers for specific // properties. For example, to check the source and destination address, one @@ -158,6 +161,19 @@ func FragmentFlags(flags uint8) NetworkChecker { } } +// ReceiveTOS creates a checker that checks the TOS field in ControlMessages. +func ReceiveTOS(want uint8) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasTOS { + t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want) + } + if got := cm.TOS; got != want { + t.Fatalf("got cm.TOS = %d, want %d", got, want) + } + } +} + // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index abf73fe33..071221d5a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -763,7 +763,7 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Unlock() } -// Subnets returns the Subnets associated with this NIC. +// AddressRanges returns the Subnets associated with this NIC. func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index f8d89248e..386eb6eec 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -912,7 +912,7 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return false } -// NICSubnets returns a map of NICIDs to their associated subnets. +// NICAddressRanges returns a map of NICIDs to their associated subnets. func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 4a090ac86..b7813cbc0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -322,7 +322,7 @@ type ControlMessages struct { HasTOS bool // TOS is the IPv4 type of service of the associated packet. - TOS int8 + TOS uint8 // HasTClass indicates whether Tclass is valid/set. HasTClass bool @@ -500,9 +500,13 @@ type WriteOptions struct { type SockOptBool int const ( + // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS + // ancillary message is passed with incoming packets. + ReceiveTOSOption SockOptBool = iota + // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. - V6OnlyOption SockOptBool = iota + V6OnlyOption ) // SockOptInt represents socket options which values have the int type. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 13446f5d9..c9cbed8f4 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -31,6 +31,7 @@ type udpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 + tos uint8 } // EndpointState represents the state of a UDP endpoint. @@ -113,6 +114,10 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 + // receiveTOS determines if the incoming IPv4 TOS header field is passed + // as ancillary data to ControlMessages on Read. + receiveTOS bool + // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -243,7 +248,18 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess *addr = p.senderAddress } - return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil + cm := tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: p.timestamp, + } + e.mu.RLock() + receiveTOS := e.receiveTOS + e.mu.RUnlock() + if receiveTOS { + cm.HasTOS = true + cm.TOS = p.tos + } + return p.data.ToView(), cm, nil } // prepareForWrite prepares the endpoint for sending data. In particular, it @@ -458,6 +474,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { switch opt { + case tcpip.ReceiveTOSOption: + e.mu.Lock() + e.receiveTOS = v + e.mu.Unlock() + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -664,15 +686,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { + case tcpip.ReceiveTOSOption: + e.mu.RLock() + v := e.receiveTOS + e.mu.RUnlock() + return v, nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { return false, tcpip.ErrUnknownProtocolOption } - e.mu.Lock() + e.mu.RLock() v := e.v6only - e.mu.Unlock() + e.mu.RUnlock() return v, nil } @@ -1215,6 +1243,12 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.rcvList.PushBack(packet) e.rcvBufSize += pkt.Data.Size() + // Save any useful information from the network header to the packet. + switch r.NetProto { + case header.IPv4ProtocolNumber: + packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() + } + packet.timestamp = e.stack.NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 0a82bc4fa..ee9d10555 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -56,6 +56,7 @@ const ( multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" broadcastAddr = header.IPv4Broadcast + testTOS = 0x80 // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -453,6 +454,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool ip := header.IPv4(buf) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, + TOS: testTOS, TotalLength: uint16(len(buf)), TTL: 65, Protocol: uint8(udp.ProtocolNumber), @@ -552,8 +554,8 @@ func TestBindToDeviceOption(t *testing.T) { // testReadInternal sends a packet of the given test flow into the stack by // injecting it into the link endpoint. It then attempts to read it from the // UDP endpoint and depending on if this was expected to succeed verifies its -// correctness. -func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) { +// correctness including any additional checker functions provided. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { c.t.Helper() payload := newPayload() @@ -568,12 +570,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() var addr tcpip.FullAddress - v, _, err := c.ep.Read(&addr) + v, cm, err := c.ep.Read(&addr) if err == tcpip.ErrWouldBlock { // Wait for data to become available. select { case <-ch: - v, _, err = c.ep.Read(&addr) + v, cm, err = c.ep.Read(&addr) case <-time.After(300 * time.Millisecond): if packetShouldBeDropped { @@ -606,15 +608,21 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe if !bytes.Equal(payload, v) { c.t.Fatalf("bad payload: got %x, want %x", v, payload) } + + // Run any checkers against the ControlMessages. + for _, f := range checkers { + f(c.t, cm) + } + c.checkEndpointReadStats(1, epstats, err) } // testRead sends a packet of the given test flow into the stack by injecting it // into the link endpoint. It then reads it from the UDP endpoint and verifies -// its correctness. -func testRead(c *testContext, flow testFlow) { +// its correctness including any additional checker functions provided. +func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { c.t.Helper() - testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */) + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) } // testFailingRead sends a packet of the given test flow into the stack by @@ -1282,7 +1290,7 @@ func TestTOSV4(t *testing.T) { c.createEndpointForFlow(flow) - const tos = 0xC0 + const tos = testTOS var v tcpip.IPv4TOSOption if err := c.ep.GetSockOpt(&v); err != nil { c.t.Errorf("GetSockopt failed: %s", err) @@ -1317,7 +1325,7 @@ func TestTOSV6(t *testing.T) { c.createEndpointForFlow(flow) - const tos = 0xC0 + const tos = testTOS var v tcpip.IPv6TrafficClassOption if err := c.ep.GetSockOpt(&v); err != nil { c.t.Errorf("GetSockopt failed: %s", err) @@ -1344,6 +1352,47 @@ func TestTOSV6(t *testing.T) { } } +func TestReceiveTOSV4(t *testing.T) { + for _, flow := range []testFlow{unicastV4, broadcast} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Verify that setting and reading the option works. + v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption) + if err != nil { + c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err) + } + // Test for expected default value. + if v != false { + c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false) + } + + want := true + if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil { + c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err) + } + + got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption) + if err != nil { + c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err) + } + if got != want { + c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want) + } + + // Verify that the correct received TOS is handed through as + // ancillary data to the ControlMessages struct. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + testRead(c, flow, checker.ReceiveTOS(testTOS)) + }) + } +} + func TestMulticastInterfaceOption(t *testing.T) { for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 66eb68857..53290bed7 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -209,6 +209,46 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { EXPECT_EQ(get, kSockOptOn); } +// Ensure that Receiving TOS is off by default. +TEST_P(UDPSocketPairTest, RecvTosDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +// Test that setting and getting IP_RECVTOS works as expected. +TEST_P(UDPSocketPairTest, SetRecvTos) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, + &kSockOptOff, sizeof(kSockOptOff)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); + + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); +} + TEST_P(UDPSocketPairTest, ReuseAddrDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index dc35c2f50..68e0a8109 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -1349,8 +1349,9 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { // outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SetAndReceiveTOS) { - // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. - SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); + // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack. + SKIP_IF((GetParam() != AddressFamily::kIpv4) && IsRunningOnGvisor() && + !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1421,7 +1422,8 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { // TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SendAndReceiveTOS) { - // TODO(b/68320120): IP_RECVTOS/IPV6_RECVTCLASS not supported for netstack. + // TODO(b/68320120): IPV6_RECVTCLASS not supported for netstack. + // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); -- cgit v1.2.3 From a611fdaee3c14abe2222140ae0a8a742ebfd31ab Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 14 Jan 2020 14:14:17 -0800 Subject: Changes TCP packet dispatch to use a pool of goroutines. All inbound segments for connections in ESTABLISHED state are delivered to the endpoint's queue but for every segment delivered we also queue the endpoint for processing to a selected processor. This ensures that when there are a large number of connections in ESTABLISHED state the inbound packets are all handled by a small number of goroutines and significantly reduces the amount of work the goscheduler has to perform. We let connections in other states follow the current path where the endpoint's goroutine directly handles the segments. Updates #231 PiperOrigin-RevId: 289728325 --- benchmarks/tcp/tcp_proxy.go | 6 +- pkg/sleep/sleep_test.go | 31 +++ pkg/tcpip/stack/transport_demuxer.go | 54 ++++- pkg/tcpip/transport/tcp/BUILD | 15 +- pkg/tcpip/transport/tcp/accept.go | 9 +- pkg/tcpip/transport/tcp/connect.go | 310 ++++++++++++++++------------ pkg/tcpip/transport/tcp/dispatcher.go | 218 +++++++++++++++++++ pkg/tcpip/transport/tcp/endpoint.go | 303 ++++++++++++++++++--------- pkg/tcpip/transport/tcp/endpoint_state.go | 30 +-- pkg/tcpip/transport/tcp/protocol.go | 11 + pkg/tcpip/transport/tcp/rcv.go | 21 +- pkg/tcpip/transport/tcp/snd.go | 14 +- pkg/tcpip/transport/tcp/tcp_test.go | 11 +- test/syscalls/linux/socket_inet_loopback.cc | 2 +- test/syscalls/linux/tcp_socket.cc | 14 ++ 15 files changed, 769 insertions(+), 280 deletions(-) create mode 100644 pkg/tcpip/transport/tcp/dispatcher.go (limited to 'pkg/tcpip/stack') diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go index be0d7bdd6..dc96add66 100644 --- a/benchmarks/tcp/tcp_proxy.go +++ b/benchmarks/tcp/tcp_proxy.go @@ -85,7 +85,7 @@ func (netImpl) printStats() { const ( nicID = 1 // Fixed. - rcvBufSize = 1 << 20 // 1MB. + rcvBufSize = 4 << 20 // 1MB. ) type netstackImpl struct { @@ -130,6 +130,10 @@ func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) { return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err) } + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, rcvBufSize); err != nil { + return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", rcvBufSize, err) + } + if !*swgso && *gso != 0 { if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go index 130806c86..af47e2ba1 100644 --- a/pkg/sleep/sleep_test.go +++ b/pkg/sleep/sleep_test.go @@ -376,6 +376,37 @@ func TestRace(t *testing.T) { } } +// TestRaceInOrder tests that multiple wakers can continuously send wake requests to +// the sleeper and that the wakers are retrieved in the order asserted. +func TestRaceInOrder(t *testing.T) { + const wakers = 100 + const wakeRequests = 10000 + + w := make([]Waker, wakers) + s := Sleeper{} + + // Associate each waker and start goroutines that will assert them. + for i := range w { + s.AddWaker(&w[i], i) + } + go func() { + n := 0 + for n < wakeRequests { + wk := w[n%len(w)] + wk.Assert() + n++ + } + }() + + // Wait for all wake up notifications from all wakers. + for i := 0; i < wakeRequests; i++ { + v, _ := s.Fetch(true) + if got, want := v, i%wakers; got != want { + t.Fatalf("got %d want %d", got, want) + } + } +} + // BenchmarkSleeperMultiSelect measures how long it takes to fetch a wake up // from 4 wakers when at least one is already asserted. func BenchmarkSleeperMultiSelect(b *testing.B) { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index f384a91de..d686e6eb8 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -104,7 +104,14 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p return } // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, pkt) + transEP := selectEndpoint(id, mpep, epsByNic.seed) + if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { + queuedProtocol.QueuePacket(r, transEP, id, pkt) + epsByNic.mu.RUnlock() + return + } + + transEP.HandlePacket(r, id, pkt) epsByNic.mu.RUnlock() // Don't use defer for performance reasons. } @@ -130,7 +137,7 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { epsByNic.mu.Lock() defer epsByNic.mu.Unlock() @@ -140,7 +147,7 @@ func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort } // This is a new binding. - multiPortEp := &multiPortEndpoint{} + multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto} multiPortEp.endpointsMap = make(map[TransportEndpoint]int) multiPortEp.reuse = reusePort epsByNic.endpoints[bindToDevice] = multiPortEp @@ -168,18 +175,34 @@ func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t T // newTransportDemuxer. type transportDemuxer struct { // protocol is immutable. - protocol map[protocolIDs]*transportEndpoints + protocol map[protocolIDs]*transportEndpoints + queuedProtocols map[protocolIDs]queuedTransportProtocol +} + +// queuedTransportProtocol if supported by a protocol implementation will cause +// the dispatcher to delivery packets to the QueuePacket method instead of +// calling HandlePacket directly on the endpoint. +type queuedTransportProtocol interface { + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { - d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)} + d := &transportDemuxer{ + protocol: make(map[protocolIDs]*transportEndpoints), + queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), + } // Add each network and transport pair to the demuxer. for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { - d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ + protoIDs := protocolIDs{netProto, proto} + d.protocol[protoIDs] = &transportEndpoints{ endpoints: make(map[TransportEndpointID]*endpointsByNic), } + qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol) + if isQueued { + d.queuedProtocols[protoIDs] = qTransProto + } } } @@ -209,7 +232,11 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum // // +stateify savable type multiPortEndpoint struct { - mu sync.RWMutex `state:"nosave"` + mu sync.RWMutex `state:"nosave"` + demux *transportDemuxer + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int // reuse indicates if more than one endpoint is allowed. @@ -258,13 +285,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { ep.mu.RLock() + queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] for i, endpoint := range ep.endpointsArr { // HandlePacket takes ownership of pkt, so each endpoint needs // its own copy except for the final one. if i == len(ep.endpointsArr)-1 { + if mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt) + break + } endpoint.HandlePacket(r, id, pkt) break } + if mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) + continue + } endpoint.HandlePacket(r, id, pkt.Clone()) } ep.mu.RUnlock() // Don't use defer for performance reasons. @@ -357,7 +393,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol if epsByNic, ok := eps.endpoints[id]; ok { // There was already a binding. - return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) + return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // This is a new binding. @@ -367,7 +403,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol } eps.endpoints[id] = epsByNic - return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) + return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 353bd06f4..0e3ab05ad 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -16,6 +16,18 @@ go_template_instance( }, ) +go_template_instance( + name = "tcp_endpoint_list", + out = "tcp_endpoint_list.go", + package = "tcp", + prefix = "endpoint", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*endpoint", + "Linker": "*endpoint", + }, +) + go_library( name = "tcp", srcs = [ @@ -23,6 +35,7 @@ go_library( "connect.go", "cubic.go", "cubic_state.go", + "dispatcher.go", "endpoint.go", "endpoint_state.go", "forwarder.go", @@ -38,6 +51,7 @@ go_library( "segment_state.go", "snd.go", "snd_state.go", + "tcp_endpoint_list.go", "tcp_segment_list.go", "timer.go", ], @@ -45,7 +59,6 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ - "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 1ea996936..1a2e3efa9 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -285,7 +285,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // listenEP is nil when listenContext is used by tcp.Forwarder. if l.listenEP != nil { l.listenEP.mu.Lock() - if l.listenEP.state != StateListen { + if l.listenEP.EndpointState() != StateListen { l.listenEP.mu.Unlock() return nil, tcpip.ErrConnectionAborted } @@ -344,11 +344,12 @@ func (l *listenContext) closeAllPendingEndpoints() { // instead. func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.Lock() - state := e.state + state := e.EndpointState() e.pendingAccepted.Add(1) defer e.pendingAccepted.Done() acceptedChan := e.acceptedChan e.mu.Unlock() + if state == StateListen { acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) @@ -562,8 +563,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // We do not use transitionToStateEstablishedLocked here as there is // no handshake state available when doing a SYN cookie based accept. n.stack.Stats().TCP.CurrentEstablished.Increment() - n.state = StateEstablished n.isConnectNotified = true + n.setEndpointState(StateEstablished) // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -596,7 +597,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // handleSynSegment() from attempting to queue new connections // to the endpoint. e.mu.Lock() - e.state = StateClose + e.setEndpointState(StateClose) // close any endpoints in SYN-RCVD state. ctx.closeAllPendingEndpoints() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 613ec1775..f3896715b 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -190,7 +190,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.mss = opts.MSS h.sndWndScale = opts.WS h.ep.mu.Lock() - h.ep.state = StateSynRecv + h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() } @@ -274,14 +274,14 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // SYN-RCVD state. h.state = handshakeSynRcvd h.ep.mu.Lock() - h.ep.state = StateSynRecv ttl := h.ep.ttl + h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: int(h.effectiveRcvWndScale()), TS: rcvSynOpts.TS, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), // We only send SACKPermitted if the other side indicated it // permits SACK. This is not explicitly defined in the RFC but @@ -341,7 +341,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { WS: h.rcvWndScale, TS: h.ep.sendTSOk, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } @@ -501,7 +501,7 @@ func (h *handshake) execute() *tcpip.Error { WS: h.rcvWndScale, TS: true, TSVal: h.ep.timestamp(), - TSEcr: h.ep.recentTS, + TSEcr: h.ep.recentTimestamp(), SACKPermitted: bool(sackEnabled), MSS: h.ep.amss, } @@ -792,7 +792,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // Ref: https://tools.ietf.org/html/rfc7323#section-5.4. offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:]) + offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:]) } if e.sackPermitted && len(sackBlocks) > 0 { offset += header.EncodeNOP(options[offset:]) @@ -811,7 +811,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -848,6 +848,9 @@ func (e *endpoint) handleWrite() *tcpip.Error { } func (e *endpoint) handleClose() *tcpip.Error { + if !e.EndpointState().connected() { + return nil + } // Drain the send queue. e.handleWrite() @@ -864,11 +867,7 @@ func (e *endpoint) handleClose() *tcpip.Error { func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. - if e.state == StateEstablished || e.state == StateCloseWait { - e.stack.Stats().TCP.EstablishedResets.Increment() - e.stack.Stats().TCP.CurrentEstablished.Decrement() - } - e.state = StateError + e.setEndpointState(StateError) e.HardError = err if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { // The exact sequence number to be used for the RST is the same as the @@ -888,9 +887,12 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { } // completeWorkerLocked is called by the worker goroutine when it's about to -// exit. It marks the worker as completed and performs cleanup work if requested -// by Close(). +// exit. func (e *endpoint) completeWorkerLocked() { + // Worker is terminating(either due to moving to + // CLOSED or ERROR state, ensure we release all + // registrations port reservations even if the socket + // itself is not yet closed by the application. e.workerRunning = false if e.workerCleanup { e.cleanupLocked() @@ -917,8 +919,7 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { e.rcvAutoParams.prevCopied = int(h.rcvWnd) e.rcvListMu.Unlock() } - h.ep.stack.Stats().TCP.CurrentEstablished.Increment() - e.state = StateEstablished + e.setEndpointState(StateEstablished) } // transitionToStateCloseLocked ensures that the endpoint is @@ -927,11 +928,12 @@ func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { // delivered to this endpoint from the demuxer when the endpoint // is transitioned to StateClose. func (e *endpoint) transitionToStateCloseLocked() { - if e.state == StateClose { + if e.EndpointState() == StateClose { return } + // Mark the endpoint as fully closed for reads/writes. e.cleanupLocked() - e.state = StateClose + e.setEndpointState(StateClose) e.stack.Stats().TCP.EstablishedClosed.Increment() } @@ -946,7 +948,9 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { s.decRef() return } - ep.(*endpoint).enqueueSegment(s) + if ep.(*endpoint).enqueueSegment(s) { + ep.(*endpoint).newSegmentWaker.Assert() + } } func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { @@ -955,9 +959,8 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // except SYN-SENT, all reset (RST) segments are // validated by checking their SEQ-fields." So // we only process it if it's acceptable. - s.decRef() e.mu.Lock() - switch e.state { + switch e.EndpointState() { // In case of a RST in CLOSE-WAIT linux moves // the socket to closed state with an error set // to indicate EPIPE. @@ -981,103 +984,57 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { e.transitionToStateCloseLocked() e.HardError = tcpip.ErrAborted e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: e.mu.Unlock() + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + + // Notify protocol goroutine. This is required when + // handleSegment is invoked from the processor goroutine + // rather than the worker goroutine. + e.notifyProtocolGoroutine(notifyResetByPeer) return false, tcpip.ErrConnectionReset } } return true, nil } -// handleSegments pulls segments from the queue and processes them. It returns -// no error if the protocol loop should continue, an error otherwise. -func (e *endpoint) handleSegments() *tcpip.Error { +// handleSegments processes all inbound segments. +func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { + if e.EndpointState() == StateClose || e.EndpointState() == StateError { + return nil + } s := e.segmentQueue.dequeue() if s == nil { checkRequeue = false break } - // Invoke the tcp probe if installed. - if e.probe != nil { - e.probe(e.completeState()) + cont, err := e.handleSegment(s) + if err != nil { + s.decRef() + e.mu.Lock() + e.setEndpointState(StateError) + e.HardError = err + e.mu.Unlock() + return err } - - if s.flagIsSet(header.TCPFlagRst) { - if ok, err := e.handleReset(s); !ok { - return err - } - } else if s.flagIsSet(header.TCPFlagSyn) { - // See: https://tools.ietf.org/html/rfc5961#section-4.1 - // 1) If the SYN bit is set, irrespective of the sequence number, TCP - // MUST send an ACK (also referred to as challenge ACK) to the remote - // peer: - // - // - // - // After sending the acknowledgment, TCP MUST drop the unacceptable - // segment and stop processing further. - // - // By sending an ACK, the remote peer is challenged to confirm the loss - // of the previous connection and the request to start a new connection. - // A legitimate peer, after restart, would not have a TCB in the - // synchronized state. Thus, when the ACK arrives, the peer should send - // a RST segment back with the sequence number derived from the ACK - // field that caused the RST. - - // This RST will confirm that the remote peer has indeed closed the - // previous connection. Upon receipt of a valid RST, the local TCP - // endpoint MUST terminate its connection. The local TCP endpoint - // should then rely on SYN retransmission from the remote end to - // re-establish the connection. - - e.snd.sendAck() - } else if s.flagIsSet(header.TCPFlagAck) { - // Patch the window size in the segment according to the - // send window scale. - s.window <<= e.snd.sndWndScale - - // RFC 793, page 41 states that "once in the ESTABLISHED - // state all segments must carry current acknowledgment - // information." - drop, err := e.rcv.handleRcvdSegment(s) - if err != nil { - s.decRef() - return err - } - if drop { - s.decRef() - continue - } - - // Now check if the received segment has caused us to transition - // to a CLOSED state, if yes then terminate processing and do - // not invoke the sender. - e.mu.RLock() - state := e.state - e.mu.RUnlock() - if state == StateClose { - // When we get into StateClose while processing from the queue, - // return immediately and let the protocolMainloop handle it. - // - // We can reach StateClose only while processing a previous segment - // or a notification from the protocolMainLoop (caller goroutine). - // This means that with this return, the segment dequeue below can - // never occur on a closed endpoint. - s.decRef() - return nil - } - e.snd.handleRcvdSegment(s) + if !cont { + s.decRef() + return nil } - s.decRef() } - // If the queue is not empty, make sure we'll wake up in the next - // iteration. - if checkRequeue && !e.segmentQueue.empty() { + // When fastPath is true we don't want to wake up the worker + // goroutine. If the endpoint has more segments to process the + // dispatcher will call handleSegments again anyway. + if !fastPath && checkRequeue && !e.segmentQueue.empty() { e.newSegmentWaker.Assert() } @@ -1086,11 +1043,88 @@ func (e *endpoint) handleSegments() *tcpip.Error { e.snd.sendAck() } - e.resetKeepaliveTimer(true) + e.resetKeepaliveTimer(true /* receivedData */) return nil } +// handleSegment handles a given segment and notifies the worker goroutine if +// if the connection should be terminated. +func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { + // Invoke the tcp probe if installed. + if e.probe != nil { + e.probe(e.completeState()) + } + + if s.flagIsSet(header.TCPFlagRst) { + if ok, err := e.handleReset(s); !ok { + return false, err + } + } else if s.flagIsSet(header.TCPFlagSyn) { + // See: https://tools.ietf.org/html/rfc5961#section-4.1 + // 1) If the SYN bit is set, irrespective of the sequence number, TCP + // MUST send an ACK (also referred to as challenge ACK) to the remote + // peer: + // + // + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() + } else if s.flagIsSet(header.TCPFlagAck) { + // Patch the window size in the segment according to the + // send window scale. + s.window <<= e.snd.sndWndScale + + // RFC 793, page 41 states that "once in the ESTABLISHED + // state all segments must carry current acknowledgment + // information." + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + return false, err + } + if drop { + return true, nil + } + + // Now check if the received segment has caused us to transition + // to a CLOSED state, if yes then terminate processing and do + // not invoke the sender. + e.mu.RLock() + state := e.state + e.mu.RUnlock() + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + s.decRef() + return false, nil + } + + e.snd.handleRcvdSegment(s) + } + + return true, nil +} + // keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP // keepalive packets periodically when the connection is idle. If we don't hear // from the other side after a number of tries, we terminate the connection. @@ -1160,7 +1194,7 @@ func (e *endpoint) disableKeepaliveTimer() { // protocolMainLoop is the main loop of the TCP protocol. It runs in its own // goroutine and is responsible for sending segments and handling received // segments. -func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error { var closeTimer *time.Timer var closeWaker sleep.Waker @@ -1182,6 +1216,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.mu.Unlock() + e.workMu.Unlock() // When the protocol loop exits we should wake up our waiters. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -1193,7 +1228,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { initialRcvWnd := e.initialReceiveWindow() h := newHandshake(e, seqnum.Size(initialRcvWnd)) e.mu.Lock() - h.ep.state = StateSynSent + h.ep.setEndpointState(StateSynSent) e.mu.Unlock() if err := h.execute(); err != nil { @@ -1202,12 +1237,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.lastErrorMu.Unlock() e.mu.Lock() - e.state = StateError + e.setEndpointState(StateError) e.HardError = err // Lock released below. epilogue() - return err } } @@ -1215,7 +1249,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.keepalive.timer.init(&e.keepalive.waker) defer e.keepalive.timer.cleanup() - // Tell waiters that the endpoint is connected and writable. e.mu.Lock() drained := e.drainDone != nil e.mu.Unlock() @@ -1224,8 +1257,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { <-e.undrain } - e.waiterQueue.Notify(waiter.EventOut) - // Set up the functions that will be called when the main protocol loop // wakes up. funcs := []struct { @@ -1240,18 +1271,15 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { w: &e.sndCloseWaker, f: e.handleClose, }, - { - w: &e.newSegmentWaker, - f: e.handleSegments, - }, { w: &closeWaker, f: func() *tcpip.Error { // This means the socket is being closed due - // to the TCP_FIN_WAIT2 timeout was hit. Just + // to the TCP-FIN-WAIT2 timeout was hit. Just // mark the socket as closed. e.mu.Lock() e.transitionToStateCloseLocked() + e.workerCleanup = true e.mu.Unlock() return nil }, @@ -1266,6 +1294,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return nil }, }, + { + w: &e.newSegmentWaker, + f: func() *tcpip.Error { + return e.handleSegments(false /* fastPath */) + }, + }, { w: &e.keepalive.waker, f: e.keepaliveTimerExpired, @@ -1293,14 +1327,16 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } if n¬ifyReset != 0 { - e.mu.Lock() - e.resetConnectionLocked(tcpip.ErrConnectionAborted) - e.mu.Unlock() + return tcpip.ErrConnectionAborted + } + + if n¬ifyResetByPeer != 0 { + return tcpip.ErrConnectionReset } if n¬ifyClose != 0 && closeTimer == nil { e.mu.Lock() - if e.state == StateFinWait2 && e.closed { + if e.EndpointState() == StateFinWait2 && e.closed { // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { @@ -1320,11 +1356,11 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { - if err := e.handleSegments(); err != nil { + if err := e.handleSegments(false /* fastPath */); err != nil { return err } } - if e.state != StateClose && e.state != StateError { + if e.EndpointState() != StateClose && e.EndpointState() != StateError { // Only block the worker if the endpoint // is not in closed state or error state. close(e.drainDone) @@ -1349,14 +1385,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { s.AddWaker(funcs[i].w, i) } + // Notify the caller that the waker initialization is complete and the + // endpoint is ready. + if wakerInitDone != nil { + close(wakerInitDone) + } + + // Tell waiters that the endpoint is connected and writable. + e.waiterQueue.Notify(waiter.EventOut) + // The following assertions and notifications are needed for restored // endpoints. Fresh newly created endpoints have empty states and should // not invoke any. - e.segmentQueue.mu.Lock() - if !e.segmentQueue.list.Empty() { + if !e.segmentQueue.empty() { e.newSegmentWaker.Assert() } - e.segmentQueue.mu.Unlock() e.rcvListMu.Lock() if !e.rcvList.Empty() { @@ -1372,27 +1415,32 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Main loop. Handle segments until both send and receive ends of the // connection have completed. - for e.state != StateTimeWait && e.state != StateClose && e.state != StateError { + for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError { e.mu.Unlock() e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() + // We need to double check here because the notification + // maybe stale by the time we got around to processing it. + // NOTE: since we now hold the workMu the processors cannot + // change the state of the endpoint so it' safe to proceed + // after this check. + if e.EndpointState() == StateTimeWait || e.EndpointState() == StateClose || e.EndpointState() == StateError { + e.mu.Lock() + break + } if err := funcs[v].f(); err != nil { e.mu.Lock() - // Ensure we release all endpoint registration and route - // references as the connection is now in an error - // state. e.workerCleanup = true e.resetConnectionLocked(err) // Lock released below. epilogue() - return nil } e.mu.Lock() } - state := e.state + state := e.EndpointState() e.mu.Unlock() var reuseTW func() if state == StateTimeWait { @@ -1405,13 +1453,15 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { s.Done() // Wake up any waiters before we enter TIME_WAIT. e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.mu.Lock() + e.workerCleanup = true + e.mu.Unlock() reuseTW = e.doTimeWait() } // Mark endpoint as closed. e.mu.Lock() - if e.state != StateError { - e.stack.Stats().TCP.CurrentEstablished.Decrement() + if e.EndpointState() != StateError { e.transitionToStateCloseLocked() } @@ -1468,7 +1518,11 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() tcpEP := listenEP.(*endpoint) if EndpointState(tcpEP.State()) == StateListen { reuseTW = func() { - tcpEP.enqueueSegment(s) + if !tcpEP.enqueueSegment(s) { + s.decRef() + return + } + tcpEP.newSegmentWaker.Assert() } // We explicitly do not decRef // the segment as it's still diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go new file mode 100644 index 000000000..a72f0c379 --- /dev/null +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -0,0 +1,218 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// epQueue is a queue of endpoints. +type epQueue struct { + mu sync.Mutex + list endpointList +} + +// enqueue adds e to the queue if the endpoint is not already on the queue. +func (q *epQueue) enqueue(e *endpoint) { + q.mu.Lock() + if e.pendingProcessing { + q.mu.Unlock() + return + } + q.list.PushBack(e) + e.pendingProcessing = true + q.mu.Unlock() +} + +// dequeue removes and returns the first element from the queue if available, +// returns nil otherwise. +func (q *epQueue) dequeue() *endpoint { + q.mu.Lock() + if e := q.list.Front(); e != nil { + q.list.Remove(e) + e.pendingProcessing = false + q.mu.Unlock() + return e + } + q.mu.Unlock() + return nil +} + +// empty returns true if the queue is empty, false otherwise. +func (q *epQueue) empty() bool { + q.mu.Lock() + v := q.list.Empty() + q.mu.Unlock() + return v +} + +// processor is responsible for processing packets queued to a tcp endpoint. +type processor struct { + epQ epQueue + newEndpointWaker sleep.Waker + id int +} + +func newProcessor(id int) *processor { + p := &processor{ + id: id, + } + go p.handleSegments() + return p +} + +func (p *processor) queueEndpoint(ep *endpoint) { + // Queue an endpoint for processing by the processor goroutine. + p.epQ.enqueue(ep) + p.newEndpointWaker.Assert() +} + +func (p *processor) handleSegments() { + const newEndpointWaker = 1 + s := sleep.Sleeper{} + s.AddWaker(&p.newEndpointWaker, newEndpointWaker) + defer s.Done() + for { + s.Fetch(true) + for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { + if ep.segmentQueue.empty() { + continue + } + + // If socket has transitioned out of connected state + // then just let the worker handle the packet. + // + // NOTE: We read this outside of e.mu lock which means + // that by the time we get to handleSegments the + // endpoint may not be in ESTABLISHED. But this should + // be fine as all normal shutdown states are handled by + // handleSegments and if the endpoint moves to a + // CLOSED/ERROR state then handleSegments is a noop. + if ep.EndpointState() != StateEstablished { + ep.newSegmentWaker.Assert() + continue + } + + if !ep.workMu.TryLock() { + ep.newSegmentWaker.Assert() + continue + } + // If the endpoint is in a connected state then we do + // direct delivery to ensure low latency and avoid + // scheduler interactions. + if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose { + ep.notifyProtocolGoroutine(notifyTickleWorker) + ep.workMu.Unlock() + continue + } + + if !ep.segmentQueue.empty() { + p.epQ.enqueue(ep) + } + + ep.workMu.Unlock() + } + } +} + +// dispatcher manages a pool of TCP endpoint processors which are responsible +// for the processing of inbound segments. This fixed pool of processor +// goroutines do full tcp processing. The processor is selected based on the +// hash of the endpoint id to ensure that delivery for the same endpoint happens +// in-order. +type dispatcher struct { + processors []*processor + seed uint32 +} + +func newDispatcher(nProcessors int) *dispatcher { + processors := []*processor{} + for i := 0; i < nProcessors; i++ { + processors = append(processors, newProcessor(i)) + } + return &dispatcher{ + processors: processors, + seed: generateRandUint32(), + } +} + +func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + ep := stackEP.(*endpoint) + s := newSegment(r, id, pkt) + if !s.parse() { + ep.stack.Stats().MalformedRcvdPackets.Increment() + ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + s.decRef() + return + } + + if !s.csumValid { + ep.stack.Stats().MalformedRcvdPackets.Increment() + ep.stack.Stats().TCP.ChecksumErrors.Increment() + ep.stats.ReceiveErrors.ChecksumErrors.Increment() + s.decRef() + return + } + + ep.stack.Stats().TCP.ValidSegmentsReceived.Increment() + ep.stats.SegmentsReceived.Increment() + if (s.flags & header.TCPFlagRst) != 0 { + ep.stack.Stats().TCP.ResetsReceived.Increment() + } + + if !ep.enqueueSegment(s) { + s.decRef() + return + } + + // For sockets not in established state let the worker goroutine + // handle the packets. + if ep.EndpointState() != StateEstablished { + ep.newSegmentWaker.Assert() + return + } + + d.selectProcessor(id).queueEndpoint(ep) +} + +func generateRandUint32() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { + payload := []byte{ + byte(id.LocalPort), + byte(id.LocalPort >> 8), + byte(id.RemotePort), + byte(id.RemotePort >> 8)} + + h := jenkins.Sum32(d.seed) + h.Write(payload) + h.Write([]byte(id.LocalAddress)) + h.Write([]byte(id.RemoteAddress)) + + return d.processors[h.Sum32()%uint32(len(d.processors))] +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index cc8b533c8..1799c6e10 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -120,6 +120,7 @@ const ( notifyMTUChanged notifyDrain notifyReset + notifyResetByPeer notifyKeepaliveChanged notifyMSSChanged // notifyTickleWorker is used to tickle the protocol main loop during a @@ -127,6 +128,7 @@ const ( // ensures the loop terminates if the final state of the endpoint is // say TIME_WAIT. notifyTickleWorker + notifyError ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -283,6 +285,18 @@ func (*EndpointInfo) IsEndpointInfo() {} type endpoint struct { EndpointInfo + // endpointEntry is used to queue endpoints for processing to the + // a given tcp processor goroutine. + // + // Precondition: epQueue.mu must be held to read/write this field.. + endpointEntry `state:"nosave"` + + // pendingProcessing is true if this endpoint is queued for processing + // to a TCP processor. + // + // Precondition: epQueue.mu must be held to read/write this field.. + pendingProcessing bool `state:"nosave"` + // workMu is used to arbitrate which goroutine may perform protocol // work. Only the main protocol goroutine is expected to call Lock() on // it, but other goroutines (e.g., send) may call TryLock() to eagerly @@ -324,6 +338,7 @@ type endpoint struct { // The following fields are protected by the mutex. mu sync.RWMutex `state:"nosave"` + // state must be read/set using the EndpointState()/setEndpointState() methods. state EndpointState `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the @@ -359,7 +374,7 @@ type endpoint struct { workerRunning bool // workerCleanup specifies if the worker goroutine must perform cleanup - // before exitting. This can only be set to true when workerRunning is + // before exiting. This can only be set to true when workerRunning is // also true, and they're both protected by the mutex. workerCleanup bool @@ -371,6 +386,8 @@ type endpoint struct { // recentTS is the timestamp that should be sent in the TSEcr field of // the timestamp for future segments sent by the endpoint. This field is // updated if required when a new segment is received by this endpoint. + // + // recentTS must be read/written atomically. recentTS uint32 // tsOffset is a randomized offset added to the value of the @@ -567,6 +584,47 @@ func (e *endpoint) ResumeWork() { e.workMu.Unlock() } +// setEndpointState updates the state of the endpoint to state atomically. This +// method is unexported as the only place we should update the state is in this +// package but we allow the state to be read freely without holding e.mu. +// +// Precondition: e.mu must be held to call this method. +func (e *endpoint) setEndpointState(state EndpointState) { + oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + switch state { + case StateEstablished: + e.stack.Stats().TCP.CurrentEstablished.Increment() + case StateError: + fallthrough + case StateClose: + if oldstate == StateCloseWait || oldstate == StateEstablished { + e.stack.Stats().TCP.EstablishedResets.Increment() + } + fallthrough + default: + if oldstate == StateEstablished { + e.stack.Stats().TCP.CurrentEstablished.Decrement() + } + } + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) +} + +// EndpointState returns the current state of the endpoint. +func (e *endpoint) EndpointState() EndpointState { + return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) +} + +// setRecentTimestamp atomically sets the recentTS field to the +// provided value. +func (e *endpoint) setRecentTimestamp(recentTS uint32) { + atomic.StoreUint32(&e.recentTS, recentTS) +} + +// recentTimestamp atomically reads and returns the value of the recentTS field. +func (e *endpoint) recentTimestamp() uint32 { + return atomic.LoadUint32(&e.recentTS) +} + // keepalive is a synchronization wrapper used to appease stateify. See the // comment in endpoint, where it is used. // @@ -656,7 +714,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { e.mu.RLock() defer e.mu.RUnlock() - switch e.state { + switch e.EndpointState() { case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. @@ -672,7 +730,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } } } - if e.state.connected() { + if e.EndpointState().connected() { // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() @@ -733,14 +791,20 @@ func (e *endpoint) Close() { // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) + e.closeNoShutdown() +} +// closeNoShutdown closes the endpoint without doing a full shutdown. This is +// used when a connection needs to be aborted with a RST and we want to skip +// a full 4 way TCP shutdown. +func (e *endpoint) closeNoShutdown() { e.mu.Lock() // For listening sockets, we always release ports inline so that they // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail // in Listen() when trying to register. - if e.state == StateListen && e.isPortReserved { + if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) e.isRegistered = false @@ -780,6 +844,8 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { defer close(done) for n := range e.acceptedChan { n.notifyProtocolGoroutine(notifyReset) + // close all connections that have completed but + // not accepted by the application. n.Close() } }() @@ -797,11 +863,13 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() { // after Close() is called and the worker goroutine (if any) is done with its // work. func (e *endpoint) cleanupLocked() { + // Close all endpoints that might have been accepted by TCP but not by // the client. if e.acceptedChan != nil { e.closePendingAcceptableConnectionsLocked() } + e.workerCleanup = false if e.isRegistered { @@ -920,7 +988,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() bufUsed := e.rcvBufUsed - if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { + if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.HardError e.mu.RUnlock() @@ -944,7 +1012,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.state.connected() { + if e.rcvClosed || !e.EndpointState().connected() { return buffer.View{}, tcpip.ErrClosedForReceive } return buffer.View{}, tcpip.ErrWouldBlock @@ -980,8 +1048,8 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { // The endpoint cannot be written to if it's not connected. - if !e.state.connected() { - switch e.state { + if !e.EndpointState().connected() { + switch e.EndpointState() { case StateError: return 0, e.HardError default: @@ -1039,42 +1107,86 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, perr } - if !opts.Atomic { // See above. - e.mu.RLock() - e.sndBufMu.Lock() + if opts.Atomic { + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() + } - // Because we released the lock before copying, check state again - // to make sure the endpoint is still in a valid state for a write. - avail, err = e.isEndpointWritableLocked() - if err != nil { + if e.workMu.TryLock() { + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. e.mu.RUnlock() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err - } - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] } - } - - // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - // Release the endpoint lock to prevent deadlocks due to lock - // order inversion when acquiring workMu. - e.mu.RUnlock() - - if e.workMu.TryLock() { // Do the work inline. e.handleWrite() e.workMu.Unlock() } else { + if !opts.Atomic { // See above. + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() + + } // Let the protocol goroutine do the work. e.sndWaker.Assert() } @@ -1091,7 +1203,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. - if s := e.state; !s.connected() && s != StateClose { + if s := e.EndpointState(); !s.connected() && s != StateClose { if s == StateError { return 0, tcpip.ControlMessages{}, e.HardError } @@ -1103,7 +1215,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro defer e.rcvListMu.Unlock() if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.state.connected() { + if e.rcvClosed || !e.EndpointState().connected() { e.stats.ReadErrors.ReadClosed.Increment() return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } @@ -1187,7 +1299,7 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { defer e.mu.Unlock() // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -1402,14 +1514,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // Acquire the work mutex as we may need to // reinitialize the congestion control state. e.mu.Lock() - state := e.state + state := e.EndpointState() e.cc = v e.mu.Unlock() switch state { case StateEstablished: e.workMu.Lock() e.mu.Lock() - if e.state == state { + if e.EndpointState() == state { e.snd.cc = e.snd.initCongestionControl(e.cc) } e.mu.Unlock() @@ -1472,7 +1584,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { defer e.mu.RUnlock() // The endpoint cannot be in listen state. - if e.state == StateListen { + if e.EndpointState() == StateListen { return 0, tcpip.ErrInvalidEndpointState } @@ -1731,7 +1843,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return err } - if e.state.connected() { + if e.EndpointState().connected() { // The endpoint is already connected. If caller hasn't been // notified yet, return success. if !e.isConnectNotified { @@ -1743,7 +1855,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } nicID := addr.NIC - switch e.state { + switch e.EndpointState() { case StateBound: // If we're already bound to a NIC but the caller is requesting // that we use a different one now, we cannot proceed. @@ -1850,7 +1962,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } e.isRegistered = true - e.state = StateConnecting + e.setEndpointState(StateConnecting) e.route = r.Clone() e.boundNICID = nicID e.effectiveNetProtos = netProtos @@ -1871,14 +1983,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) - e.state = StateEstablished - e.stack.Stats().TCP.CurrentEstablished.Increment() + e.setEndpointState(StateEstablished) } if run { e.workerRunning = true e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() - go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save. + go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. } return tcpip.ErrConnectStarted @@ -1896,7 +2007,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.shutdownFlags |= flags finQueued := false switch { - case e.state.connected(): + case e.EndpointState().connected(): // Close for read. if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { // Mark read side as closed. @@ -1908,8 +2019,23 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // If we're fully closed and we have unread data we need to abort // the connection with a RST. if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { - e.notifyProtocolGoroutine(notifyReset) + // Move the socket to error state immediately. + // This is done redundantly because in case of + // save/restore on a Shutdown/Close() the socket + // state needs to indicate the error otherwise + // save file will show the socket in established + // state even though snd/rcv are closed. e.mu.Unlock() + // Try to send an active reset immediately if the + // work mutex is available. + if e.workMu.TryLock() { + e.mu.Lock() + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.mu.Unlock() + e.workMu.Unlock() + } else { + e.notifyProtocolGoroutine(notifyReset) + } return nil } } @@ -1931,11 +2057,10 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { finQueued = true // Mark endpoint as closed. e.sndClosed = true - e.sndBufMu.Unlock() } - case e.state == StateListen: + case e.EndpointState() == StateListen: // Tell protocolListenLoop to stop. if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) @@ -1976,7 +2101,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // When the endpoint shuts down, it sets workerCleanup to true, and from // that point onward, acceptedChan is the responsibility of the cleanup() // method (and should not be touched anywhere else, including here). - if e.state == StateListen && !e.workerCleanup { + if e.EndpointState() == StateListen && !e.workerCleanup { // Adjust the size of the channel iff we can fix existing // pending connections into the new one. if len(e.acceptedChan) > backlog { @@ -1994,7 +2119,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { return nil } - if e.state == StateInitial { + if e.EndpointState() == StateInitial { // The listen is called on an unbound socket, the socket is // automatically bound to a random free port with the local // address set to INADDR_ANY. @@ -2004,7 +2129,7 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Endpoint must be bound before it can transition to listen mode. - if e.state != StateBound { + if e.EndpointState() != StateBound { e.stats.ReadErrors.InvalidEndpointState.Increment() return tcpip.ErrInvalidEndpointState } @@ -2015,24 +2140,27 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } e.isRegistered = true - e.state = StateListen + e.setEndpointState(StateListen) + if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } e.workerRunning = true - go e.protocolListenLoop( // S/R-SAFE: drained on save. seqnum.Size(e.receiveBufferAvailable())) - return nil } // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { + e.mu.Lock() e.waiterQueue = waiterQueue e.workerRunning = true - go e.protocolMainLoop(false) // S/R-SAFE: drained on save. + e.mu.Unlock() + wakerInitDone := make(chan struct{}) + go e.protocolMainLoop(false, wakerInitDone) // S/R-SAFE: drained on save. + <-wakerInitDone } // Accept returns a new endpoint if a peer has established a connection @@ -2042,7 +2170,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { defer e.mu.RUnlock() // Endpoint must be in listen state before it can accept connections. - if e.state != StateListen { + if e.EndpointState() != StateListen { return nil, nil, tcpip.ErrInvalidEndpointState } @@ -2069,7 +2197,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrAlreadyBound } @@ -2131,7 +2259,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } // Mark endpoint as bound. - e.state = StateBound + e.setEndpointState(StateBound) return nil } @@ -2153,7 +2281,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if !e.state.connected() { + if !e.EndpointState().connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -2164,45 +2292,22 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { }, nil } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { - s := newSegment(r, id, pkt) - if !s.parse() { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() - e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() - s.decRef() - return - } - - if !s.csumValid { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.ChecksumErrors.Increment() - e.stats.ReceiveErrors.ChecksumErrors.Increment() - s.decRef() - return - } - - e.stack.Stats().TCP.ValidSegmentsReceived.Increment() - e.stats.SegmentsReceived.Increment() - if (s.flags & header.TCPFlagRst) != 0 { - e.stack.Stats().TCP.ResetsReceived.Increment() - } - - e.enqueueSegment(s) + // TCP HandlePacket is not required anymore as inbound packets first + // land at the Dispatcher which then can either delivery using the + // worker go routine or directly do the invoke the tcp processing inline + // based on the state of the endpoint. } -func (e *endpoint) enqueueSegment(s *segment) { +func (e *endpoint) enqueueSegment(s *segment) bool { // Send packet to worker goroutine. - if e.segmentQueue.enqueue(s) { - e.newSegmentWaker.Assert() - } else { + if !e.segmentQueue.enqueue(s) { // The queue is full, so we drop the segment. e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.SegmentQueueDropped.Increment() - s.decRef() + return false } + return true } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. @@ -2319,8 +2424,8 @@ func (e *endpoint) rcvWndScaleForHandshake() int { // updateRecentTimestamp updates the recent timestamp using the algorithm // described in https://tools.ietf.org/html/rfc7323#section-4.3 func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { - if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { - e.recentTS = tsVal + if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + e.setRecentTimestamp(tsVal) } } @@ -2330,7 +2435,7 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { if synOpts.TS { e.sendTSOk = true - e.recentTS = synOpts.TSVal + e.setRecentTimestamp(synOpts.TSVal) } } @@ -2419,7 +2524,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { // Endpoint TCP Option state. s.SendTSOk = e.sendTSOk - s.RecentTS = e.recentTS + s.RecentTS = e.recentTimestamp() s.TSOffset = e.tsOffset s.SACKPermitted = e.sackPermitted s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) @@ -2526,9 +2631,7 @@ func (e *endpoint) initGSO() { // State implements tcpip.Endpoint.State. It exports the endpoint's protocol // state for diagnostics. func (e *endpoint) State() uint32 { - e.mu.Lock() - defer e.mu.Unlock() - return uint32(e.state) + return uint32(e.EndpointState()) } // Info returns a copy of the endpoint info. diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 4b8d867bc..4a46f0ec5 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -16,6 +16,7 @@ package tcp import ( "fmt" + "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sync" @@ -48,7 +49,7 @@ func (e *endpoint) beforeSave() { e.mu.Lock() defer e.mu.Unlock() - switch e.state { + switch e.EndpointState() { case StateInitial, StateBound: // TODO(b/138137272): this enumeration duplicates // EndpointState.connected. remove it. @@ -70,31 +71,30 @@ func (e *endpoint) beforeSave() { fallthrough case StateListen, StateConnecting: e.drainSegmentLocked() - if e.state != StateClose && e.state != StateError { + if e.EndpointState() != StateClose && e.EndpointState() != StateError { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } break } - fallthrough case StateError, StateClose: - for (e.state == StateError || e.state == StateClose) && e.workerRunning { + for e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() } if e.workerRunning { - panic("endpoint still has worker running in closed or error state") + panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID)) } default: - panic(fmt.Sprintf("endpoint in unknown state %v", e.state)) + panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) } if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() { panic("endpoint still has waiters upon save") } - if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) { + if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) { panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") } } @@ -135,7 +135,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { // saveState is invoked by stateify. func (e *endpoint) saveState() EndpointState { - return e.state + return e.EndpointState() } // Endpoint loading must be done in the following ordering by their state, to @@ -151,7 +151,8 @@ var connectingLoading sync.WaitGroup func (e *endpoint) loadState(state EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. - if state.connected() { + // For restore purposes we treat TimeWait like a connected endpoint. + if state.connected() || state == StateTimeWait { connectedLoading.Add(1) } switch state { @@ -160,13 +161,14 @@ func (e *endpoint) loadState(state EndpointState) { case StateConnecting, StateSynSent, StateSynRecv: connectingLoading.Add(1) } - e.state = state + // Directly update the state here rather than using e.setEndpointState + // as the endpoint is still being loaded and the stack reference to increment + // metrics is not yet initialized. + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) } // afterLoad is invoked by stateify. func (e *endpoint) afterLoad() { - // Freeze segment queue before registering to prevent any segments - // from being delivered while it is being restored. e.origEndpointState = e.state // Restore the endpoint to InitialState as it will be moved to // its origEndpointState during Resume. @@ -180,7 +182,6 @@ func (e *endpoint) Resume(s *stack.Stack) { e.segmentQueue.setLimit(MaxUnprocessedSegments) e.workMu.Init() state := e.origEndpointState - switch state { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption @@ -276,7 +277,7 @@ func (e *endpoint) Resume(s *stack.Stack) { listenLoading.Wait() connectingLoading.Wait() bind() - e.state = StateClose + e.setEndpointState(StateClose) tcpip.AsyncLoading.Done() }() } @@ -288,6 +289,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } + } // saveLastError is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 9a8f64aa6..958c06fa7 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -21,6 +21,7 @@ package tcp import ( + "runtime" "strings" "time" @@ -104,6 +105,7 @@ type protocol struct { moderateReceiveBuffer bool tcpLingerTimeout time.Duration tcpTimeWaitTimeout time.Duration + dispatcher *dispatcher } // Number returns the tcp protocol number. @@ -134,6 +136,14 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { return h.SourcePort(), h.DestinationPort(), nil } +// QueuePacket queues packets targeted at an endpoint after hashing the packet +// to a specific processing queue. Each queue is serviced by its own processor +// goroutine which is responsible for dequeuing and doing full TCP dispatch of +// the packet. +func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { + p.dispatcher.queuePacket(r, ep, id, pkt) +} + // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. // @@ -330,5 +340,6 @@ func NewProtocol() stack.TransportProtocol { availableCongestionControl: []string{ccReno, ccCubic}, tcpLingerTimeout: DefaultTCPLingerTimeout, tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, + dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 05c8488f8..958f03ac1 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -169,19 +169,19 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // We just received a FIN, our next state depends on whether we sent a // FIN already or not. r.ep.mu.Lock() - switch r.ep.state { + switch r.ep.EndpointState() { case StateEstablished: - r.ep.state = StateCloseWait + r.ep.setEndpointState(StateCloseWait) case StateFinWait1: if s.flagIsSet(header.TCPFlagAck) { // FIN-ACK, transition to TIME-WAIT. - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) } else { // Simultaneous close, expecting a final ACK. - r.ep.state = StateClosing + r.ep.setEndpointState(StateClosing) } case StateFinWait2: - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) } r.ep.mu.Unlock() @@ -205,16 +205,16 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // shutdown states. if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { r.ep.mu.Lock() - switch r.ep.state { + switch r.ep.EndpointState() { case StateFinWait1: - r.ep.state = StateFinWait2 + r.ep.setEndpointState(StateFinWait2) // Notify protocol goroutine that we have received an // ACK to our FIN so that it can start the FIN_WAIT2 // timer to abort connection if the other side does // not close within 2MSL. r.ep.notifyProtocolGoroutine(notifyClose) case StateClosing: - r.ep.state = StateTimeWait + r.ep.setEndpointState(StateTimeWait) case StateLastAck: r.ep.transitionToStateCloseLocked() } @@ -267,7 +267,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo switch state { case StateCloseWait, StateClosing, StateLastAck: if !s.sequenceNumber.LessThanEq(r.rcvNxt) { - s.decRef() // Just drop the segment as we have // already received a FIN and this // segment is after the sequence number @@ -284,7 +283,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { - s.decRef() return true, tcpip.ErrConnectionAborted } if state == StateFinWait1 { @@ -314,7 +312,6 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // the last actual data octet in a segment in // which it occurs. if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { - s.decRef() return true, tcpip.ErrConnectionAborted } } @@ -336,7 +333,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // r as they arrive. It is called by the protocol main loop. func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { r.ep.mu.RLock() - state := r.ep.state + state := r.ep.EndpointState() closed := r.ep.closed r.ep.mu.RUnlock() diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index fdff7ed81..b74b61e7d 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -705,17 +705,15 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } seg.flags = header.TCPFlagAck | header.TCPFlagFin segEnd = seg.sequenceNumber.Add(1) - // Transition to FIN-WAIT1 state since we're initiating an active close. - s.ep.mu.Lock() - switch s.ep.state { + // Update the state to reflect that we have now + // queued a FIN. + switch s.ep.EndpointState() { case StateCloseWait: - // We've already received a FIN and are now sending our own. The - // sender is now awaiting a final ACK for this FIN. - s.ep.state = StateLastAck + s.ep.setEndpointState(StateLastAck) default: - s.ep.state = StateFinWait1 + s.ep.setEndpointState(StateFinWait1) } - s.ep.mu.Unlock() + } else { // We're sending a non-FIN segment. if seg.flags&header.TCPFlagFin != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 6edfa8dce..a9dfbe857 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -293,7 +293,6 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { checker.SeqNum(uint32(c.IRS+1)), checker.AckNum(uint32(iss)+1), checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) - finHeaders := &context.Headers{ SrcPort: context.TestPort, DstPort: context.StackPort, @@ -459,6 +458,9 @@ func TestConnectResetAfterClose(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence number + // of the FIN itself. checker.SeqNum(uint32(c.IRS)+2), checker.AckNum(0), checker.TCPFlags(header.TCPFlagRst), @@ -1500,6 +1502,9 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence + // number of the FIN itself. checker.SeqNum(uint32(c.IRS)+2), )) // The RST puts the endpoint into an error state. @@ -5441,6 +5446,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { rawEP.SendPacketWithTS(b[start:start+mss], tsVal) packetsSent++ } + // Resume the worker so that it only sees the packets once all of them // are waiting to be read. worker.ResumeWork() @@ -5508,7 +5514,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { stk := c.Stack() // Set lower limits for auto-tuning tests. This is required because the // test stops the worker which can cause packets to be dropped because - // the segment queue holding unprocessed packets is limited to 500. + // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil { @@ -5563,6 +5569,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { totalSent += mss packetsSent++ } + // Resume it so that it only sees the packets once all of them // are waiting to be read. worker.ResumeWork() diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 5d114d460..2f9821555 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -533,7 +533,7 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { // Sleep for a little over the linger timeout to reduce flakiness in // save/restore tests. - absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1)); + absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 2)); ds.reset(); diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 6b99c021d..33a5ac66c 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -814,6 +814,20 @@ TEST_P(TcpSocketTest, FullBuffer) { t_ = -1; } +TEST_P(TcpSocketTest, PollAfterShutdown) { + ScopedThread client_thread([this]() { + EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceedsWithValue(0)); + struct pollfd poll_fd = {s_, POLLIN | POLLERR | POLLHUP, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), + SyscallSucceedsWithValue(1)); + }); + + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceedsWithValue(0)); + struct pollfd poll_fd = {t_, POLLIN | POLLERR | POLLHUP, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000), + SyscallSucceedsWithValue(1)); +} + TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) { // Initialize address to the loopback one. sockaddr_storage addr = -- cgit v1.2.3 From 815df2959a76e4a19f5882e40402b9bbca9e70be Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 15 Jan 2020 16:43:36 -0800 Subject: Solicit IPv6 routers when a NIC becomes enabled as a host This change adds support to send NDP Router Solicitation messages when a NIC becomes enabled as a host, as per RFC 4861 section 6.3.7. Note, Router Solicitations will only be sent when the stack has forwarding disabled. Tests: Unittests to make sure that the initial Router Solicitations are sent as configured. The tests also validate the sent Router Solicitations' fields. PiperOrigin-RevId: 289964095 --- pkg/tcpip/checker/checker.go | 6 + pkg/tcpip/header/BUILD | 1 + pkg/tcpip/header/ipv6.go | 7 ++ pkg/tcpip/header/ndp_router_solicit.go | 36 ++++++ pkg/tcpip/stack/BUILD | 2 +- pkg/tcpip/stack/ndp.go | 175 +++++++++++++++++++++++--- pkg/tcpip/stack/ndp_test.go | 224 +++++++++++++++++++++++++++++++++ pkg/tcpip/stack/nic.go | 73 +++++++---- pkg/tcpip/stack/stack.go | 8 +- 9 files changed, 486 insertions(+), 46 deletions(-) create mode 100644 pkg/tcpip/header/ndp_router_solicit.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 542abc99d..885d773b0 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -770,3 +770,9 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { } } } + +// NDPRS creates a checker that checks that the packet contains a valid NDP +// Router Solicitation message (as per the raw wire format). +func NDPRS() NetworkChecker { + return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize) +} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index f2061c778..cd747d100 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -20,6 +20,7 @@ go_library( "ndp_neighbor_solicit.go", "ndp_options.go", "ndp_router_advert.go", + "ndp_router_solicit.go", "tcp.go", "udp.go", ], diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 83425c614..70e6ce095 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -84,6 +84,13 @@ const ( // The address is ff02::1. IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + // IPv6AllRoutersMulticastAddress is a link-local multicast group that + // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all routers on a link. + // + // The address is ff02::2. + IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, // section 5. IPv6MinimumMTU = 1280 diff --git a/pkg/tcpip/header/ndp_router_solicit.go b/pkg/tcpip/header/ndp_router_solicit.go new file mode 100644 index 000000000..9e67ba95d --- /dev/null +++ b/pkg/tcpip/header/ndp_router_solicit.go @@ -0,0 +1,36 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +// NDPRouterSolicit is an NDP Router Solicitation message. It will only contain +// the body of an ICMPv6 packet. +// +// See RFC 4861 section 4.1 for more details. +type NDPRouterSolicit []byte + +const ( + // NDPRSMinimumSize is the minimum size of a valid NDP Router + // Solicitation message (body of an ICMPv6 packet). + NDPRSMinimumSize = 4 + + // ndpRSOptionsOffset is the start of the NDP options in an + // NDPRouterSolicit. + ndpRSOptionsOffset = 4 +) + +// Options returns an NDPOptions of the the options body. +func (b NDPRouterSolicit) Options() NDPOptions { + return NDPOptions(b[ndpRSOptionsOffset:]) +} diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 705e984c1..783351a69 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -51,7 +51,7 @@ go_library( go_test( name = "stack_x_test", - size = "small", + size = "medium", srcs = [ "ndp_test.go", "stack_test.go", diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index a9dd322db..acefc356a 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -17,6 +17,7 @@ package stack import ( "fmt" "log" + "math/rand" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,24 +39,36 @@ const ( // Default = 1s (from RFC 4861 section 10). defaultRetransmitTimer = time.Second + // defaultMaxRtrSolicitations is the default number of Router + // Solicitation messages to send when a NIC becomes enabled. + // + // Default = 3 (from RFC 4861 section 10). + defaultMaxRtrSolicitations = 3 + + // defaultRtrSolicitationInterval is the default amount of time between + // sending Router Solicitation messages. + // + // Default = 4s (from 4861 section 10). + defaultRtrSolicitationInterval = 4 * time.Second + + // defaultMaxRtrSolicitationDelay is the default maximum amount of time + // to wait before sending the first Router Solicitation message. + // + // Default = 1s (from 4861 section 10). + defaultMaxRtrSolicitationDelay = time.Second + // defaultHandleRAs is the default configuration for whether or not to // handle incoming Router Advertisements as a host. - // - // Default = true. defaultHandleRAs = true // defaultDiscoverDefaultRouters is the default configuration for // whether or not to discover default routers from incoming Router // Advertisements, as a host. - // - // Default = true. defaultDiscoverDefaultRouters = true // defaultDiscoverOnLinkPrefixes is the default configuration for // whether or not to discover on-link prefixes from incoming Router // Advertisements' Prefix Information option, as a host. - // - // Default = true. defaultDiscoverOnLinkPrefixes = true // defaultAutoGenGlobalAddresses is the default configuration for @@ -74,26 +87,31 @@ const ( // value of 0 means unspecified, so the smallest valid value is 1. // Note, the unit of the RetransmitTimer field in the Router // Advertisement is milliseconds. - // - // Min = 1ms. minimumRetransmitTimer = time.Millisecond + // minimumRtrSolicitationInterval is the minimum amount of time to wait + // between sending Router Solicitation messages. This limit is imposed + // to make sure that Router Solicitation messages are not sent all at + // once, defeating the purpose of sending the initial few messages. + minimumRtrSolicitationInterval = 500 * time.Millisecond + + // minimumMaxRtrSolicitationDelay is the minimum amount of time to wait + // before sending the first Router Solicitation message. It is 0 because + // we cannot have a negative delay. + minimumMaxRtrSolicitationDelay = 0 + // MaxDiscoveredDefaultRouters is the maximum number of discovered // default routers. The stack should stop discovering new routers after // discovering MaxDiscoveredDefaultRouters routers. // // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and // SHOULD be more. - // - // Max = 10. MaxDiscoveredDefaultRouters = 10 // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered // on-link prefixes. The stack should stop discovering new on-link // prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link // prefixes. - // - // Max = 10. MaxDiscoveredOnLinkPrefixes = 10 // validPrefixLenForAutoGen is the expected prefix length that an @@ -245,9 +263,24 @@ type NDPConfigurations struct { // The amount of time to wait between sending Neighbor solicitation // messages. // - // Must be greater than 0.5s. + // Must be greater than or equal to 1ms. RetransmitTimer time.Duration + // The number of Router Solicitation messages to send when the NIC + // becomes enabled. + MaxRtrSolicitations uint8 + + // The amount of time between transmitting Router Solicitation messages. + // + // Must be greater than or equal to 0.5s. + RtrSolicitationInterval time.Duration + + // The maximum amount of time before transmitting the first Router + // Solicitation message. + // + // Must be greater than or equal to 0s. + MaxRtrSolicitationDelay time.Duration + // HandleRAs determines whether or not Router Advertisements will be // processed. HandleRAs bool @@ -278,12 +311,15 @@ type NDPConfigurations struct { // default values. func DefaultNDPConfigurations() NDPConfigurations { return NDPConfigurations{ - DupAddrDetectTransmits: defaultDupAddrDetectTransmits, - RetransmitTimer: defaultRetransmitTimer, - HandleRAs: defaultHandleRAs, - DiscoverDefaultRouters: defaultDiscoverDefaultRouters, - DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, - AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, + DupAddrDetectTransmits: defaultDupAddrDetectTransmits, + RetransmitTimer: defaultRetransmitTimer, + MaxRtrSolicitations: defaultMaxRtrSolicitations, + RtrSolicitationInterval: defaultRtrSolicitationInterval, + MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay, + HandleRAs: defaultHandleRAs, + DiscoverDefaultRouters: defaultDiscoverDefaultRouters, + DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, + AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, } } @@ -292,10 +328,24 @@ func DefaultNDPConfigurations() NDPConfigurations { // // If RetransmitTimer is less than minimumRetransmitTimer, then a value of // defaultRetransmitTimer will be used. +// +// If RtrSolicitationInterval is less than minimumRtrSolicitationInterval, then +// a value of defaultRtrSolicitationInterval will be used. +// +// If MaxRtrSolicitationDelay is less than minimumMaxRtrSolicitationDelay, then +// a value of defaultMaxRtrSolicitationDelay will be used. func (c *NDPConfigurations) validate() { if c.RetransmitTimer < minimumRetransmitTimer { c.RetransmitTimer = defaultRetransmitTimer } + + if c.RtrSolicitationInterval < minimumRtrSolicitationInterval { + c.RtrSolicitationInterval = defaultRtrSolicitationInterval + } + + if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay { + c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay + } } // ndpState is the per-interface NDP state. @@ -316,6 +366,10 @@ type ndpState struct { // Information option. onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState + // The timer used to send the next router solicitation message. + // If routers are being solicited, rtrSolicitTimer MUST NOT be nil. + rtrSolicitTimer *time.Timer + // The addresses generated by SLAAC. autoGenAddresses map[tcpip.Address]autoGenAddressState @@ -501,10 +555,12 @@ func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining u // address. panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr)) } + snmcRef.incRef() // Use the unspecified address as the source address when performing // DAD. r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false) + defer r.Release() hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) @@ -1132,3 +1188,84 @@ func (ndp *ndpState) cleanupHostOnlyState() { log.Fatalf("ndp: still have discovered default routers after cleaning up, found = %d", got) } } + +// startSolicitingRouters starts soliciting routers, as per RFC 4861 section +// 6.3.7. If routers are already being solicited, this function does nothing. +// +// The NIC ndp belongs to MUST be locked. +func (ndp *ndpState) startSolicitingRouters() { + if ndp.rtrSolicitTimer != nil { + // We are already soliciting routers. + return + } + + remaining := ndp.configs.MaxRtrSolicitations + if remaining == 0 { + return + } + + // Calculate the random delay before sending our first RS, as per RFC + // 4861 section 6.3.7. + var delay time.Duration + if ndp.configs.MaxRtrSolicitationDelay > 0 { + delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) + } + + ndp.rtrSolicitTimer = time.AfterFunc(delay, func() { + // Send an RS message with the unspecified source address. + ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, true) + r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) + defer r.Release() + + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize) + pkt := header.ICMPv6(hdr.Prepend(payloadSize)) + pkt.SetType(header.ICMPv6RouterSolicit) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + sent := r.Stats().ICMP.V6PacketsSent + if err := r.WritePacket(nil, + NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + TOS: DefaultTOS, + }, tcpip.PacketBuffer{Header: hdr}, + ); err != nil { + sent.Dropped.Increment() + log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) + // Don't send any more messages if we had an error. + remaining = 0 + } else { + sent.RouterSolicit.Increment() + remaining-- + } + + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + if remaining == 0 { + ndp.rtrSolicitTimer = nil + } else if ndp.rtrSolicitTimer != nil { + // Note, we need to explicitly check to make sure that + // the timer field is not nil because if it was nil but + // we still reached this point, then we know the NIC + // was requested to stop soliciting routers so we don't + // need to send the next Router Solicitation message. + ndp.rtrSolicitTimer.Reset(ndp.configs.RtrSolicitationInterval) + } + }) + +} + +// stopSolicitingRouters stops soliciting routers. If routers are not currently +// being solicited, this function does nothing. +// +// The NIC ndp belongs to MUST be locked. +func (ndp *ndpState) stopSolicitingRouters() { + if ndp.rtrSolicitTimer == nil { + // Nothing to do. + return + } + + ndp.rtrSolicitTimer.Stop() + ndp.rtrSolicitTimer = nil +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index d390c6312..7c68e8ed4 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -3098,3 +3098,227 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() } + +// TestRouterSolicitation tests the initial Router Solicitations that are sent +// when a NIC newly becomes enabled. +func TestRouterSolicitation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + maxRtrSolicit uint8 + rtrSolicitInt time.Duration + effectiveRtrSolicitInt time.Duration + maxRtrSolicitDelay time.Duration + effectiveMaxRtrSolicitDelay time.Duration + }{ + { + name: "Single RS with delay", + maxRtrSolicit: 1, + rtrSolicitInt: time.Second, + effectiveRtrSolicitInt: time.Second, + maxRtrSolicitDelay: time.Second, + effectiveMaxRtrSolicitDelay: time.Second, + }, + { + name: "Two RS with delay", + maxRtrSolicit: 2, + rtrSolicitInt: time.Second, + effectiveRtrSolicitInt: time.Second, + maxRtrSolicitDelay: 500 * time.Millisecond, + effectiveMaxRtrSolicitDelay: 500 * time.Millisecond, + }, + { + name: "Single RS without delay", + maxRtrSolicit: 1, + rtrSolicitInt: time.Second, + effectiveRtrSolicitInt: time.Second, + maxRtrSolicitDelay: 0, + effectiveMaxRtrSolicitDelay: 0, + }, + { + name: "Two RS without delay and invalid zero interval", + maxRtrSolicit: 2, + rtrSolicitInt: 0, + effectiveRtrSolicitInt: 4 * time.Second, + maxRtrSolicitDelay: 0, + effectiveMaxRtrSolicitDelay: 0, + }, + { + name: "Three RS without delay", + maxRtrSolicit: 3, + rtrSolicitInt: 500 * time.Millisecond, + effectiveRtrSolicitInt: 500 * time.Millisecond, + maxRtrSolicitDelay: 0, + effectiveMaxRtrSolicitDelay: 0, + }, + { + name: "Two RS with invalid negative delay", + maxRtrSolicit: 2, + rtrSolicitInt: time.Second, + effectiveRtrSolicitInt: time.Second, + maxRtrSolicitDelay: -3 * time.Second, + effectiveMaxRtrSolicitDelay: time.Second, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1) + waitForPkt := func(timeout time.Duration) { + t.Helper() + select { + case p := <-e.C: + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, + p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(), + ) + + case <-time.After(timeout): + t.Fatal("timed out waiting for packet") + } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() + select { + case <-e.C: + t.Fatal("unexpectedly got a packet") + case <-time.After(timeout): + } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Make sure each RS got sent at the right + // times. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultTimeout) + remaining-- + } + for ; remaining > 0; remaining-- { + waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout) + waitForPkt(2 * defaultTimeout) + } + + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt + defaultTimeout) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultTimeout) + } + + // Make sure the counter got properly + // incremented. + if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) + } + }) +} + +// TestStopStartSolicitingRouters tests that when forwarding is enabled or +// disabled, router solicitations are stopped or started, respecitively. +func TestStopStartSolicitingRouters(t *testing.T) { + t.Parallel() + + const interval = 500 * time.Millisecond + const delay = time.Second + const maxRtrSolicitations = 3 + e := channel.New(maxRtrSolicitations, 1280, linkAddr1) + waitForPkt := func(timeout time.Duration) { + t.Helper() + select { + case p := <-e.C: + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS()) + + case <-time.After(timeout): + t.Fatal("timed out waiting for packet") + } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: interval, + MaxRtrSolicitationDelay: delay, + }, + }) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } + + // Enable forwarding which should stop router solicitations. + s.SetForwarding(true) + select { + case <-e.C: + // A single RS may have been sent before forwarding was enabled. + select { + case <-e.C: + t.Fatal("Should not have sent more than one RS message") + case <-time.After(interval + defaultTimeout): + } + case <-time.After(delay + defaultTimeout): + } + + // Enabling forwarding again should do nothing. + s.SetForwarding(true) + select { + case <-e.C: + t.Fatal("unexpectedly got a packet after becoming a router") + case <-time.After(delay + defaultTimeout): + } + + // Disable forwarding which should start router solicitations. + s.SetForwarding(false) + waitForPkt(delay + defaultTimeout) + waitForPkt(interval + defaultTimeout) + waitForPkt(interval + defaultTimeout) + select { + case <-e.C: + t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") + case <-time.After(interval + defaultTimeout): + } + + // Disabling forwarding again should do nothing. + s.SetForwarding(false) + select { + case <-e.C: + t.Fatal("unexpectedly got a packet after becoming a router") + case <-time.After(delay + defaultTimeout): + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 071221d5a..1089fdf35 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -177,49 +177,72 @@ func (n *NIC) enable() *tcpip.Error { } // Do not auto-generate an IPv6 link-local address for loopback devices. - if !n.stack.autoGenIPv6LinkLocal || n.isLoopback() { - return nil - } + if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() { + var addr tcpip.Address + if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey) + } else { + l2addr := n.linkEP.LinkAddress() - var addr tcpip.Address - if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { - addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey) - } else { - l2addr := n.linkEP.LinkAddress() + // Only attempt to generate the link-local address if we have a valid MAC + // address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + if !header.IsValidUnicastEthernetAddress(l2addr) { + return nil + } - // Only attempt to generate the link-local address if we have a valid MAC - // address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by - // LinkEndpoint.LinkAddress) before reaching this point. - if !header.IsValidUnicastEthernetAddress(l2addr) { - return nil + addr = header.LinkLocalAddr(l2addr) } - addr = header.LinkLocalAddr(l2addr) + if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, + }, + }, CanBePrimaryEndpoint); err != nil { + return err + } } - _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, - }, - }, CanBePrimaryEndpoint) + // If we are operating as a router, then do not solicit routers since we + // won't process the RAs anyways. + // + // Routers do not process Router Advertisements (RA) the same way a host + // does. That is, routers do not learn from RAs (e.g. on-link prefixes + // and default routers). Therefore, soliciting RAs from other routers on + // a link is unnecessary for routers. + if !n.stack.forwarding { + n.ndp.startSolicitingRouters() + } - return err + return nil } // becomeIPv6Router transitions n into an IPv6 router. // // When transitioning into an IPv6 router, host-only state (NDP discovered // routers, discovered on-link prefixes, and auto-generated addresses) will -// be cleaned up/invalidated. +// be cleaned up/invalidated and NDP router solicitations will be stopped. func (n *NIC) becomeIPv6Router() { n.mu.Lock() defer n.mu.Unlock() n.ndp.cleanupHostOnlyState() + n.ndp.stopSolicitingRouters() +} + +// becomeIPv6Host transitions n into an IPv6 host. +// +// When transitioning into an IPv6 host, NDP router solicitations will be +// started. +func (n *NIC) becomeIPv6Host() { + n.mu.Lock() + defer n.mu.Unlock() + + n.ndp.startSolicitingRouters() } // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 386eb6eec..fc56a6d79 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -750,7 +750,9 @@ func (s *Stack) Stats() tcpip.Stats { // SetForwarding enables or disables the packet forwarding between NICs. // // When forwarding becomes enabled, any host-only state on all NICs will be -// cleaned up. +// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started. +// When forwarding becomes disabled and if IPv6 is enabled, NDP Router +// Solicitations will be stopped. func (s *Stack) SetForwarding(enable bool) { // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. s.mu.Lock() @@ -772,6 +774,10 @@ func (s *Stack) SetForwarding(enable bool) { for _, nic := range s.nics { nic.becomeIPv6Router() } + } else { + for _, nic := range s.nics { + nic.becomeIPv6Host() + } } } -- cgit v1.2.3 From a7a1f00425c6a742a0c953ae3cb6de513011d41b Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 15 Jan 2020 20:21:20 -0800 Subject: Support upgrading expired/removed IPv6 addresses to permanent SLAAC addresses If a previously added IPv6 address (statically or via SLAAC) was removed, it would be left in an expired state waiting to be cleaned up if any references to it were still held. During this time, the same address could be regenerated via SLAAC, which should be allowed. This change supports this scenario. When upgrading an endpoint from temporary or permanentExpired to permanent, respect the new configuration type (static or SLAAC) and deprecated status, along with the new PrimaryEndpointBehavior (which was already supported). Test: stack.TestAutoGenAddrAfterRemoval PiperOrigin-RevId: 289990168 --- pkg/tcpip/stack/ndp.go | 2 +- pkg/tcpip/stack/ndp_test.go | 125 +++++++++++++++++++++++++++++++++++++++++--- pkg/tcpip/stack/nic.go | 30 ++++++++--- 3 files changed, 142 insertions(+), 15 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index acefc356a..c99d387d5 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -994,7 +994,7 @@ func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration // If the preferred lifetime is zero, then the address should be considered // deprecated. deprecated := pl == 0 - ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) + ref, err := ndp.nic.addPermanentAddressLocked(protocolAddr, FirstPrimaryEndpoint, slaac, deprecated) if err != nil { log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err) } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 7c68e8ed4..1a52e0e68 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -35,12 +35,12 @@ import ( ) const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" - linkAddr1 = "\x02\x02\x03\x04\x05\x06" - linkAddr2 = "\x02\x02\x03\x04\x05\x07" - linkAddr3 = "\x02\x02\x03\x04\x05\x08" + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") + linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") defaultTimeout = 100 * time.Millisecond ) @@ -2445,6 +2445,119 @@ func TestAutoGenAddrRemoval(t *testing.T) { } } +// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously +// assigned to the NIC but is in the permanentExpired state. +func TestAutoGenAddrAfterRemoval(t *testing.T) { + t.Parallel() + + const nicID = 1 + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() + + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } + + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } + + // Receive a PI to auto-generate addr1 with a large valid and preferred + // lifetime. + const largeLifetimeSeconds = 999 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr1, newAddr) + expectPrimaryAddr(addr1) + + // Add addr2 as a static address. + protoAddr2 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr2, + } + if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d, %s) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + } + // addr2 should be more preferred now since it is at the front of the primary + // list. + expectPrimaryAddr(addr2) + + // Get a route using addr2 to increment its reference count then remove it + // to leave it in the permanentExpired state. + r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) + } + defer r.Release() + if err := s.RemoveAddress(nicID, addr2.Address); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) + } + // addr1 should be preferred again since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and preferred. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr2 should be more preferred now that it is closer to the front of the + // primary list and not deprecated. + expectPrimaryAddr(addr2) + + // Removing the address should result in an invalidation event immediately. + // It should still be in the permanentExpired state because r is still held. + // + // We remove addr2 here to make sure addr2 was marked as a SLAAC address + // (it was previously marked as a static address). + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + // addr1 should be more preferred since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr1 should still be more preferred since addr2 is deprecated, even though + // it is closer to the front of the primary list. + expectPrimaryAddr(addr1) + + // Receive a PI to refresh addr2's preferred lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto gen addr event") + default: + } + // addr2 should be more preferred now that it is not deprecated. + expectPrimaryAddr(addr2) + + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + expectPrimaryAddr(addr1) +} + // TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that // is already assigned to the NIC, the static address remains. func TestAutoGenAddrStaticConflict(t *testing.T) { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 1089fdf35..4452a1302 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -202,7 +202,7 @@ func (n *NIC) enable() *tcpip.Error { Address: addr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, }, - }, CanBePrimaryEndpoint); err != nil { + }, CanBePrimaryEndpoint, static, false /* deprecated */); err != nil { return err } } @@ -533,7 +533,12 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t return ref } -func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) { +// addPermanentAddressLocked adds a permanent address to n. +// +// If n already has the address in a non-permanent state, +// addPermanentAddressLocked will promote it to permanent and update the +// endpoint with the properties provided. +func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} if ref, ok := n.endpoints[id]; ok { switch ref.getKind() { @@ -541,10 +546,14 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p // The NIC already have a permanent endpoint with that address. return nil, tcpip.ErrDuplicateAddress case permanentExpired, temporary: - // Promote the endpoint to become permanent and respect - // the new peb. + // Promote the endpoint to become permanent and respect the new peb, + // configType and deprecated status. if ref.tryIncRef() { + // TODO(b/147748385): Perform Duplicate Address Detection when promoting + // an IPv6 endpoint to permanent. ref.setKind(permanent) + ref.deprecated = deprecated + ref.configType = configType refs := n.primary[ref.protocol] for i, r := range refs { @@ -576,9 +585,13 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } } - return n.addAddressLocked(protocolAddress, peb, permanent, static, false) + return n.addAddressLocked(protocolAddress, peb, permanent, configType, deprecated) } +// addAddressLocked adds a new protocolAddress to n. +// +// If the address is already known by n (irrespective of the state it is in), +// addAddressLocked does nothing and returns tcpip.ErrDuplicateAddress. func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { // TODO(b/141022673): Validate IP address before adding them. @@ -653,7 +666,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addPermanentAddressLocked(protocolAddress, peb) + _, err := n.addPermanentAddressLocked(protocolAddress, peb, static, false /* deprecated */) n.mu.Unlock() return err @@ -935,7 +948,7 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A Address: addr, PrefixLen: netProto.DefaultPrefixLen(), }, - }, NeverPrimaryEndpoint); err != nil { + }, NeverPrimaryEndpoint, static, false /* deprecated */); err != nil { return err } } @@ -1313,7 +1326,8 @@ type referencedNetworkEndpoint struct { kind networkEndpointKind // configType is the method that was used to configure this endpoint. - // This must never change after the endpoint is added to a NIC. + // This must never change except during endpoint creation and promotion to + // permanent. configType networkEndpointConfigType // deprecated indicates whether or not the endpoint should be considered -- cgit v1.2.3 From 23fa847910eeee05babeea4f712b905115eeb865 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Fri, 17 Jan 2020 11:40:51 -0800 Subject: Remove addPermanentAddressLocked It was possible to use this function incorrectly, and its separation wasn't buying us anything. PiperOrigin-RevId: 290311100 --- pkg/tcpip/stack/ndp.go | 21 ++++++++++----------- pkg/tcpip/stack/nic.go | 46 ++++++++++++++++++---------------------------- 2 files changed, 28 insertions(+), 39 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index c99d387d5..7d4b41dfa 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -432,13 +432,12 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // Should not attempt to perform DAD on an address that is currently in // the DAD process. if _, ok := ndp.dad[addr]; ok { - // Should never happen because we should only ever call this - // function for newly created addresses. If we attemped to - // "add" an address that already existed, we would returned an - // error since we attempted to add a duplicate address, or its - // reference count would have been increased without doing the - // work that would have been done for an address that was brand - // new. See NIC.addPermanentAddressLocked. + // Should never happen because we should only ever call this function for + // newly created addresses. If we attemped to "add" an address that already + // existed, we would get an error since we attempted to add a duplicate + // address, or its reference count would have been increased without doing + // the work that would have been done for an address that was brand new. + // See NIC.addAddressLocked. panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) } @@ -994,7 +993,7 @@ func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration // If the preferred lifetime is zero, then the address should be considered // deprecated. deprecated := pl == 0 - ref, err := ndp.nic.addPermanentAddressLocked(protocolAddr, FirstPrimaryEndpoint, slaac, deprecated) + ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) if err != nil { log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err) } @@ -1164,7 +1163,7 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupHostOnlyState() { - for addr, _ := range ndp.autoGenAddresses { + for addr := range ndp.autoGenAddresses { ndp.invalidateAutoGenAddress(addr) } @@ -1172,7 +1171,7 @@ func (ndp *ndpState) cleanupHostOnlyState() { log.Fatalf("ndp: still have auto-generated addresses after cleaning up, found = %d", got) } - for prefix, _ := range ndp.onLinkPrefixes { + for prefix := range ndp.onLinkPrefixes { ndp.invalidateOnLinkPrefix(prefix) } @@ -1180,7 +1179,7 @@ func (ndp *ndpState) cleanupHostOnlyState() { log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up, found = %d", got) } - for router, _ := range ndp.defaultRouters { + for router := range ndp.defaultRouters { ndp.invalidateDefaultRouter(router) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 4452a1302..53abf29e5 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -196,13 +196,13 @@ func (n *NIC) enable() *tcpip.Error { addr = header.LinkLocalAddr(l2addr) } - if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ + if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: addr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, }, - }, CanBePrimaryEndpoint, static, false /* deprecated */); err != nil { + }, CanBePrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { return err } } @@ -533,14 +533,21 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t return ref } -// addPermanentAddressLocked adds a permanent address to n. +// addAddressLocked adds a new protocolAddress to n. // -// If n already has the address in a non-permanent state, -// addPermanentAddressLocked will promote it to permanent and update the -// endpoint with the properties provided. -func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { - id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} +// If n already has the address in a non-permanent state, and the kind given is +// permanent, that address will be promoted in place and its properties set to +// the properties provided. Otherwise, it returns tcpip.ErrDuplicateAddress. +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { + // TODO(b/141022673): Validate IP addresses before adding them. + + // Sanity check. + id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address} if ref, ok := n.endpoints[id]; ok { + // Endpoint already exists. + if kind != permanent { + return nil, tcpip.ErrDuplicateAddress + } switch ref.getKind() { case permanentTentative, permanent: // The NIC already have a permanent endpoint with that address. @@ -585,23 +592,6 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } } - return n.addAddressLocked(protocolAddress, peb, permanent, configType, deprecated) -} - -// addAddressLocked adds a new protocolAddress to n. -// -// If the address is already known by n (irrespective of the state it is in), -// addAddressLocked does nothing and returns tcpip.ErrDuplicateAddress. -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) { - // TODO(b/141022673): Validate IP address before adding them. - - // Sanity check. - id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} - if _, ok := n.endpoints[id]; ok { - // Endpoint already exists. - return nil, tcpip.ErrDuplicateAddress - } - netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol] if !ok { return nil, tcpip.ErrUnknownProtocol @@ -666,7 +656,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addPermanentAddressLocked(protocolAddress, peb, static, false /* deprecated */) + _, err := n.addAddressLocked(protocolAddress, peb, permanent, static, false /* deprecated */) n.mu.Unlock() return err @@ -942,13 +932,13 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A if !ok { return tcpip.ErrUnknownProtocol } - if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ + if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: addr, PrefixLen: netProto.DefaultPrefixLen(), }, - }, NeverPrimaryEndpoint, static, false /* deprecated */); err != nil { + }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { return err } } -- cgit v1.2.3 From 47d85257d3d015f0b9f7739c81af0ee9f510aaf5 Mon Sep 17 00:00:00 2001 From: Eyal Soha Date: Fri, 17 Jan 2020 18:24:39 -0800 Subject: Filter out received packets with a local source IP address. CERT Advisory CA-96.21 III. Solution advises that devices drop packets which could not have correctly arrived on the wire, such as receiving a packet where the source IP address is owned by the device that sent it. Fixes #1507 PiperOrigin-RevId: 290378240 --- pkg/sentry/socket/netstack/netstack.go | 15 +++++----- pkg/sentry/socket/netstack/stack.go | 38 ++++++++++++------------- pkg/tcpip/stack/nic.go | 14 +++++++-- pkg/tcpip/tcpip.go | 10 +++++-- pkg/tcpip/transport/udp/udp_test.go | 52 ++++++++++++++++++++++++++++++++-- 5 files changed, 95 insertions(+), 34 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index d2f7e987d..fec575357 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -138,13 +138,14 @@ var Metrics = tcpip.Stats{ }, }, IP: tcpip.IPStats{ - PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), - InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), - PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), - PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), - OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), - MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), - MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), + PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), + InvalidDestinationAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), + InvalidSourceAddressesReceived: mustCreateMetric("/netstack/ip/invalid_source_addresses_received", "Total number of IP packets received with an unknown or invalid source address."), + PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), + PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), + OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), + MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), + MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), }, TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index a0db2d4fd..31ea66eca 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -148,25 +148,25 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { case *inet.StatSNMPIP: ip := Metrics.IP *stats = inet.StatSNMPIP{ - 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding. - 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL. - ip.PacketsReceived.Value(), // InReceives. - 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors. - ip.InvalidAddressesReceived.Value(), // InAddrErrors. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams. - 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos. - 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards. - ip.PacketsDelivered.Value(), // InDelivers. - ip.PacketsSent.Value(), // OutRequests. - ip.OutgoingPacketErrors.Value(), // OutDiscards. - 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails. - 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs. - 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails. - 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates. + 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding. + 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL. + ip.PacketsReceived.Value(), // InReceives. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors. + ip.InvalidDestinationAddressesReceived.Value(), // InAddrErrors. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos. + 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards. + ip.PacketsDelivered.Value(), // InDelivers. + ip.PacketsSent.Value(), // OutRequests. + ip.OutgoingPacketErrors.Value(), // OutDiscards. + 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails. + 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates. } case *inet.StatSNMPICMP: in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 53abf29e5..4afe7b744 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -984,7 +984,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when -// the NIC receives a packet from the physical interface. +// the NIC receives a packet from the link endpoint. // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. @@ -1029,6 +1029,14 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link src, dst := netProto.ParseAddresses(pkt.Data.First()) + if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { + // The source address is one of our own, so we never should have gotten a + // packet like this unless handleLocal is false. Loopback also calls this + // function even though the packets didn't come from the physical interface + // so don't drop those. + n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() + return + } if ref := n.getRef(protocol, dst); ref != nil { handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) return @@ -1041,7 +1049,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link if n.stack.Forwarding() { r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) if err != nil { - n.stack.stats.IP.InvalidAddressesReceived.Increment() + n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() return } defer r.Release() @@ -1079,7 +1087,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // If a packet socket handled the packet, don't treat it as invalid. if len(packetEPs) == 0 { - n.stack.stats.IP.InvalidAddressesReceived.Increment() + n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index b7813cbc0..6243762e3 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -903,9 +903,13 @@ type IPStats struct { // link layer in nic.DeliverNetworkPacket. PacketsReceived *StatCounter - // InvalidAddressesReceived is the total number of IP packets received - // with an unknown or invalid destination address. - InvalidAddressesReceived *StatCounter + // InvalidDestinationAddressesReceived is the total number of IP packets + // received with an unknown or invalid destination address. + InvalidDestinationAddressesReceived *StatCounter + + // InvalidSourceAddressesReceived is the total number of IP packets received + // with a source address that should never have been received on the wire. + InvalidSourceAddressesReceived *StatCounter // PacketsDelivered is the total number of incoming IP packets that // are successfully delivered to the transport layer via HandlePacket. diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index ee9d10555..51bb61167 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -274,11 +274,16 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() - - s := stack.New(stack.Options{ + return newDualTestContextWithOptions(t, mtu, stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, }) +} + +func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext { + t.Helper() + + s := stack.New(options) ep := channel.New(256, mtu, "") wep := stack.LinkEndpoint(ep) @@ -763,6 +768,49 @@ func TestV6ReadOnV6(t *testing.T) { testRead(c, unicastV6) } +// TestV4ReadSelfSource checks that packets coming from a local IP address are +// correctly dropped when handleLocal is true and not otherwise. +func TestV4ReadSelfSource(t *testing.T) { + for _, tt := range []struct { + name string + handleLocal bool + wantErr *tcpip.Error + wantInvalidSource uint64 + }{ + {"HandleLocal", false, nil, 0}, + {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1}, + } { + t.Run(tt.name, func(t *testing.T) { + c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + HandleLocal: tt.handleLocal, + }) + defer c.cleanup() + + c.createEndpointForFlow(unicastV4) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + h.srcAddr = h.dstAddr + + c.injectV4Packet(payload, &h, true /* valid */) + + if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { + t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) + } + + if _, _, err := c.ep.Read(nil); err != tt.wantErr { + t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr) + } + }) + } +} + func TestV4ReadOnV4(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() -- cgit v1.2.3 From 1d97adaa6d73dd897bb4e89d4533936a95003951 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 22 Jan 2020 14:50:32 -0800 Subject: Use embedded mutex pattern for stack.NIC - Wrap NIC's fields that should only be accessed while holding the mutex in an anonymous struct with the embedded mutex. - Make sure NIC's spoofing and promiscuous mode flags are only read while holding the NIC's mutex. - Use the correct endpoint when sending DAD messages. - Do not hold the NIC's lock when sending DAD messages. This change does not introduce any behaviour changes. Tests: Existing tests continue to pass. PiperOrigin-RevId: 291036251 --- pkg/tcpip/stack/ndp.go | 181 +++++++++++++++------------------ pkg/tcpip/stack/ndp_test.go | 45 ++++----- pkg/tcpip/stack/nic.go | 236 +++++++++++++++++++++++++------------------- 3 files changed, 234 insertions(+), 228 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 7d4b41dfa..d983ac390 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -15,7 +15,6 @@ package stack import ( - "fmt" "log" "math/rand" "time" @@ -429,8 +428,13 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref return tcpip.ErrAddressFamilyNotSupported } - // Should not attempt to perform DAD on an address that is currently in - // the DAD process. + if ref.getKind() != permanentTentative { + // The endpoint should be marked as tentative since we are starting DAD. + log.Fatalf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()) + } + + // Should not attempt to perform DAD on an address that is currently in the + // DAD process. if _, ok := ndp.dad[addr]; ok { // Should never happen because we should only ever call this function for // newly created addresses. If we attemped to "add" an address that already @@ -438,77 +442,79 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // address, or its reference count would have been increased without doing // the work that would have been done for an address that was brand new. // See NIC.addAddressLocked. - panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) + log.Fatalf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()) } remaining := ndp.configs.DupAddrDetectTransmits - - { - done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref) - if err != nil { - return err - } - if done { - return nil - } + if remaining == 0 { + ref.setKind(permanent) + return nil } - remaining-- - var done bool var timer *time.Timer - timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() { - var d bool - var err *tcpip.Error - - // doDadIteration does a single iteration of the DAD loop. - // - // Returns true if the integrator needs to be informed of DAD - // completing. - doDadIteration := func() bool { - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if done { - // If we reach this point, it means that the DAD - // timer fired after another goroutine already - // obtained the NIC lock and stopped DAD before - // this function obtained the NIC lock. Simply - // return here and do nothing further. - return false - } + // We initially start a timer to fire immediately because some of the DAD work + // cannot be done while holding the NIC's lock. This is effectively the same + // as starting a goroutine but we use a timer that fires immediately so we can + // reset it for the next DAD iteration. + timer = time.AfterFunc(0, func() { + ndp.nic.mu.RLock() + if done { + // If we reach this point, it means that the DAD timer fired after + // another goroutine already obtained the NIC lock and stopped DAD + // before this function obtained the NIC lock. Simply return here and do + // nothing further. + ndp.nic.mu.RUnlock() + return + } - ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}] - if !ok { - // This should never happen. - // We should have an endpoint for addr since we - // are still performing DAD on it. If the - // endpoint does not exist, but we are doing DAD - // on it, then we started DAD at some point, but - // forgot to stop it when the endpoint was - // deleted. - panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID())) - } + if ref.getKind() != permanentTentative { + // The endpoint should still be marked as tentative since we are still + // performing DAD on it. + log.Fatalf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID()) + } - d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref) - if err != nil || d { - delete(ndp.dad, addr) + dadDone := remaining == 0 + ndp.nic.mu.RUnlock() - if err != nil { - log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) - } + var err *tcpip.Error + if !dadDone { + err = ndp.sendDADPacket(addr) + } - // Let the integrator know DAD has completed. - return true - } + ndp.nic.mu.Lock() + if done { + // If we reach this point, it means that DAD was stopped after we released + // the NIC's read lock and before we obtained the write lock. + ndp.nic.mu.Unlock() + return + } + if dadDone { + // DAD has resolved. + ref.setKind(permanent) + } else if err == nil { + // DAD is not done and we had no errors when sending the last NDP NS, + // schedule the next DAD timer. remaining-- timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) - return false + + ndp.nic.mu.Unlock() + return + } + + // At this point we know that either DAD is done or we hit an error sending + // the last NDP NS. Either way, clean up addr's DAD state and let the + // integrator know DAD has completed. + delete(ndp.dad, addr) + ndp.nic.mu.Unlock() + + if err != nil { + log.Printf("ndpdad: error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) } - if doDadIteration() && ndp.nic.stack.ndpDisp != nil { - ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err) + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err) } }) @@ -520,45 +526,16 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref return nil } -// doDuplicateAddressDetection is called on every iteration of the timer, and -// when DAD starts. -// -// It handles resolving the address (if there are no more NS to send), or -// sending the next NS if there are more NS to send. -// -// This function must only be called by IPv6 addresses that are currently -// tentative. -// -// The NIC that ndp belongs to (n) MUST be locked. +// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns +// addr. // -// Returns true if DAD has resolved; false if DAD is still ongoing. -func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) { - if ref.getKind() != permanentTentative { - // The endpoint should still be marked as tentative - // since we are still performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) - } - - if remaining == 0 { - // DAD has resolved. - ref.setKind(permanent) - return true, nil - } - - // Send a new NS. +// addr must be a tentative IPv6 address on ndp's NIC. +func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}] - if !ok { - // This should never happen as if we have the - // address, we should have the solicited-node - // address. - panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr)) - } - snmcRef.incRef() - // Use the unspecified address as the source address when performing - // DAD. - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false) + // Use the unspecified address as the source address when performing DAD. + ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) + r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) @@ -569,15 +546,19 @@ func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining u pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}, tcpip.PacketBuffer{ - Header: hdr, - }); err != nil { + if err := r.WritePacket(nil, + NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + TOS: DefaultTOS, + }, tcpip.PacketBuffer{Header: hdr}, + ); err != nil { sent.Dropped.Increment() - return false, err + return err } sent.NeighborSolicit.Increment() - return false, nil + return nil } // stopDuplicateAddressDetection ends a running Duplicate Address Detection @@ -1212,7 +1193,7 @@ func (ndp *ndpState) startSolicitingRouters() { ndp.rtrSolicitTimer = time.AfterFunc(delay, func() { // Send an RS message with the unspecified source address. - ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, true) + ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 1a52e0e68..376681b30 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -301,6 +301,8 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s // Included in the subtests is a test to make sure that an invalid // RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. func TestDADResolve(t *testing.T) { + const nicID = 1 + tests := []struct { name string dupAddrDetectTransmits uint8 @@ -331,44 +333,36 @@ func TestDADResolve(t *testing.T) { opts.NDPConfigs.RetransmitTimer = test.retransTimer opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits - e := channel.New(10, 1280, linkAddr1) + e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1) s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) - } - - stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit - - // Should have sent an NDP NS immediately. - if got := stat.Value(); got != 1 { - t.Fatalf("got NeighborSolicit = %d, want = 1", got) - + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) } // Address should not be considered bound to the NIC yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Wait for the remaining time - some delta (500ms), to // make sure the address is still not resolved. const delta = 500 * time.Millisecond time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Wait for DAD to resolve. @@ -385,8 +379,8 @@ func TestDADResolve(t *testing.T) { if e.err != nil { t.Fatal("got DAD error: ", e.err) } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + if e.nicID != nicID { + t.Fatalf("got DAD event w/ nicID = %d, want = %d", e.nicID, nicID) } if e.addr != addr1 { t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) @@ -395,16 +389,16 @@ func TestDADResolve(t *testing.T) { t.Fatal("got DAD event w/ resolved = false, want = true") } } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) } // Should not have sent any more NS messages. - if got := stat.Value(); got != uint64(test.dupAddrDetectTransmits) { + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) } @@ -425,7 +419,6 @@ func TestDADResolve(t *testing.T) { } }) } - } // TestDADFail tests to make sure that the DAD process fails if another node is diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index de88c0bfa..79556a36f 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -35,24 +35,21 @@ type NIC struct { linkEP LinkEndpoint context NICContext - mu sync.RWMutex - spoofing bool - promiscuous bool - primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - addressRanges []tcpip.Subnet - mcastJoins map[NetworkEndpointID]int32 - // packetEPs is protected by mu, but the contained PacketEndpoint - // values are not. - packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint - stats NICStats - // ndp is the NDP related state for NIC. - // - // Note, read and write operations on ndp require that the NIC is - // appropriately locked. - ndp ndpState + mu struct { + sync.RWMutex + spoofing bool + promiscuous bool + primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint + endpoints map[NetworkEndpointID]*referencedNetworkEndpoint + addressRanges []tcpip.Subnet + mcastJoins map[NetworkEndpointID]int32 + // packetEPs is protected by mu, but the contained PacketEndpoint + // values are not. + packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint + ndp ndpState + } } // NICStats includes transmitted and received stats. @@ -97,15 +94,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // of IPv6 is supported on this endpoint's LinkEndpoint. nic := &NIC{ - stack: stack, - id: id, - name: name, - linkEP: ep, - context: ctx, - primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), - endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), - mcastJoins: make(map[NetworkEndpointID]int32), - packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint), + stack: stack, + id: id, + name: name, + linkEP: ep, + context: ctx, stats: NICStats{ Tx: DirectionStats{ Packets: &tcpip.StatCounter{}, @@ -116,22 +109,26 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC Bytes: &tcpip.StatCounter{}, }, }, - ndp: ndpState{ - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), - }, } - nic.ndp.nic = nic + nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) + nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) + nic.mu.mcastJoins = make(map[NetworkEndpointID]int32) + nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) + nic.mu.ndp = ndpState{ + nic: nic, + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), + } // Register supported packet endpoint protocols. for _, netProto := range header.Ethertypes { - nic.packetEPs[netProto] = []PacketEndpoint{} + nic.mu.packetEPs[netProto] = []PacketEndpoint{} } for _, netProto := range stack.networkProtocols { - nic.packetEPs[netProto.Number()] = []PacketEndpoint{} + nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{} } return nic @@ -215,7 +212,7 @@ func (n *NIC) enable() *tcpip.Error { // and default routers). Therefore, soliciting RAs from other routers on // a link is unnecessary for routers. if !n.stack.forwarding { - n.ndp.startSolicitingRouters() + n.mu.ndp.startSolicitingRouters() } return nil @@ -230,8 +227,8 @@ func (n *NIC) becomeIPv6Router() { n.mu.Lock() defer n.mu.Unlock() - n.ndp.cleanupHostOnlyState() - n.ndp.stopSolicitingRouters() + n.mu.ndp.cleanupHostOnlyState() + n.mu.ndp.stopSolicitingRouters() } // becomeIPv6Host transitions n into an IPv6 host. @@ -242,7 +239,7 @@ func (n *NIC) becomeIPv6Host() { n.mu.Lock() defer n.mu.Unlock() - n.ndp.startSolicitingRouters() + n.mu.ndp.startSolicitingRouters() } // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it @@ -254,13 +251,13 @@ func (n *NIC) attachLinkEndpoint() { // setPromiscuousMode enables or disables promiscuous mode. func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Lock() - n.promiscuous = enable + n.mu.promiscuous = enable n.mu.Unlock() } func (n *NIC) isPromiscuousMode() bool { n.mu.RLock() - rv := n.promiscuous + rv := n.mu.promiscuous n.mu.RUnlock() return rv } @@ -272,7 +269,7 @@ func (n *NIC) isLoopback() bool { // setSpoofing enables or disables address spoofing. func (n *NIC) setSpoofing(enable bool) { n.mu.Lock() - n.spoofing = enable + n.mu.spoofing = enable n.mu.Unlock() } @@ -291,8 +288,8 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr t defer n.mu.RUnlock() var deprecatedEndpoint *referencedNetworkEndpoint - for _, r := range n.primary[protocol] { - if !r.isValidForOutgoing() { + for _, r := range n.mu.primary[protocol] { + if !r.isValidForOutgoingRLocked() { continue } @@ -342,7 +339,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn n.mu.RLock() defer n.mu.RUnlock() - primaryAddrs := n.primary[header.IPv6ProtocolNumber] + primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber] if len(primaryAddrs) == 0 { return nil @@ -425,7 +422,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn // hasPermanentAddrLocked returns true if n has a permanent (including currently // tentative) address, addr. func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { - ref, ok := n.endpoints[NetworkEndpointID{addr}] + ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return false @@ -436,24 +433,54 @@ func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { return kind == permanent || kind == permanentTentative } +type getRefBehaviour int + +const ( + // spoofing indicates that the NIC's spoofing flag should be observed when + // getting a NIC's referenced network endpoint. + spoofing getRefBehaviour = iota + + // promiscuous indicates that the NIC's promiscuous flag should be observed + // when getting a NIC's referenced network endpoint. + promiscuous + + // forceSpoofing indicates that the NIC should be assumed to be spoofing, + // regardless of what the NIC's spoofing flag is when getting a NIC's + // referenced network endpoint. + forceSpoofing +) + func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) + return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing) + return n.getRefOrCreateTemp(protocol, address, peb, spoofing) } // getRefEpOrCreateTemp returns the referenced network endpoint for the given -// protocol and address. If none exists a temporary one may be created if -// we are in promiscuous mode or spoofing. -func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint { +// protocol and address. +// +// If none exists a temporary one may be created if we are in promiscuous mode +// or spoofing. Promiscuous mode will only be checked if promiscuous is true. +// Similarly, spoofing will only be checked if spoofing is true. +func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { id := NetworkEndpointID{address} n.mu.RLock() - if ref, ok := n.endpoints[id]; ok { + var spoofingOrPromiscuous bool + switch tempRef { + case spoofing: + spoofingOrPromiscuous = n.mu.spoofing + case promiscuous: + spoofingOrPromiscuous = n.mu.promiscuous + case forceSpoofing: + spoofingOrPromiscuous = true + } + + if ref, ok := n.mu.endpoints[id]; ok { // An endpoint with this id exists, check if it can be used and return it. switch ref.getKind() { case permanentExpired: @@ -474,7 +501,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // the caller or if the address is found in the NIC's subnets. createTempEP := spoofingOrPromiscuous if !createTempEP { - for _, sn := range n.addressRanges { + for _, sn := range n.mu.addressRanges { // Skip the subnet address. if address == sn.ID() { continue @@ -502,7 +529,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - if ref, ok := n.endpoints[id]; ok { + if ref, ok := n.mu.endpoints[id]; ok { // No need to check the type as we are ok with expired endpoints at this // point. if ref.tryIncRef() { @@ -543,7 +570,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar // Sanity check. id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address} - if ref, ok := n.endpoints[id]; ok { + if ref, ok := n.mu.endpoints[id]; ok { // Endpoint already exists. if kind != permanent { return nil, tcpip.ErrDuplicateAddress @@ -562,7 +589,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar ref.deprecated = deprecated ref.configType = configType - refs := n.primary[ref.protocol] + refs := n.mu.primary[ref.protocol] for i, r := range refs { if r == ref { switch peb { @@ -572,9 +599,9 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar if i == 0 { return ref, nil } - n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) case NeverPrimaryEndpoint: - n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) return ref, nil } } @@ -637,13 +664,13 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } - n.endpoints[id] = ref + n.mu.endpoints[id] = ref n.insertPrimaryEndpointLocked(ref, peb) // If we are adding a tentative IPv6 address, start DAD. if isIPv6Unicast && kind == permanentTentative { - if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { + if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { return nil, err } } @@ -668,8 +695,8 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { n.mu.RLock() defer n.mu.RUnlock() - addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) - for nid, ref := range n.endpoints { + addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints)) + for nid, ref := range n.mu.endpoints { // Don't include tentative, expired or temporary endpoints to // avoid confusion and prevent the caller from using those. switch ref.getKind() { @@ -695,7 +722,7 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { defer n.mu.RUnlock() var addrs []tcpip.ProtocolAddress - for proto, list := range n.primary { + for proto, list := range n.mu.primary { for _, ref := range list { // Don't include tentative, expired or tempory endpoints // to avoid confusion and prevent the caller from using @@ -726,7 +753,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit n.mu.RLock() defer n.mu.RUnlock() - list, ok := n.primary[proto] + list, ok := n.mu.primary[proto] if !ok { return tcpip.AddressWithPrefix{} } @@ -769,7 +796,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit // address. func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { n.mu.Lock() - n.addressRanges = append(n.addressRanges, subnet) + n.mu.addressRanges = append(n.mu.addressRanges, subnet) n.mu.Unlock() } @@ -778,13 +805,13 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Lock() // Use the same underlying array. - tmp := n.addressRanges[:0] - for _, sub := range n.addressRanges { + tmp := n.mu.addressRanges[:0] + for _, sub := range n.mu.addressRanges { if sub != subnet { tmp = append(tmp, sub) } } - n.addressRanges = tmp + n.mu.addressRanges = tmp n.mu.Unlock() } @@ -793,8 +820,8 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() - sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints)) - for nid := range n.endpoints { + sns := make([]tcpip.Subnet, 0, len(n.mu.addressRanges)+len(n.mu.endpoints)) + for nid := range n.mu.endpoints { sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) if err != nil { // This should never happen as the mask has been carefully crafted to @@ -803,7 +830,7 @@ func (n *NIC) AddressRanges() []tcpip.Subnet { } sns = append(sns, sn) } - return append(sns, n.addressRanges...) + return append(sns, n.mu.addressRanges...) } // insertPrimaryEndpointLocked adds r to n's primary endpoint list as required @@ -813,9 +840,9 @@ func (n *NIC) AddressRanges() []tcpip.Subnet { func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) { switch peb { case CanBePrimaryEndpoint: - n.primary[r.protocol] = append(n.primary[r.protocol], r) + n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r) case FirstPrimaryEndpoint: - n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...) + n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...) } } @@ -827,7 +854,7 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { // and was waiting (on the lock) to be removed and 2) the same address was // re-added in the meantime by removing this endpoint from the list and // adding a new one. - if n.endpoints[id] != r { + if n.mu.endpoints[id] != r { return } @@ -835,11 +862,11 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { panic("Reference count dropped to zero before being removed") } - delete(n.endpoints, id) - refs := n.primary[r.protocol] + delete(n.mu.endpoints, id) + refs := n.mu.primary[r.protocol] for i, ref := range refs { if ref == r { - n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) break } } @@ -854,7 +881,7 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { } func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r, ok := n.endpoints[NetworkEndpointID{addr}] + r, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return tcpip.ErrBadLocalAddress } @@ -870,13 +897,13 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { // If we are removing a tentative IPv6 unicast address, stop // DAD. if kind == permanentTentative { - n.ndp.stopDuplicateAddressDetection(addr) + n.mu.ndp.stopDuplicateAddressDetection(addr) } // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. if r.configType == slaac { - n.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) + n.mu.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) } } @@ -926,7 +953,7 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A // outlined in RFC 3810 section 5. id := NetworkEndpointID{addr} - joins := n.mcastJoins[id] + joins := n.mu.mcastJoins[id] if joins == 0 { netProto, ok := n.stack.networkProtocols[protocol] if !ok { @@ -942,7 +969,7 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A return err } } - n.mcastJoins[id] = joins + 1 + n.mu.mcastJoins[id] = joins + 1 return nil } @@ -960,7 +987,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { // before leaveGroupLocked is called. func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} - joins := n.mcastJoins[id] + joins := n.mu.mcastJoins[id] switch joins { case 0: // There are no joins with this address on this NIC. @@ -971,7 +998,7 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { return err } } - n.mcastJoins[id] = joins - 1 + n.mu.mcastJoins[id] = joins - 1 return nil } @@ -1006,12 +1033,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // Are any packet sockets listening for this network protocol? n.mu.RLock() - packetEPs := n.packetEPs[protocol] + packetEPs := n.mu.packetEPs[protocol] // Check whether there are packet sockets listening for every protocol. // If we received a packet with protocol EthernetProtocolAll, then the // previous for loop will have handled it. if protocol != header.EthernetProtocolAll { - packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...) + packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) } n.mu.RUnlock() for _, ep := range packetEPs { @@ -1060,8 +1087,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // Found a NIC. n := r.ref.nic n.mu.RLock() - ref, ok := n.endpoints[NetworkEndpointID{dst}] - ok = ok && ref.isValidForOutgoing() && ref.tryIncRef() + ref, ok := n.mu.endpoints[NetworkEndpointID{dst}] + ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef() n.mu.RUnlock() if ok { r.RemoteAddress = src @@ -1181,7 +1208,7 @@ func (n *NIC) Stack() *Stack { // false. It will only return true if the address is associated with the NIC // AND it is tentative. func (n *NIC) isAddrTentative(addr tcpip.Address) bool { - ref, ok := n.endpoints[NetworkEndpointID{addr}] + ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return false } @@ -1197,7 +1224,7 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - ref, ok := n.endpoints[NetworkEndpointID{addr}] + ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return tcpip.ErrBadAddress } @@ -1217,7 +1244,7 @@ func (n *NIC) setNDPConfigs(c NDPConfigurations) { c.validate() n.mu.Lock() - n.ndp.configs = c + n.mu.ndp.configs = c n.mu.Unlock() } @@ -1226,7 +1253,7 @@ func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { n.mu.Lock() defer n.mu.Unlock() - n.ndp.handleRA(ip, ra) + n.mu.ndp.handleRA(ip, ra) } type networkEndpointKind int32 @@ -1268,11 +1295,11 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa n.mu.Lock() defer n.mu.Unlock() - eps, ok := n.packetEPs[netProto] + eps, ok := n.mu.packetEPs[netProto] if !ok { return tcpip.ErrNotSupported } - n.packetEPs[netProto] = append(eps, ep) + n.mu.packetEPs[netProto] = append(eps, ep) return nil } @@ -1281,14 +1308,14 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep n.mu.Lock() defer n.mu.Unlock() - eps, ok := n.packetEPs[netProto] + eps, ok := n.mu.packetEPs[netProto] if !ok { return } for i, epOther := range eps { if epOther == ep { - n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) + n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) return } } @@ -1346,14 +1373,19 @@ func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { // packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed), or the NIC to be in spoofing mode. func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { - return r.getKind() != permanentExpired || r.nic.spoofing + r.nic.mu.RLock() + defer r.nic.mu.RUnlock() + + return r.isValidForOutgoingRLocked() } -// isValidForIncoming returns true if the endpoint can accept an incoming -// packet. It requires the endpoint to not be marked expired (i.e., its address -// has been removed), or the NIC to be in promiscuous mode. -func (r *referencedNetworkEndpoint) isValidForIncoming() bool { - return r.getKind() != permanentExpired || r.nic.promiscuous +// isValidForOutgoingRLocked returns true if the endpoint can be used to send +// out a packet. It requires the endpoint to not be marked expired (i.e., its +// address has been removed), or the NIC to be in spoofing mode. +// +// r's NIC must be read locked. +func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { + return r.getKind() != permanentExpired || r.nic.mu.spoofing } // decRef decrements the ref count and cleans up the endpoint once it reaches -- cgit v1.2.3 From fb80979e3fe2614414d2d23c27e41bdb9e7c8541 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 24 Jan 2020 12:29:13 -0800 Subject: Increase timeouts for NDP tests' async events Increase the timeout to 1s when waiting for async NDP events to help reduce flakiness. This will not significantly increase test times as the async events continue to receive an event on a channel. The increased timeout allows more time for an event to be sent on the channel as the previous timeout of 100ms caused some flakes. Test: Existing tests pass PiperOrigin-RevId: 291420936 --- pkg/tcpip/stack/ndp_test.go | 47 +++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 23 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 376681b30..f9460bd51 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -35,13 +35,14 @@ import ( ) const ( - addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") - linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") - defaultTimeout = 100 * time.Millisecond + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") + linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") + defaultTimeout = 100 * time.Millisecond + defaultAsyncEventTimeout = time.Second ) var ( @@ -1086,7 +1087,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultTimeout) + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncEventTimeout) // Rx an RA from lladdr2 with huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) @@ -1103,7 +1104,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultTimeout) + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncEventTimeout) } // TestRouterDiscoveryMaxRouters tests that only @@ -1342,7 +1343,7 @@ func TestPrefixDiscovery(t *testing.T) { if diff := checkPrefixEvent(e, subnet2, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(lifetime)*time.Second + defaultTimeout): + case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncEventTimeout): t.Fatal("timed out waiting for prefix discovery event") } @@ -1681,7 +1682,7 @@ func TestAutoGenAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultTimeout): + case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if contains(s.NICInfo()[1].ProtocolAddresses, addr1) { @@ -1987,7 +1988,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -2027,7 +2028,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -2041,7 +2042,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultTimeout) + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout) if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -2073,7 +2074,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultTimeout): + case <-time.After(defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { @@ -2088,7 +2089,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { t.Fatalf("got unexpected auto-generated event") } - case <-time.After(newMinVLDuration + defaultTimeout): + case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -2213,7 +2214,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(minVLSeconds*time.Second + defaultTimeout): + case <-time.After(minVLSeconds*time.Second + defaultAsyncEventTimeout): t.Fatal("timeout waiting for addr auto gen event") } }) @@ -2701,7 +2702,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultTimeout): + case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3325,12 +3326,12 @@ func TestRouterSolicitation(t *testing.T) { // times. remaining := test.maxRtrSolicit if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultTimeout) + waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) remaining-- } for ; remaining > 0; remaining-- { waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout) - waitForPkt(2 * defaultTimeout) + waitForPkt(defaultAsyncEventTimeout) } // Make sure no more RS. @@ -3411,9 +3412,9 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Disable forwarding which should start router solicitations. s.SetForwarding(false) - waitForPkt(delay + defaultTimeout) - waitForPkt(interval + defaultTimeout) - waitForPkt(interval + defaultTimeout) + waitForPkt(delay + defaultAsyncEventTimeout) + waitForPkt(interval + defaultAsyncEventTimeout) + waitForPkt(interval + defaultAsyncEventTimeout) select { case <-e.C: t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") -- cgit v1.2.3 From 878bda6e19a0d55525ea6b1600f3413e0c5d6a84 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 24 Jan 2020 13:02:01 -0800 Subject: Lock the NIC when checking if an address is tentative PiperOrigin-RevId: 291426657 --- pkg/tcpip/stack/nic.go | 3 +++ 1 file changed, 3 insertions(+) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 79556a36f..7dad9a8cb 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1208,6 +1208,9 @@ func (n *NIC) Stack() *Stack { // false. It will only return true if the address is associated with the NIC // AND it is tentative. func (n *NIC) isAddrTentative(addr tcpip.Address) bool { + n.mu.RLock() + defer n.mu.RUnlock() + ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return false -- cgit v1.2.3 From d29e59af9fbd420e34378bcbf7ae543134070217 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Mon, 27 Jan 2020 10:04:07 -0800 Subject: Standardize on tools directory. PiperOrigin-RevId: 291745021 --- .bazelrc | 8 +- BUILD | 49 ++++++- benchmarks/defs.bzl | 18 --- benchmarks/harness/BUILD | 74 +++++----- benchmarks/harness/machine_producers/BUILD | 4 +- benchmarks/runner/BUILD | 24 ++-- benchmarks/tcp/BUILD | 3 +- benchmarks/workloads/ab/BUILD | 19 ++- benchmarks/workloads/absl/BUILD | 19 ++- benchmarks/workloads/curl/BUILD | 2 +- benchmarks/workloads/ffmpeg/BUILD | 2 +- benchmarks/workloads/fio/BUILD | 19 ++- benchmarks/workloads/httpd/BUILD | 2 +- benchmarks/workloads/iperf/BUILD | 19 ++- benchmarks/workloads/netcat/BUILD | 2 +- benchmarks/workloads/nginx/BUILD | 2 +- benchmarks/workloads/node/BUILD | 2 +- benchmarks/workloads/node_template/BUILD | 2 +- benchmarks/workloads/redis/BUILD | 2 +- benchmarks/workloads/redisbenchmark/BUILD | 19 ++- benchmarks/workloads/ruby/BUILD | 2 +- benchmarks/workloads/ruby_template/BUILD | 2 +- benchmarks/workloads/sleep/BUILD | 2 +- benchmarks/workloads/sysbench/BUILD | 19 ++- benchmarks/workloads/syscall/BUILD | 19 ++- benchmarks/workloads/tensorflow/BUILD | 2 +- benchmarks/workloads/true/BUILD | 2 +- pkg/abi/BUILD | 3 +- pkg/abi/linux/BUILD | 6 +- pkg/amutex/BUILD | 6 +- pkg/atomicbitops/BUILD | 6 +- pkg/binary/BUILD | 6 +- pkg/bits/BUILD | 6 +- pkg/bpf/BUILD | 6 +- pkg/compressio/BUILD | 6 +- pkg/control/client/BUILD | 3 +- pkg/control/server/BUILD | 3 +- pkg/cpuid/BUILD | 8 +- pkg/eventchannel/BUILD | 16 +-- pkg/fd/BUILD | 6 +- pkg/fdchannel/BUILD | 8 +- pkg/fdnotifier/BUILD | 3 +- pkg/flipcall/BUILD | 8 +- pkg/fspath/BUILD | 13 +- pkg/gate/BUILD | 4 +- pkg/goid/BUILD | 6 +- pkg/ilist/BUILD | 6 +- pkg/linewriter/BUILD | 6 +- pkg/log/BUILD | 6 +- pkg/memutil/BUILD | 3 +- pkg/metric/BUILD | 23 +-- pkg/p9/BUILD | 6 +- pkg/p9/p9test/BUILD | 6 +- pkg/procid/BUILD | 8 +- pkg/rand/BUILD | 3 +- pkg/refs/BUILD | 6 +- pkg/seccomp/BUILD | 6 +- pkg/secio/BUILD | 6 +- pkg/segment/test/BUILD | 6 +- pkg/sentry/BUILD | 2 + pkg/sentry/arch/BUILD | 20 +-- pkg/sentry/context/BUILD | 3 +- pkg/sentry/context/contexttest/BUILD | 3 +- pkg/sentry/control/BUILD | 8 +- pkg/sentry/device/BUILD | 6 +- pkg/sentry/fs/BUILD | 6 +- pkg/sentry/fs/anon/BUILD | 3 +- pkg/sentry/fs/dev/BUILD | 3 +- pkg/sentry/fs/fdpipe/BUILD | 6 +- pkg/sentry/fs/filetest/BUILD | 3 +- pkg/sentry/fs/fsutil/BUILD | 6 +- pkg/sentry/fs/gofer/BUILD | 6 +- pkg/sentry/fs/host/BUILD | 6 +- pkg/sentry/fs/lock/BUILD | 6 +- pkg/sentry/fs/proc/BUILD | 6 +- pkg/sentry/fs/proc/device/BUILD | 3 +- pkg/sentry/fs/proc/seqfile/BUILD | 6 +- pkg/sentry/fs/ramfs/BUILD | 6 +- pkg/sentry/fs/sys/BUILD | 3 +- pkg/sentry/fs/timerfd/BUILD | 3 +- pkg/sentry/fs/tmpfs/BUILD | 6 +- pkg/sentry/fs/tty/BUILD | 6 +- pkg/sentry/fsimpl/ext/BUILD | 6 +- pkg/sentry/fsimpl/ext/benchmark/BUILD | 2 +- pkg/sentry/fsimpl/ext/disklayout/BUILD | 6 +- pkg/sentry/fsimpl/kernfs/BUILD | 6 +- pkg/sentry/fsimpl/proc/BUILD | 8 +- pkg/sentry/fsimpl/sys/BUILD | 6 +- pkg/sentry/fsimpl/testutil/BUILD | 5 +- pkg/sentry/fsimpl/tmpfs/BUILD | 8 +- pkg/sentry/hostcpu/BUILD | 6 +- pkg/sentry/hostmm/BUILD | 3 +- pkg/sentry/inet/BUILD | 3 +- pkg/sentry/kernel/BUILD | 24 +--- pkg/sentry/kernel/auth/BUILD | 3 +- pkg/sentry/kernel/contexttest/BUILD | 3 +- pkg/sentry/kernel/epoll/BUILD | 6 +- pkg/sentry/kernel/eventfd/BUILD | 6 +- pkg/sentry/kernel/fasync/BUILD | 3 +- pkg/sentry/kernel/futex/BUILD | 6 +- pkg/sentry/kernel/memevent/BUILD | 20 +-- pkg/sentry/kernel/pipe/BUILD | 6 +- pkg/sentry/kernel/sched/BUILD | 6 +- pkg/sentry/kernel/semaphore/BUILD | 6 +- pkg/sentry/kernel/shm/BUILD | 3 +- pkg/sentry/kernel/signalfd/BUILD | 5 +- pkg/sentry/kernel/time/BUILD | 3 +- pkg/sentry/limits/BUILD | 6 +- pkg/sentry/loader/BUILD | 4 +- pkg/sentry/memmap/BUILD | 6 +- pkg/sentry/mm/BUILD | 6 +- pkg/sentry/pgalloc/BUILD | 6 +- pkg/sentry/platform/BUILD | 3 +- pkg/sentry/platform/interrupt/BUILD | 6 +- pkg/sentry/platform/kvm/BUILD | 6 +- pkg/sentry/platform/kvm/testutil/BUILD | 3 +- pkg/sentry/platform/ptrace/BUILD | 3 +- pkg/sentry/platform/ring0/BUILD | 3 +- pkg/sentry/platform/ring0/gen_offsets/BUILD | 2 +- pkg/sentry/platform/ring0/pagetables/BUILD | 16 +-- pkg/sentry/platform/safecopy/BUILD | 6 +- pkg/sentry/safemem/BUILD | 6 +- pkg/sentry/sighandling/BUILD | 3 +- pkg/sentry/socket/BUILD | 3 +- pkg/sentry/socket/control/BUILD | 3 +- pkg/sentry/socket/hostinet/BUILD | 3 +- pkg/sentry/socket/netfilter/BUILD | 3 +- pkg/sentry/socket/netlink/BUILD | 3 +- pkg/sentry/socket/netlink/port/BUILD | 6 +- pkg/sentry/socket/netlink/route/BUILD | 3 +- pkg/sentry/socket/netlink/uevent/BUILD | 3 +- pkg/sentry/socket/netstack/BUILD | 3 +- pkg/sentry/socket/unix/BUILD | 3 +- pkg/sentry/socket/unix/transport/BUILD | 3 +- pkg/sentry/state/BUILD | 3 +- pkg/sentry/strace/BUILD | 20 +-- pkg/sentry/syscalls/BUILD | 3 +- pkg/sentry/syscalls/linux/BUILD | 3 +- pkg/sentry/time/BUILD | 6 +- pkg/sentry/unimpl/BUILD | 21 +-- pkg/sentry/uniqueid/BUILD | 3 +- pkg/sentry/usage/BUILD | 5 +- pkg/sentry/usermem/BUILD | 7 +- pkg/sentry/vfs/BUILD | 8 +- pkg/sentry/watchdog/BUILD | 3 +- pkg/sleep/BUILD | 6 +- pkg/state/BUILD | 17 +-- pkg/state/statefile/BUILD | 6 +- pkg/sync/BUILD | 6 +- pkg/sync/atomicptrtest/BUILD | 6 +- pkg/sync/seqatomictest/BUILD | 6 +- pkg/syserr/BUILD | 3 +- pkg/syserror/BUILD | 4 +- pkg/tcpip/BUILD | 6 +- pkg/tcpip/adapters/gonet/BUILD | 6 +- pkg/tcpip/buffer/BUILD | 6 +- pkg/tcpip/checker/BUILD | 3 +- pkg/tcpip/hash/jenkins/BUILD | 6 +- pkg/tcpip/header/BUILD | 6 +- pkg/tcpip/iptables/BUILD | 3 +- pkg/tcpip/link/channel/BUILD | 3 +- pkg/tcpip/link/fdbased/BUILD | 6 +- pkg/tcpip/link/loopback/BUILD | 3 +- pkg/tcpip/link/muxed/BUILD | 6 +- pkg/tcpip/link/rawfile/BUILD | 3 +- pkg/tcpip/link/sharedmem/BUILD | 6 +- pkg/tcpip/link/sharedmem/pipe/BUILD | 6 +- pkg/tcpip/link/sharedmem/queue/BUILD | 6 +- pkg/tcpip/link/sniffer/BUILD | 3 +- pkg/tcpip/link/tun/BUILD | 3 +- pkg/tcpip/link/waitable/BUILD | 6 +- pkg/tcpip/network/BUILD | 2 +- pkg/tcpip/network/arp/BUILD | 4 +- pkg/tcpip/network/fragmentation/BUILD | 6 +- pkg/tcpip/network/hash/BUILD | 3 +- pkg/tcpip/network/ipv4/BUILD | 4 +- pkg/tcpip/network/ipv6/BUILD | 6 +- pkg/tcpip/ports/BUILD | 6 +- pkg/tcpip/sample/tun_tcp_connect/BUILD | 2 +- pkg/tcpip/sample/tun_tcp_echo/BUILD | 2 +- pkg/tcpip/seqnum/BUILD | 3 +- pkg/tcpip/stack/BUILD | 6 +- pkg/tcpip/transport/icmp/BUILD | 3 +- pkg/tcpip/transport/packet/BUILD | 3 +- pkg/tcpip/transport/raw/BUILD | 3 +- pkg/tcpip/transport/tcp/BUILD | 4 +- pkg/tcpip/transport/tcp/testing/context/BUILD | 3 +- pkg/tcpip/transport/tcpconntrack/BUILD | 4 +- pkg/tcpip/transport/udp/BUILD | 4 +- pkg/tmutex/BUILD | 6 +- pkg/unet/BUILD | 6 +- pkg/urpc/BUILD | 6 +- pkg/waiter/BUILD | 6 +- runsc/BUILD | 27 ++-- runsc/boot/BUILD | 5 +- runsc/boot/filter/BUILD | 3 +- runsc/boot/platforms/BUILD | 3 +- runsc/cgroup/BUILD | 5 +- runsc/cmd/BUILD | 5 +- runsc/console/BUILD | 3 +- runsc/container/BUILD | 5 +- runsc/container/test_app/BUILD | 4 +- runsc/criutil/BUILD | 3 +- runsc/dockerutil/BUILD | 3 +- runsc/fsgofer/BUILD | 9 +- runsc/fsgofer/filter/BUILD | 3 +- runsc/sandbox/BUILD | 3 +- runsc/specutils/BUILD | 5 +- runsc/testutil/BUILD | 3 +- runsc/version_test.sh | 2 +- scripts/common.sh | 6 +- scripts/common_bazel.sh | 99 ------------- scripts/common_build.sh | 99 +++++++++++++ test/BUILD | 45 +----- test/e2e/BUILD | 5 +- test/image/BUILD | 5 +- test/iptables/BUILD | 5 +- test/iptables/runner/BUILD | 12 +- test/root/BUILD | 5 +- test/root/testdata/BUILD | 3 +- test/runtimes/BUILD | 4 +- test/runtimes/build_defs.bzl | 5 +- test/runtimes/images/proctor/BUILD | 4 +- test/syscalls/BUILD | 2 +- test/syscalls/build_defs.bzl | 6 +- test/syscalls/gtest/BUILD | 7 +- test/syscalls/linux/BUILD | 23 ++- test/syscalls/linux/arch_prctl.cc | 2 + test/syscalls/linux/rseq/BUILD | 5 +- .../linux/udp_socket_errqueue_test_case.cc | 4 + test/uds/BUILD | 3 +- test/util/BUILD | 27 ++-- test/util/save_util_linux.cc | 4 + test/util/save_util_other.cc | 4 + test/util/test_util_runfiles.cc | 4 + tools/BUILD | 3 + tools/build/BUILD | 10 ++ tools/build/defs.bzl | 91 ++++++++++++ tools/checkunsafe/BUILD | 3 +- tools/defs.bzl | 154 +++++++++++++++++++++ tools/go_generics/BUILD | 2 +- tools/go_generics/globals/BUILD | 4 +- tools/go_generics/go_merge/BUILD | 2 +- tools/go_generics/rules_tests/BUILD | 2 +- tools/go_marshal/BUILD | 4 +- tools/go_marshal/README.md | 52 +------ tools/go_marshal/analysis/BUILD | 5 +- tools/go_marshal/defs.bzl | 112 ++------------- tools/go_marshal/gomarshal/BUILD | 6 +- tools/go_marshal/gomarshal/generator.go | 20 ++- tools/go_marshal/gomarshal/generator_tests.go | 6 +- tools/go_marshal/main.go | 11 +- tools/go_marshal/marshal/BUILD | 5 +- tools/go_marshal/test/BUILD | 7 +- tools/go_marshal/test/external/BUILD | 6 +- tools/go_stateify/BUILD | 2 +- tools/go_stateify/defs.bzl | 79 +---------- tools/images/BUILD | 2 +- tools/images/defs.bzl | 6 +- tools/issue_reviver/BUILD | 2 +- tools/issue_reviver/github/BUILD | 3 +- tools/issue_reviver/reviver/BUILD | 5 +- tools/workspace_status.sh | 2 +- vdso/BUILD | 33 ++--- 264 files changed, 1012 insertions(+), 1380 deletions(-) delete mode 100644 benchmarks/defs.bzl delete mode 100755 scripts/common_bazel.sh create mode 100755 scripts/common_build.sh create mode 100644 tools/BUILD create mode 100644 tools/build/BUILD create mode 100644 tools/build/defs.bzl create mode 100644 tools/defs.bzl (limited to 'pkg/tcpip/stack') diff --git a/.bazelrc b/.bazelrc index 9c35c5e7b..ef214bcfa 100644 --- a/.bazelrc +++ b/.bazelrc @@ -30,10 +30,10 @@ build:remote --auth_scope="https://www.googleapis.com/auth/cloud-source-tools" # Add a custom platform and toolchain that builds in a privileged docker # container, which is required by our syscall tests. -build:remote --host_platform=//test:rbe_ubuntu1604 -build:remote --extra_toolchains=//test:cc-toolchain-clang-x86_64-default -build:remote --extra_execution_platforms=//test:rbe_ubuntu1604 -build:remote --platforms=//test:rbe_ubuntu1604 +build:remote --host_platform=//:rbe_ubuntu1604 +build:remote --extra_toolchains=//:cc-toolchain-clang-x86_64-default +build:remote --extra_execution_platforms=//:rbe_ubuntu1604 +build:remote --platforms=//:rbe_ubuntu1604 build:remote --crosstool_top=@rbe_default//cc:toolchain build:remote --jobs=50 build:remote --remote_timeout=3600 diff --git a/BUILD b/BUILD index 76286174f..5fd929378 100644 --- a/BUILD +++ b/BUILD @@ -1,8 +1,8 @@ -package(licenses = ["notice"]) # Apache 2.0 - load("@io_bazel_rules_go//go:def.bzl", "go_path", "nogo") load("@bazel_gazelle//:def.bzl", "gazelle") +package(licenses = ["notice"]) + # The sandbox filegroup is used for sandbox-internal dependencies. package_group( name = "sandbox", @@ -49,9 +49,52 @@ gazelle(name = "gazelle") # live in the tools subdirectory (unless they are standard). nogo( name = "nogo", - config = "tools/nogo.js", + config = "//tools:nogo.js", visibility = ["//visibility:public"], deps = [ "//tools/checkunsafe", ], ) + +# We need to define a bazel platform and toolchain to specify dockerPrivileged +# and dockerRunAsRoot options, they are required to run tests on the RBE +# cluster in Kokoro. +alias( + name = "rbe_ubuntu1604", + actual = ":rbe_ubuntu1604_r346485", +) + +platform( + name = "rbe_ubuntu1604_r346485", + constraint_values = [ + "@bazel_tools//platforms:x86_64", + "@bazel_tools//platforms:linux", + "@bazel_tools//tools/cpp:clang", + "@bazel_toolchains//constraints:xenial", + "@bazel_toolchains//constraints/sanitizers:support_msan", + ], + remote_execution_properties = """ + properties: { + name: "container-image" + value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:93f7e127196b9b653d39830c50f8b05d49ef6fd8739a9b5b8ab16e1df5399e50" + } + properties: { + name: "dockerAddCapabilities" + value: "SYS_ADMIN" + } + properties: { + name: "dockerPrivileged" + value: "true" + } + """, +) + +toolchain( + name = "cc-toolchain-clang-x86_64-default", + exec_compatible_with = [ + ], + target_compatible_with = [ + ], + toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) diff --git a/benchmarks/defs.bzl b/benchmarks/defs.bzl deleted file mode 100644 index 79e6cdbc8..000000000 --- a/benchmarks/defs.bzl +++ /dev/null @@ -1,18 +0,0 @@ -"""Provides python helper functions.""" - -load("@pydeps//:requirements.bzl", _requirement = "requirement") - -def filter_deps(deps = None): - if deps == None: - deps = [] - return [dep for dep in deps if dep] - -def py_library(deps = None, **kwargs): - return native.py_library(deps = filter_deps(deps), **kwargs) - -def py_test(deps = None, **kwargs): - return native.py_test(deps = filter_deps(deps), **kwargs) - -def requirement(name, direct = True): - """ requirement returns the required dependency. """ - return _requirement(name) diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD index 081a74243..52d4e42f8 100644 --- a/benchmarks/harness/BUILD +++ b/benchmarks/harness/BUILD @@ -1,4 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "requirement") +load("//tools:defs.bzl", "py_library", "py_requirement") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -25,16 +25,16 @@ py_library( srcs = ["container.py"], deps = [ "//benchmarks/workloads", - requirement("asn1crypto", False), - requirement("chardet", False), - requirement("certifi", False), - requirement("docker", True), - requirement("docker-pycreds", False), - requirement("idna", False), - requirement("ptyprocess", False), - requirement("requests", False), - requirement("urllib3", False), - requirement("websocket-client", False), + py_requirement("asn1crypto", False), + py_requirement("chardet", False), + py_requirement("certifi", False), + py_requirement("docker", True), + py_requirement("docker-pycreds", False), + py_requirement("idna", False), + py_requirement("ptyprocess", False), + py_requirement("requests", False), + py_requirement("urllib3", False), + py_requirement("websocket-client", False), ], ) @@ -47,17 +47,17 @@ py_library( "//benchmarks/harness:ssh_connection", "//benchmarks/harness:tunnel_dispatcher", "//benchmarks/harness/machine_mocks", - requirement("asn1crypto", False), - requirement("chardet", False), - requirement("certifi", False), - requirement("docker", True), - requirement("docker-pycreds", False), - requirement("idna", False), - requirement("ptyprocess", False), - requirement("requests", False), - requirement("six", False), - requirement("urllib3", False), - requirement("websocket-client", False), + py_requirement("asn1crypto", False), + py_requirement("chardet", False), + py_requirement("certifi", False), + py_requirement("docker", True), + py_requirement("docker-pycreds", False), + py_requirement("idna", False), + py_requirement("ptyprocess", False), + py_requirement("requests", False), + py_requirement("six", False), + py_requirement("urllib3", False), + py_requirement("websocket-client", False), ], ) @@ -66,10 +66,10 @@ py_library( srcs = ["ssh_connection.py"], deps = [ "//benchmarks/harness", - requirement("bcrypt", False), - requirement("cffi", True), - requirement("paramiko", True), - requirement("cryptography", False), + py_requirement("bcrypt", False), + py_requirement("cffi", True), + py_requirement("paramiko", True), + py_requirement("cryptography", False), ], ) @@ -77,16 +77,16 @@ py_library( name = "tunnel_dispatcher", srcs = ["tunnel_dispatcher.py"], deps = [ - requirement("asn1crypto", False), - requirement("chardet", False), - requirement("certifi", False), - requirement("docker", True), - requirement("docker-pycreds", False), - requirement("idna", False), - requirement("pexpect", True), - requirement("ptyprocess", False), - requirement("requests", False), - requirement("urllib3", False), - requirement("websocket-client", False), + py_requirement("asn1crypto", False), + py_requirement("chardet", False), + py_requirement("certifi", False), + py_requirement("docker", True), + py_requirement("docker-pycreds", False), + py_requirement("idna", False), + py_requirement("pexpect", True), + py_requirement("ptyprocess", False), + py_requirement("requests", False), + py_requirement("urllib3", False), + py_requirement("websocket-client", False), ], ) diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD index c4e943882..48ea0ef39 100644 --- a/benchmarks/harness/machine_producers/BUILD +++ b/benchmarks/harness/machine_producers/BUILD @@ -1,4 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "requirement") +load("//tools:defs.bzl", "py_library", "py_requirement") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -31,7 +31,7 @@ py_library( deps = [ "//benchmarks/harness:machine", "//benchmarks/harness/machine_producers:machine_producer", - requirement("PyYAML", False), + py_requirement("PyYAML", False), ], ) diff --git a/benchmarks/runner/BUILD b/benchmarks/runner/BUILD index e1b2ea550..fae0ca800 100644 --- a/benchmarks/runner/BUILD +++ b/benchmarks/runner/BUILD @@ -1,4 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") +load("//tools:defs.bzl", "py_library", "py_requirement", "py_test") package(licenses = ["notice"]) @@ -28,7 +28,7 @@ py_library( "//benchmarks/suites:startup", "//benchmarks/suites:sysbench", "//benchmarks/suites:syscall", - requirement("click", True), + py_requirement("click", True), ], ) @@ -36,7 +36,7 @@ py_library( name = "commands", srcs = ["commands.py"], deps = [ - requirement("click", True), + py_requirement("click", True), ], ) @@ -50,14 +50,14 @@ py_test( ], deps = [ ":runner", - requirement("click", True), - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("click", True), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/tcp/BUILD b/benchmarks/tcp/BUILD index 735d7127f..d5e401acc 100644 --- a/benchmarks/tcp/BUILD +++ b/benchmarks/tcp/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") -load("@rules_cc//cc:defs.bzl", "cc_binary") +load("//tools:defs.bzl", "cc_binary", "go_binary") package(licenses = ["notice"]) diff --git a/benchmarks/workloads/ab/BUILD b/benchmarks/workloads/ab/BUILD index 4fc0ab735..4dd91ceb3 100644 --- a/benchmarks/workloads/ab/BUILD +++ b/benchmarks/workloads/ab/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":ab", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/absl/BUILD b/benchmarks/workloads/absl/BUILD index 61e010096..55dae3baa 100644 --- a/benchmarks/workloads/absl/BUILD +++ b/benchmarks/workloads/absl/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":absl", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/curl/BUILD b/benchmarks/workloads/curl/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/curl/BUILD +++ b/benchmarks/workloads/curl/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/ffmpeg/BUILD b/benchmarks/workloads/ffmpeg/BUILD index be472dfb2..7c41ba631 100644 --- a/benchmarks/workloads/ffmpeg/BUILD +++ b/benchmarks/workloads/ffmpeg/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/fio/BUILD b/benchmarks/workloads/fio/BUILD index de257adad..7b78e8e75 100644 --- a/benchmarks/workloads/fio/BUILD +++ b/benchmarks/workloads/fio/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":fio", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/httpd/BUILD b/benchmarks/workloads/httpd/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/httpd/BUILD +++ b/benchmarks/workloads/httpd/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/iperf/BUILD b/benchmarks/workloads/iperf/BUILD index 8832a996c..570f40148 100644 --- a/benchmarks/workloads/iperf/BUILD +++ b/benchmarks/workloads/iperf/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":iperf", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/netcat/BUILD b/benchmarks/workloads/netcat/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/netcat/BUILD +++ b/benchmarks/workloads/netcat/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/nginx/BUILD b/benchmarks/workloads/nginx/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/nginx/BUILD +++ b/benchmarks/workloads/nginx/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/node/BUILD b/benchmarks/workloads/node/BUILD index 71cd9f519..bfcf78cf9 100644 --- a/benchmarks/workloads/node/BUILD +++ b/benchmarks/workloads/node/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/node_template/BUILD b/benchmarks/workloads/node_template/BUILD index ca996f068..e142f082a 100644 --- a/benchmarks/workloads/node_template/BUILD +++ b/benchmarks/workloads/node_template/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/redis/BUILD b/benchmarks/workloads/redis/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/redis/BUILD +++ b/benchmarks/workloads/redis/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/redisbenchmark/BUILD b/benchmarks/workloads/redisbenchmark/BUILD index f5994a815..f472a4443 100644 --- a/benchmarks/workloads/redisbenchmark/BUILD +++ b/benchmarks/workloads/redisbenchmark/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":redisbenchmark", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/ruby/BUILD b/benchmarks/workloads/ruby/BUILD index e37d77804..a3be4fe92 100644 --- a/benchmarks/workloads/ruby/BUILD +++ b/benchmarks/workloads/ruby/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/ruby_template/BUILD b/benchmarks/workloads/ruby_template/BUILD index 27f7c0c46..59443b14a 100644 --- a/benchmarks/workloads/ruby_template/BUILD +++ b/benchmarks/workloads/ruby_template/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/sleep/BUILD b/benchmarks/workloads/sleep/BUILD index eb0fb6165..a70873065 100644 --- a/benchmarks/workloads/sleep/BUILD +++ b/benchmarks/workloads/sleep/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/sysbench/BUILD b/benchmarks/workloads/sysbench/BUILD index fd2f8f03d..3834af7ed 100644 --- a/benchmarks/workloads/sysbench/BUILD +++ b/benchmarks/workloads/sysbench/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":sysbench", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/syscall/BUILD b/benchmarks/workloads/syscall/BUILD index 5100cbb21..dba4bb1e7 100644 --- a/benchmarks/workloads/syscall/BUILD +++ b/benchmarks/workloads/syscall/BUILD @@ -1,5 +1,4 @@ -load("//benchmarks:defs.bzl", "py_library", "py_test", "requirement") -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test") package( default_visibility = ["//benchmarks:__subpackages__"], @@ -17,14 +16,14 @@ py_test( python_version = "PY3", deps = [ ":syscall", - requirement("attrs", False), - requirement("atomicwrites", False), - requirement("more-itertools", False), - requirement("pathlib2", False), - requirement("pluggy", False), - requirement("py", False), - requirement("pytest", True), - requirement("six", False), + py_requirement("attrs", False), + py_requirement("atomicwrites", False), + py_requirement("more-itertools", False), + py_requirement("pathlib2", False), + py_requirement("pluggy", False), + py_requirement("py", False), + py_requirement("pytest", True), + py_requirement("six", False), ], ) diff --git a/benchmarks/workloads/tensorflow/BUILD b/benchmarks/workloads/tensorflow/BUILD index 026c3b316..a7b7742f4 100644 --- a/benchmarks/workloads/tensorflow/BUILD +++ b/benchmarks/workloads/tensorflow/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/benchmarks/workloads/true/BUILD b/benchmarks/workloads/true/BUILD index 221c4b9a7..eba23d325 100644 --- a/benchmarks/workloads/true/BUILD +++ b/benchmarks/workloads/true/BUILD @@ -1,4 +1,4 @@ -load("@rules_pkg//:pkg.bzl", "pkg_tar") +load("//tools:defs.bzl", "pkg_tar") package( default_visibility = ["//benchmarks:__subpackages__"], diff --git a/pkg/abi/BUILD b/pkg/abi/BUILD index f5c08ea06..839f822eb 100644 --- a/pkg/abi/BUILD +++ b/pkg/abi/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,6 +9,5 @@ go_library( "abi_linux.go", "flag.go", ], - importpath = "gvisor.dev/gvisor/pkg/abi", visibility = ["//:sandbox"], ) diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 716ff22d2..1f3c0c687 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") # Package linux contains the constants and types needed to interface with a # Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix @@ -60,7 +59,6 @@ go_library( "wait.go", "xattr.go", ], - importpath = "gvisor.dev/gvisor/pkg/abi/linux", visibility = ["//visibility:public"], deps = [ "//pkg/abi", @@ -73,7 +71,7 @@ go_test( name = "linux_test", size = "small", srcs = ["netfilter_test.go"], - embed = [":linux"], + library = ":linux", deps = [ "//pkg/binary", ], diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index d99e37b40..9612f072e 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "amutex", srcs = ["amutex.go"], - importpath = "gvisor.dev/gvisor/pkg/amutex", visibility = ["//:sandbox"], ) @@ -14,6 +12,6 @@ go_test( name = "amutex_test", size = "small", srcs = ["amutex_test.go"], - embed = [":amutex"], + library = ":amutex", deps = ["//pkg/sync"], ) diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD index 6403c60c2..3948074ba 100644 --- a/pkg/atomicbitops/BUILD +++ b/pkg/atomicbitops/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "atomic_bitops_arm64.s", "atomic_bitops_common.go", ], - importpath = "gvisor.dev/gvisor/pkg/atomicbitops", visibility = ["//:sandbox"], ) @@ -19,6 +17,6 @@ go_test( name = "atomicbitops_test", size = "small", srcs = ["atomic_bitops_test.go"], - embed = [":atomicbitops"], + library = ":atomicbitops", deps = ["//pkg/sync"], ) diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD index 543fb54bf..7ca2fda90 100644 --- a/pkg/binary/BUILD +++ b/pkg/binary/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "binary", srcs = ["binary.go"], - importpath = "gvisor.dev/gvisor/pkg/binary", visibility = ["//:sandbox"], ) @@ -14,5 +12,5 @@ go_test( name = "binary_test", size = "small", srcs = ["binary_test.go"], - embed = [":binary"], + library = ":binary", ) diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD index 93b88a29a..63f4670d7 100644 --- a/pkg/bits/BUILD +++ b/pkg/bits/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) @@ -15,7 +14,6 @@ go_library( "uint64_arch_arm64_asm.s", "uint64_arch_generic.go", ], - importpath = "gvisor.dev/gvisor/pkg/bits", visibility = ["//:sandbox"], ) @@ -53,5 +51,5 @@ go_test( name = "bits_test", size = "small", srcs = ["uint64_test.go"], - embed = [":bits"], + library = ":bits", ) diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD index fba5643e8..2a6977f85 100644 --- a/pkg/bpf/BUILD +++ b/pkg/bpf/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "interpreter.go", "program_builder.go", ], - importpath = "gvisor.dev/gvisor/pkg/bpf", visibility = ["//visibility:public"], deps = ["//pkg/abi/linux"], ) @@ -25,7 +23,7 @@ go_test( "interpreter_test.go", "program_builder_test.go", ], - embed = [":bpf"], + library = ":bpf", deps = [ "//pkg/abi/linux", "//pkg/binary", diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD index 2bb581b18..1f75319a7 100644 --- a/pkg/compressio/BUILD +++ b/pkg/compressio/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "compressio", srcs = ["compressio.go"], - importpath = "gvisor.dev/gvisor/pkg/compressio", visibility = ["//:sandbox"], deps = [ "//pkg/binary", @@ -18,5 +16,5 @@ go_test( name = "compressio_test", size = "medium", srcs = ["compressio_test.go"], - embed = [":compressio"], + library = ":compressio", ) diff --git a/pkg/control/client/BUILD b/pkg/control/client/BUILD index 066d7b1a1..1b9e10ee7 100644 --- a/pkg/control/client/BUILD +++ b/pkg/control/client/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -7,7 +7,6 @@ go_library( srcs = [ "client.go", ], - importpath = "gvisor.dev/gvisor/pkg/control/client", visibility = ["//:sandbox"], deps = [ "//pkg/unet", diff --git a/pkg/control/server/BUILD b/pkg/control/server/BUILD index adbd1e3f8..002d2ef44 100644 --- a/pkg/control/server/BUILD +++ b/pkg/control/server/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "server", srcs = ["server.go"], - importpath = "gvisor.dev/gvisor/pkg/control/server", visibility = ["//:sandbox"], deps = [ "//pkg/log", diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD index ed111fd2a..43a432190 100644 --- a/pkg/cpuid/BUILD +++ b/pkg/cpuid/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "cpu_amd64.s", "cpuid.go", ], - importpath = "gvisor.dev/gvisor/pkg/cpuid", visibility = ["//:sandbox"], deps = ["//pkg/log"], ) @@ -18,7 +16,7 @@ go_test( name = "cpuid_test", size = "small", srcs = ["cpuid_test.go"], - embed = [":cpuid"], + library = ":cpuid", ) go_test( @@ -27,6 +25,6 @@ go_test( srcs = [ "cpuid_parse_test.go", ], - embed = [":cpuid"], + library = ":cpuid", tags = ["manual"], ) diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index 9d68682c7..bee28b68d 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -1,6 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") package(licenses = ["notice"]) @@ -10,7 +8,6 @@ go_library( "event.go", "rate.go", ], - importpath = "gvisor.dev/gvisor/pkg/eventchannel", visibility = ["//:sandbox"], deps = [ ":eventchannel_go_proto", @@ -24,22 +21,15 @@ go_library( ) proto_library( - name = "eventchannel_proto", + name = "eventchannel", srcs = ["event.proto"], visibility = ["//:sandbox"], ) -go_proto_library( - name = "eventchannel_go_proto", - importpath = "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto", - proto = ":eventchannel_proto", - visibility = ["//:sandbox"], -) - go_test( name = "eventchannel_test", srcs = ["event_test.go"], - embed = [":eventchannel"], + library = ":eventchannel", deps = [ "//pkg/sync", "@com_github_golang_protobuf//proto:go_default_library", diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD index afa8f7659..872361546 100644 --- a/pkg/fd/BUILD +++ b/pkg/fd/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "fd", srcs = ["fd.go"], - importpath = "gvisor.dev/gvisor/pkg/fd", visibility = ["//visibility:public"], ) @@ -14,5 +12,5 @@ go_test( name = "fd_test", size = "small", srcs = ["fd_test.go"], - embed = [":fd"], + library = ":fd", ) diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD index b0478c672..d9104ef02 100644 --- a/pkg/fdchannel/BUILD +++ b/pkg/fdchannel/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "fdchannel", srcs = ["fdchannel_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/fdchannel", visibility = ["//visibility:public"], ) @@ -14,6 +12,6 @@ go_test( name = "fdchannel_test", size = "small", srcs = ["fdchannel_test.go"], - embed = [":fdchannel"], + library = ":fdchannel", deps = ["//pkg/sync"], ) diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD index 91a202a30..235dcc490 100644 --- a/pkg/fdnotifier/BUILD +++ b/pkg/fdnotifier/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "fdnotifier.go", "poll_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/fdnotifier", visibility = ["//:sandbox"], deps = [ "//pkg/sync", diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index 85bd83af1..9c5ad500b 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -1,7 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "flipcall", @@ -13,7 +12,6 @@ go_library( "io.go", "packet_window_allocator.go", ], - importpath = "gvisor.dev/gvisor/pkg/flipcall", visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux", @@ -30,6 +28,6 @@ go_test( "flipcall_example_test.go", "flipcall_test.go", ], - embed = [":flipcall"], + library = ":flipcall", deps = ["//pkg/sync"], ) diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD index ca540363c..ee84471b2 100644 --- a/pkg/fspath/BUILD +++ b/pkg/fspath/BUILD @@ -1,10 +1,8 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) go_library( name = "fspath", @@ -13,7 +11,6 @@ go_library( "builder_unsafe.go", "fspath.go", ], - importpath = "gvisor.dev/gvisor/pkg/fspath", ) go_test( @@ -23,5 +20,5 @@ go_test( "builder_test.go", "fspath_test.go", ], - embed = [":fspath"], + library = ":fspath", ) diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD index f22bd070d..dd3141143 100644 --- a/pkg/gate/BUILD +++ b/pkg/gate/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -8,7 +7,6 @@ go_library( srcs = [ "gate.go", ], - importpath = "gvisor.dev/gvisor/pkg/gate", visibility = ["//visibility:public"], ) diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD index 5d31e5366..ea8d2422c 100644 --- a/pkg/goid/BUILD +++ b/pkg/goid/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "goid_race.go", "goid_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/goid", visibility = ["//visibility:public"], ) @@ -22,5 +20,5 @@ go_test( "empty_test.go", "goid_test.go", ], - embed = [":goid"], + library = ":goid", ) diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD index 34d2673ef..3f6eb07df 100644 --- a/pkg/ilist/BUILD +++ b/pkg/ilist/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( srcs = [ "interface_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/ilist", visibility = ["//visibility:public"], ) @@ -41,7 +39,7 @@ go_test( "list_test.go", "test_list.go", ], - embed = [":ilist"], + library = ":ilist", ) go_template( diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD index bcde6d308..41bf104d0 100644 --- a/pkg/linewriter/BUILD +++ b/pkg/linewriter/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "linewriter", srcs = ["linewriter.go"], - importpath = "gvisor.dev/gvisor/pkg/linewriter", visibility = ["//visibility:public"], deps = ["//pkg/sync"], ) @@ -14,5 +12,5 @@ go_library( go_test( name = "linewriter_test", srcs = ["linewriter_test.go"], - embed = [":linewriter"], + library = ":linewriter", ) diff --git a/pkg/log/BUILD b/pkg/log/BUILD index 0df0f2849..935d06963 100644 --- a/pkg/log/BUILD +++ b/pkg/log/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "json_k8s.go", "log.go", ], - importpath = "gvisor.dev/gvisor/pkg/log", visibility = [ "//visibility:public", ], @@ -29,5 +27,5 @@ go_test( "json_test.go", "log_test.go", ], - embed = [":log"], + library = ":log", ) diff --git a/pkg/memutil/BUILD b/pkg/memutil/BUILD index 7b50e2b28..9d07d98b4 100644 --- a/pkg/memutil/BUILD +++ b/pkg/memutil/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "memutil", srcs = ["memutil_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/memutil", visibility = ["//visibility:public"], deps = ["@org_golang_x_sys//unix:go_default_library"], ) diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD index 9145f3233..58305009d 100644 --- a/pkg/metric/BUILD +++ b/pkg/metric/BUILD @@ -1,14 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") package(licenses = ["notice"]) go_library( name = "metric", srcs = ["metric.go"], - importpath = "gvisor.dev/gvisor/pkg/metric", visibility = ["//:sandbox"], deps = [ ":metric_go_proto", @@ -19,28 +15,15 @@ go_library( ) proto_library( - name = "metric_proto", + name = "metric", srcs = ["metric.proto"], visibility = ["//:sandbox"], ) -cc_proto_library( - name = "metric_cc_proto", - visibility = ["//:sandbox"], - deps = [":metric_proto"], -) - -go_proto_library( - name = "metric_go_proto", - importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto", - proto = ":metric_proto", - visibility = ["//:sandbox"], -) - go_test( name = "metric_test", srcs = ["metric_test.go"], - embed = [":metric"], + library = ":metric", deps = [ ":metric_go_proto", "//pkg/eventchannel", diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD index a3e05c96d..4ccc1de86 100644 --- a/pkg/p9/BUILD +++ b/pkg/p9/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package( default_visibility = ["//visibility:public"], @@ -23,7 +22,6 @@ go_library( "transport_flipcall.go", "version.go", ], - importpath = "gvisor.dev/gvisor/pkg/p9", deps = [ "//pkg/fd", "//pkg/fdchannel", @@ -47,7 +45,7 @@ go_test( "transport_test.go", "version_test.go", ], - embed = [":p9"], + library = ":p9", deps = [ "//pkg/fd", "//pkg/unet", diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD index f4edd68b2..7ca67cb19 100644 --- a/pkg/p9/p9test/BUILD +++ b/pkg/p9/p9test/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary", "go_library", "go_test") package(licenses = ["notice"]) @@ -64,7 +63,6 @@ go_library( "mocks.go", "p9test.go", ], - importpath = "gvisor.dev/gvisor/pkg/p9/p9test", visibility = ["//:sandbox"], deps = [ "//pkg/fd", @@ -80,7 +78,7 @@ go_test( name = "client_test", size = "medium", srcs = ["client_test.go"], - embed = [":p9test"], + library = ":p9test", deps = [ "//pkg/fd", "//pkg/p9", diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD index b506813f0..aa3e3ac0b 100644 --- a/pkg/procid/BUILD +++ b/pkg/procid/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +9,6 @@ go_library( "procid_amd64.s", "procid_arm64.s", ], - importpath = "gvisor.dev/gvisor/pkg/procid", visibility = ["//visibility:public"], ) @@ -20,7 +18,7 @@ go_test( srcs = [ "procid_test.go", ], - embed = [":procid"], + library = ":procid", deps = ["//pkg/sync"], ) @@ -31,6 +29,6 @@ go_test( "procid_net_test.go", "procid_test.go", ], - embed = [":procid"], + library = ":procid", deps = ["//pkg/sync"], ) diff --git a/pkg/rand/BUILD b/pkg/rand/BUILD index 9d5b4859b..80b8ceb02 100644 --- a/pkg/rand/BUILD +++ b/pkg/rand/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "rand.go", "rand_linux.go", ], - importpath = "gvisor.dev/gvisor/pkg/rand", visibility = ["//:sandbox"], deps = [ "//pkg/sync", diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD index 974d9af9b..74affc887 100644 --- a/pkg/refs/BUILD +++ b/pkg/refs/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,7 +22,6 @@ go_library( "refcounter_state.go", "weak_ref_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/refs", visibility = ["//:sandbox"], deps = [ "//pkg/log", @@ -35,6 +33,6 @@ go_test( name = "refs_test", size = "small", srcs = ["refcounter_test.go"], - embed = [":refs"], + library = ":refs", deps = ["//pkg/sync"], ) diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index af94e944d..742c8b79b 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data", "go_test") +load("//tools:defs.bzl", "go_binary", "go_embed_data", "go_library", "go_test") package(licenses = ["notice"]) @@ -27,7 +26,6 @@ go_library( "seccomp_rules.go", "seccomp_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/seccomp", visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux", @@ -43,7 +41,7 @@ go_test( "seccomp_test.go", ":victim_data", ], - embed = [":seccomp"], + library = ":seccomp", deps = [ "//pkg/abi/linux", "//pkg/binary", diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD index 22abdc69f..60f63c7a6 100644 --- a/pkg/secio/BUILD +++ b/pkg/secio/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "full_reader.go", "secio.go", ], - importpath = "gvisor.dev/gvisor/pkg/secio", visibility = ["//pkg/sentry:internal"], ) @@ -17,5 +15,5 @@ go_test( name = "secio_test", size = "small", srcs = ["secio_test.go"], - embed = [":secio"], + library = ":secio", ) diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD index a27c35e21..f2d8462d8 100644 --- a/pkg/segment/test/BUILD +++ b/pkg/segment/test/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package( @@ -38,7 +37,6 @@ go_library( "int_set.go", "set_functions.go", ], - importpath = "gvisor.dev/gvisor/pkg/segment/segment", deps = [ "//pkg/state", ], @@ -48,5 +46,5 @@ go_test( name = "segment_test", size = "small", srcs = ["segment_test.go"], - embed = [":segment"], + library = ":segment", ) diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD index 2d6379c86..e8b794179 100644 --- a/pkg/sentry/BUILD +++ b/pkg/sentry/BUILD @@ -6,6 +6,8 @@ package(licenses = ["notice"]) package_group( name = "internal", packages = [ + "//cloud/gvisor/gopkg/sentry/...", + "//cloud/gvisor/sentry/...", "//pkg/sentry/...", "//runsc/...", # Code generated by go_marshal relies on go_marshal libraries. diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 65f22af2b..51ca09b24 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -1,6 +1,4 @@ -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "proto_library") package(licenses = ["notice"]) @@ -27,7 +25,6 @@ go_library( "syscalls_amd64.go", "syscalls_arm64.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/arch", visibility = ["//:sandbox"], deps = [ ":registers_go_proto", @@ -44,20 +41,7 @@ go_library( ) proto_library( - name = "registers_proto", + name = "registers", srcs = ["registers.proto"], visibility = ["//visibility:public"], ) - -cc_proto_library( - name = "registers_cc_proto", - visibility = ["//visibility:public"], - deps = [":registers_proto"], -) - -go_proto_library( - name = "registers_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto", - proto = ":registers_proto", - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/context/BUILD b/pkg/sentry/context/BUILD index 8dc1a77b1..e13a9ce20 100644 --- a/pkg/sentry/context/BUILD +++ b/pkg/sentry/context/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "context", srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/context", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/amutex", diff --git a/pkg/sentry/context/contexttest/BUILD b/pkg/sentry/context/contexttest/BUILD index 581e7aa96..f91a6d4ed 100644 --- a/pkg/sentry/context/contexttest/BUILD +++ b/pkg/sentry/context/contexttest/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "contexttest", testonly = 1, srcs = ["contexttest.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/context/contexttest", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/memutil", diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index 2561a6109..e69496477 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,9 +11,8 @@ go_library( "proc.go", "state.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/control", visibility = [ - "//pkg/sentry:internal", + "//:sandbox", ], deps = [ "//pkg/abi/linux", @@ -40,7 +38,7 @@ go_test( name = "control_test", size = "small", srcs = ["proc_test.go"], - embed = [":control"], + library = ":control", deps = [ "//pkg/log", "//pkg/sentry/kernel/time", diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD index 97fa1512c..e403cbd8b 100644 --- a/pkg/sentry/device/BUILD +++ b/pkg/sentry/device/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "device", srcs = ["device.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/device", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -18,5 +16,5 @@ go_test( name = "device_test", size = "small", srcs = ["device_test.go"], - embed = [":device"], + library = ":device", ) diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index 7d5d72d5a..605d61dbe 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -44,7 +43,6 @@ go_library( "splice.go", "sync.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -129,7 +127,7 @@ go_test( "mount_test.go", "path_test.go", ], - embed = [":fs"], + library = ":fs", deps = [ "//pkg/sentry/context", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/fs/anon/BUILD b/pkg/sentry/fs/anon/BUILD index ae1c9cf76..c14e5405e 100644 --- a/pkg/sentry/fs/anon/BUILD +++ b/pkg/sentry/fs/anon/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "anon.go", "device.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/anon", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index a0d9e8496..0c7247bd7 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -13,7 +13,6 @@ go_library( "random.go", "tty.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/dev", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD index cc43de69d..25ef96299 100644 --- a/pkg/sentry/fs/fdpipe/BUILD +++ b/pkg/sentry/fs/fdpipe/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +9,6 @@ go_library( "pipe_opener.go", "pipe_state.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/fdpipe", imports = ["gvisor.dev/gvisor/pkg/sentry/fs"], visibility = ["//pkg/sentry:internal"], deps = [ @@ -36,7 +34,7 @@ go_test( "pipe_opener_test.go", "pipe_test.go", ], - embed = [":fdpipe"], + library = ":fdpipe", deps = [ "//pkg/fd", "//pkg/fdnotifier", diff --git a/pkg/sentry/fs/filetest/BUILD b/pkg/sentry/fs/filetest/BUILD index 358dc2be3..9a7608cae 100644 --- a/pkg/sentry/fs/filetest/BUILD +++ b/pkg/sentry/fs/filetest/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "filetest", testonly = 1, srcs = ["filetest.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/filetest", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/context", diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 945b6270d..9142f5bdf 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -75,7 +74,6 @@ go_library( "inode.go", "inode_cached.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/fsutil", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -106,7 +104,7 @@ go_test( "dirty_set_test.go", "inode_cached_test.go", ], - embed = [":fsutil"], + library = ":fsutil", deps = [ "//pkg/sentry/context", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index fd870e8e1..cf48e7c03 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -22,7 +21,6 @@ go_library( "socket.go", "util.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/gofer", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -56,7 +54,7 @@ go_test( name = "gofer_test", size = "small", srcs = ["gofer_test.go"], - embed = [":gofer"], + library = ":gofer", deps = [ "//pkg/p9", "//pkg/p9/p9test", diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index 2b581aa69..f586f47c1 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -25,7 +24,6 @@ go_library( "util_arm64_unsafe.go", "util_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/host", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -69,7 +67,7 @@ go_test( "socket_test.go", "wait_test.go", ], - embed = [":host"], + library = ":host", deps = [ "//pkg/fd", "//pkg/fdnotifier", diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD index 2c332a82a..ae3331737 100644 --- a/pkg/sentry/fs/lock/BUILD +++ b/pkg/sentry/fs/lock/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -40,7 +39,6 @@ go_library( "lock_set.go", "lock_set_functions.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/lock", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", @@ -56,5 +54,5 @@ go_test( "lock_range_test.go", "lock_test.go", ], - embed = [":lock"], + library = ":lock", ) diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index cb37c6c6b..b06bead41 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -27,7 +26,6 @@ go_library( "uptime.go", "version.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -63,7 +61,7 @@ go_test( "net_test.go", "sys_net_test.go", ], - embed = [":proc"], + library = ":proc", deps = [ "//pkg/abi/linux", "//pkg/sentry/context", diff --git a/pkg/sentry/fs/proc/device/BUILD b/pkg/sentry/fs/proc/device/BUILD index 0394451d4..52c9aa93d 100644 --- a/pkg/sentry/fs/proc/device/BUILD +++ b/pkg/sentry/fs/proc/device/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "device", srcs = ["device.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc/device", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/sentry/device"], ) diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD index 38b246dff..310d8dd52 100644 --- a/pkg/sentry/fs/proc/seqfile/BUILD +++ b/pkg/sentry/fs/proc/seqfile/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "seqfile", srcs = ["seqfile.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -26,7 +24,7 @@ go_test( name = "seqfile_test", size = "small", srcs = ["seqfile_test.go"], - embed = [":seqfile"], + library = ":seqfile", deps = [ "//pkg/sentry/context", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD index 3fb7b0633..39c4b84f8 100644 --- a/pkg/sentry/fs/ramfs/BUILD +++ b/pkg/sentry/fs/ramfs/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "symlink.go", "tree.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/ramfs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -31,7 +29,7 @@ go_test( name = "ramfs_test", size = "small", srcs = ["tree_test.go"], - embed = [":ramfs"], + library = ":ramfs", deps = [ "//pkg/sentry/context/contexttest", "//pkg/sentry/fs", diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD index 25f0f124e..cc6b3bfbf 100644 --- a/pkg/sentry/fs/sys/BUILD +++ b/pkg/sentry/fs/sys/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -10,7 +10,6 @@ go_library( "fs.go", "sys.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/sys", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD index a215c1b95..092668e8d 100644 --- a/pkg/sentry/fs/timerfd/BUILD +++ b/pkg/sentry/fs/timerfd/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "timerfd", srcs = ["timerfd.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/timerfd", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/context", diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index 3400b940c..04776555f 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "inode_file.go", "tmpfs.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -41,7 +39,7 @@ go_test( name = "tmpfs_test", size = "small", srcs = ["file_test.go"], - embed = [":tmpfs"], + library = ":tmpfs", deps = [ "//pkg/sentry/context", "//pkg/sentry/fs", diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index f6f60d0cf..29f804c6c 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -14,7 +13,6 @@ go_library( "slave.go", "terminal.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fs/tty", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -40,7 +38,7 @@ go_test( name = "tty_test", size = "small", srcs = ["tty_test.go"], - embed = [":tty"], + library = ":tty", deps = [ "//pkg/abi/linux", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 903874141..a718920d5 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -32,7 +31,6 @@ go_library( "symlink.go", "utils.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -71,7 +69,7 @@ go_test( "//pkg/sentry/fsimpl/ext:assets/tiny.ext3", "//pkg/sentry/fsimpl/ext:assets/tiny.ext4", ], - embed = [":ext"], + library = ":ext", deps = [ "//pkg/abi/linux", "//pkg/binary", diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD index 4fc8296ef..12f3990c1 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/BUILD +++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD index fcfaf5c3e..9bd9c76c0 100644 --- a/pkg/sentry/fsimpl/ext/disklayout/BUILD +++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -23,7 +22,6 @@ go_library( "superblock_old.go", "test_utils.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -44,6 +42,6 @@ go_test( "inode_test.go", "superblock_test.go", ], - embed = [":disklayout"], + library = ":disklayout", deps = ["//pkg/sentry/kernel/time"], ) diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 66d409785..7bf83ccba 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -1,8 +1,7 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -package(licenses = ["notice"]) +licenses(["notice"]) go_template_instance( name = "slot_list", @@ -27,7 +26,6 @@ go_library( "slot_list.go", "symlink.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index c5b79fb38..3768f55b2 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -1,7 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "proc", @@ -15,7 +14,6 @@ go_library( "tasks_net.go", "tasks_sys.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc", deps = [ "//pkg/abi/linux", "//pkg/log", @@ -47,7 +45,7 @@ go_test( "tasks_sys_test.go", "tasks_test.go", ], - embed = [":proc"], + library = ":proc", deps = [ "//pkg/abi/linux", "//pkg/fspath", diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index ee3c842bd..beda141f1 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -1,14 +1,12 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "sys", srcs = [ "sys.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys", deps = [ "//pkg/abi/linux", "//pkg/sentry/context", diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index 4e70d84a7..12053a5b6 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -1,6 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "testutil", @@ -9,7 +9,6 @@ go_library( "kernel.go", "testutil.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 691476b4f..857e98bc5 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -1,8 +1,7 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -package(licenses = ["notice"]) +licenses(["notice"]) go_template_instance( name = "dentry_list", @@ -28,7 +27,6 @@ go_library( "symlink.go", "tmpfs.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs", deps = [ "//pkg/abi/linux", "//pkg/amutex", @@ -81,7 +79,7 @@ go_test( "regular_file_test.go", "stat_test.go", ], - embed = [":tmpfs"], + library = ":tmpfs", deps = [ "//pkg/abi/linux", "//pkg/fspath", diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD index 359468ccc..e6933aa70 100644 --- a/pkg/sentry/hostcpu/BUILD +++ b/pkg/sentry/hostcpu/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +9,6 @@ go_library( "getcpu_arm64.s", "hostcpu.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/hostcpu", visibility = ["//:sandbox"], ) @@ -18,5 +16,5 @@ go_test( name = "hostcpu_test", size = "small", srcs = ["hostcpu_test.go"], - embed = [":hostcpu"], + library = ":hostcpu", ) diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD index 67831d5a1..a145a5ca3 100644 --- a/pkg/sentry/hostmm/BUILD +++ b/pkg/sentry/hostmm/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "cgroup.go", "hostmm.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/hostmm", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/fd", diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 8d60ad4ad..aa621b724 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package( default_visibility = ["//:sandbox"], @@ -12,7 +12,6 @@ go_library( "inet.go", "test_stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/inet", deps = [ "//pkg/sentry/context", "//pkg/tcpip/stack", diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index ac85ba0c8..cebaccd92 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -1,8 +1,5 @@ -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -78,26 +75,12 @@ go_template_instance( ) proto_library( - name = "uncaught_signal_proto", + name = "uncaught_signal", srcs = ["uncaught_signal.proto"], visibility = ["//visibility:public"], deps = ["//pkg/sentry/arch:registers_proto"], ) -cc_proto_library( - name = "uncaught_signal_cc_proto", - visibility = ["//visibility:public"], - deps = [":uncaught_signal_proto"], -) - -go_proto_library( - name = "uncaught_signal_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto", - proto = ":uncaught_signal_proto", - visibility = ["//visibility:public"], - deps = ["//pkg/sentry/arch:registers_go_proto"], -) - go_library( name = "kernel", srcs = [ @@ -156,7 +139,6 @@ go_library( "vdso.go", "version.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel", imports = [ "gvisor.dev/gvisor/pkg/bpf", "gvisor.dev/gvisor/pkg/sentry/device", @@ -227,7 +209,7 @@ go_test( "task_test.go", "timekeeper_test.go", ], - embed = [":kernel"], + library = ":kernel", deps = [ "//pkg/abi", "//pkg/sentry/arch", diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 1aa72fa47..64537c9be 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -57,7 +57,6 @@ go_library( "id_map_set.go", "user_namespace.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/auth", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/contexttest/BUILD b/pkg/sentry/kernel/contexttest/BUILD index 3a88a585c..daff608d7 100644 --- a/pkg/sentry/kernel/contexttest/BUILD +++ b/pkg/sentry/kernel/contexttest/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "contexttest", testonly = 1, srcs = ["contexttest.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/context", diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD index c47f6b6fc..19e16ab3a 100644 --- a/pkg/sentry/kernel/epoll/BUILD +++ b/pkg/sentry/kernel/epoll/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,7 +22,6 @@ go_library( "epoll_list.go", "epoll_state.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/epoll", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/refs", @@ -43,7 +41,7 @@ go_test( srcs = [ "epoll_test.go", ], - embed = [":epoll"], + library = ":epoll", deps = [ "//pkg/sentry/context/contexttest", "//pkg/sentry/fs/filetest", diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD index c831fbab2..ee2d74864 100644 --- a/pkg/sentry/kernel/eventfd/BUILD +++ b/pkg/sentry/kernel/eventfd/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "eventfd", srcs = ["eventfd.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/eventfd", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -26,7 +24,7 @@ go_test( name = "eventfd_test", size = "small", srcs = ["eventfd_test.go"], - embed = [":eventfd"], + library = ":eventfd", deps = [ "//pkg/sentry/context/contexttest", "//pkg/sentry/usermem", diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index 6b36bc63e..b9126e946 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "fasync", srcs = ["fasync.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/fasync", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index 50db443ce..f413d8ae2 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -34,7 +33,6 @@ go_library( "futex.go", "waiter_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/futex", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -51,7 +49,7 @@ go_test( name = "futex_test", size = "small", srcs = ["futex_test.go"], - embed = [":futex"], + library = ":futex", deps = [ "//pkg/sentry/usermem", "//pkg/sync", diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD index 7f36252a9..4486848d2 100644 --- a/pkg/sentry/kernel/memevent/BUILD +++ b/pkg/sentry/kernel/memevent/BUILD @@ -1,13 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "proto_library") package(licenses = ["notice"]) go_library( name = "memevent", srcs = ["memory_events.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent", visibility = ["//:sandbox"], deps = [ ":memory_events_go_proto", @@ -21,20 +18,7 @@ go_library( ) proto_library( - name = "memory_events_proto", + name = "memory_events", srcs = ["memory_events.proto"], visibility = ["//visibility:public"], ) - -cc_proto_library( - name = "memory_events_cc_proto", - visibility = ["//visibility:public"], - deps = [":memory_events_proto"], -) - -go_proto_library( - name = "memory_events_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto", - proto = ":memory_events_proto", - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD index 5eeaeff66..2c7b6206f 100644 --- a/pkg/sentry/kernel/pipe/BUILD +++ b/pkg/sentry/kernel/pipe/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -30,7 +29,6 @@ go_library( "vfs.go", "writer.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/pipe", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -57,7 +55,7 @@ go_test( "node_test.go", "pipe_test.go", ], - embed = [":pipe"], + library = ":pipe", deps = [ "//pkg/sentry/context", "//pkg/sentry/context/contexttest", diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD index 98ea7a0d8..1b82e087b 100644 --- a/pkg/sentry/kernel/sched/BUILD +++ b/pkg/sentry/kernel/sched/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "cpuset.go", "sched.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/sched", visibility = ["//pkg/sentry:internal"], ) @@ -17,5 +15,5 @@ go_test( name = "sched_test", size = "small", srcs = ["cpuset_test.go"], - embed = [":sched"], + library = ":sched", ) diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD index 13a961594..76e19b551 100644 --- a/pkg/sentry/kernel/semaphore/BUILD +++ b/pkg/sentry/kernel/semaphore/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -22,7 +21,6 @@ go_library( "semaphore.go", "waiter_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/semaphore", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -40,7 +38,7 @@ go_test( name = "semaphore_test", size = "small", srcs = ["semaphore_test.go"], - embed = [":semaphore"], + library = ":semaphore", deps = [ "//pkg/abi/linux", "//pkg/sentry/context", diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index 7321b22ed..5547c5abf 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "device.go", "shm.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/shm", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD index 89e4d84b1..5d44773d4 100644 --- a/pkg/sentry/kernel/signalfd/BUILD +++ b/pkg/sentry/kernel/signalfd/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "signalfd", srcs = ["signalfd.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 4e4de0512..d49594d9f 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "context.go", "time.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/time", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD index 9fa841e8b..67869757f 100644 --- a/pkg/sentry/limits/BUILD +++ b/pkg/sentry/limits/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +9,6 @@ go_library( "limits.go", "linux.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/limits", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", @@ -25,5 +23,5 @@ go_test( srcs = [ "limits_test.go", ], - embed = [":limits"], + library = ":limits", ) diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index 2890393bd..d4ad2bd6c 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_embed_data") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_embed_data", "go_library") package(licenses = ["notice"]) @@ -20,7 +19,6 @@ go_library( "vdso_state.go", ":vdso_bin", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/loader", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi", diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index 112794e9c..f9a65f086 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -37,7 +36,6 @@ go_library( "mapping_set_impl.go", "memmap.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/memmap", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", @@ -52,6 +50,6 @@ go_test( name = "memmap_test", size = "small", srcs = ["mapping_set_test.go"], - embed = [":memmap"], + library = ":memmap", deps = ["//pkg/sentry/usermem"], ) diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index 83e248431..bd6399fa2 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -96,7 +95,6 @@ go_library( "vma.go", "vma_set.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/mm", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -128,7 +126,7 @@ go_test( name = "mm_test", size = "small", srcs = ["mm_test.go"], - embed = [":mm"], + library = ":mm", deps = [ "//pkg/sentry/arch", "//pkg/sentry/context", diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index a9a2642c5..02385a3ce 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -60,7 +59,6 @@ go_library( "save_restore.go", "usage_set.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/pgalloc", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/log", @@ -82,6 +80,6 @@ go_test( name = "pgalloc_test", size = "small", srcs = ["pgalloc_test.go"], - embed = [":pgalloc"], + library = ":pgalloc", deps = ["//pkg/sentry/usermem"], ) diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD index 157bffa81..006450b2d 100644 --- a/pkg/sentry/platform/BUILD +++ b/pkg/sentry/platform/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -22,7 +22,6 @@ go_library( "mmap_min_addr.go", "platform.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD index 85e882df9..83b385f14 100644 --- a/pkg/sentry/platform/interrupt/BUILD +++ b/pkg/sentry/platform/interrupt/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -8,7 +7,6 @@ go_library( srcs = [ "interrupt.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/interrupt", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/sync"], ) @@ -17,5 +15,5 @@ go_test( name = "interrupt_test", size = "small", srcs = ["interrupt_test.go"], - embed = [":interrupt"], + library = ":interrupt", ) diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 6a358d1d4..a4532a766 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -38,7 +37,6 @@ go_library( "physical_map_arm64.go", "virtual_map.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -65,7 +63,7 @@ go_test( "kvm_test.go", "virtual_map_test.go", ], - embed = [":kvm"], + library = ":kvm", tags = [ "manual", "nogotsan", diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD index b0e45f159..f7605df8a 100644 --- a/pkg/sentry/platform/kvm/testutil/BUILD +++ b/pkg/sentry/platform/kvm/testutil/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,6 +12,5 @@ go_library( "testutil_arm64.go", "testutil_arm64.s", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil", visibility = ["//pkg/sentry/platform/kvm:__pkg__"], ) diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index cd13390c3..3bcc5e040 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -20,7 +20,6 @@ go_library( "subprocess_linux_unsafe.go", "subprocess_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ptrace", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 87f4552b5..6dee8fcc5 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) @@ -74,7 +74,6 @@ go_library( "lib_arm64.s", "ring0.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/cpuid", diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 42076fb04..147311ed3 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD index 387a7f6c3..8b5cdd6c1 100644 --- a/pkg/sentry/platform/ring0/pagetables/BUILD +++ b/pkg/sentry/platform/ring0/pagetables/BUILD @@ -1,17 +1,14 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test", "select_arch") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) -config_setting( - name = "aarch64", - constraint_values = ["@bazel_tools//platforms:aarch64"], -) - go_template( name = "generic_walker", - srcs = ["walker_amd64.go"], + srcs = select_arch( + amd64 = ["walker_amd64.go"], + arm64 = ["walker_amd64.go"], + ), opt_types = [ "Visitor", ], @@ -91,7 +88,6 @@ go_library( "walker_map.go", "walker_unmap.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables", visibility = [ "//pkg/sentry/platform/kvm:__subpackages__", "//pkg/sentry/platform/ring0:__subpackages__", @@ -111,6 +107,6 @@ go_test( "pagetables_test.go", "walker_check.go", ], - embed = [":pagetables"], + library = ":pagetables", deps = ["//pkg/sentry/usermem"], ) diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD index 6769cd0a5..b8747585b 100644 --- a/pkg/sentry/platform/safecopy/BUILD +++ b/pkg/sentry/platform/safecopy/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -17,7 +16,6 @@ go_library( "sighandler_amd64.s", "sighandler_arm64.s", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/platform/safecopy", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/syserror"], ) @@ -27,5 +25,5 @@ go_test( srcs = [ "safecopy_test.go", ], - embed = [":safecopy"], + library = ":safecopy", ) diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD index 884020f7b..3ab76da97 100644 --- a/pkg/sentry/safemem/BUILD +++ b/pkg/sentry/safemem/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "safemem.go", "seq_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/safemem", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/platform/safecopy", @@ -25,5 +23,5 @@ go_test( "io_test.go", "seq_test.go", ], - embed = [":safemem"], + library = ":safemem", ) diff --git a/pkg/sentry/sighandling/BUILD b/pkg/sentry/sighandling/BUILD index f561670c7..6c38a3f44 100644 --- a/pkg/sentry/sighandling/BUILD +++ b/pkg/sentry/sighandling/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "sighandling.go", "sighandling_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/sighandling", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/abi/linux"], ) diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 26176b10d..8e2b97afb 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "socket", srcs = ["socket.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 357517ed4..3850f6345 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "control", srcs = ["control.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/control", imports = [ "gvisor.dev/gvisor/pkg/sentry/fs", ], diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 4c44c7c0f..42bf7be6a 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,7 +12,6 @@ go_library( "socket_unsafe.go", "stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/hostinet", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index b70047d81..ed34a8308 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -7,7 +7,6 @@ go_library( srcs = [ "netfilter.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netfilter", # This target depends on netstack and should only be used by epsocket, # which is allowed to depend on netstack. visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 103933144..baaac13c6 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "provider.go", "socket.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 2d9f4ba9b..3a22923d8 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "port", srcs = ["port.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port", visibility = ["//pkg/sentry:internal"], deps = ["//pkg/sync"], ) @@ -14,5 +12,5 @@ go_library( go_test( name = "port_test", srcs = ["port_test.go"], - embed = [":port"], + library = ":port", ) diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD index 1d4912753..2137c7aeb 100644 --- a/pkg/sentry/socket/netlink/route/BUILD +++ b/pkg/sentry/socket/netlink/route/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "route", srcs = ["protocol.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/route", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD index 0777f3baf..73fbdf1eb 100644 --- a/pkg/sentry/socket/netlink/uevent/BUILD +++ b/pkg/sentry/socket/netlink/uevent/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "uevent", srcs = ["protocol.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/uevent", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index f78784569..e3d1f90cb 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -11,7 +11,6 @@ go_library( "save_restore.go", "stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netstack", visibility = [ "//pkg/sentry:internal", ], diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index 5b6a154f6..bade18686 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "io.go", "unix.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index d7ba95dff..4bdfc9208 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -25,7 +25,6 @@ go_library( "transport_message_list.go", "unix.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD index 88765f4d6..0ea4aab8b 100644 --- a/pkg/sentry/state/BUILD +++ b/pkg/sentry/state/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "state_metadata.go", "state_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/state", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index aa1ac720c..ff6fafa63 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -1,6 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "proto_library") package(licenses = ["notice"]) @@ -21,7 +19,6 @@ go_library( "strace.go", "syscalls.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/strace", visibility = ["//:sandbox"], deps = [ ":strace_go_proto", @@ -42,20 +39,7 @@ go_library( ) proto_library( - name = "strace_proto", + name = "strace", srcs = ["strace.proto"], visibility = ["//visibility:public"], ) - -cc_proto_library( - name = "strace_cc_proto", - visibility = ["//visibility:public"], - deps = [":strace_proto"], -) - -go_proto_library( - name = "strace_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto", - proto = ":strace_proto", - visibility = ["//visibility:public"], -) diff --git a/pkg/sentry/syscalls/BUILD b/pkg/sentry/syscalls/BUILD index 79d972202..b8d1bd415 100644 --- a/pkg/sentry/syscalls/BUILD +++ b/pkg/sentry/syscalls/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "epoll.go", "syscalls.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/syscalls", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 917f74e07..7d74e0f70 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -57,7 +57,6 @@ go_library( "sys_xattr.go", "timespec.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/syscalls/linux", visibility = ["//:sandbox"], deps = [ "//pkg/abi", diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 3cde3a0be..04f81a35b 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -31,7 +30,6 @@ go_library( "tsc_amd64.s", "tsc_arm64.s", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/time", visibility = ["//:sandbox"], deps = [ "//pkg/log", @@ -48,5 +46,5 @@ go_test( "parameters_test.go", "sampler_test.go", ], - embed = [":time"], + library = ":time", ) diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD index fc7614fff..370fa6ec5 100644 --- a/pkg/sentry/unimpl/BUILD +++ b/pkg/sentry/unimpl/BUILD @@ -1,34 +1,17 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("//tools:defs.bzl", "go_library", "proto_library") package(licenses = ["notice"]) proto_library( - name = "unimplemented_syscall_proto", + name = "unimplemented_syscall", srcs = ["unimplemented_syscall.proto"], visibility = ["//visibility:public"], deps = ["//pkg/sentry/arch:registers_proto"], ) -cc_proto_library( - name = "unimplemented_syscall_cc_proto", - visibility = ["//visibility:public"], - deps = [":unimplemented_syscall_proto"], -) - -go_proto_library( - name = "unimplemented_syscall_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto", - proto = ":unimplemented_syscall_proto", - visibility = ["//visibility:public"], - deps = ["//pkg/sentry/arch:registers_go_proto"], -) - go_library( name = "unimpl", srcs = ["events.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl", visibility = ["//:sandbox"], deps = [ "//pkg/log", diff --git a/pkg/sentry/uniqueid/BUILD b/pkg/sentry/uniqueid/BUILD index 86a87edd4..e9c18f170 100644 --- a/pkg/sentry/uniqueid/BUILD +++ b/pkg/sentry/uniqueid/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "uniqueid", srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/uniqueid", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/sentry/context", diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD index 5518ac3d0..099315613 100644 --- a/pkg/sentry/usage/BUILD +++ b/pkg/sentry/usage/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -11,9 +11,8 @@ go_library( "memory_unsafe.go", "usage.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/usage", visibility = [ - "//pkg/sentry:internal", + "//:sandbox", ], deps = [ "//pkg/bits", diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD index 684f59a6b..c8322e29e 100644 --- a/pkg/sentry/usermem/BUILD +++ b/pkg/sentry/usermem/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -29,7 +28,6 @@ go_library( "usermem_unsafe.go", "usermem_x86.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/usermem", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/atomicbitops", @@ -38,7 +36,6 @@ go_library( "//pkg/sentry/context", "//pkg/sentry/safemem", "//pkg/syserror", - "//pkg/tcpip/buffer", ], ) @@ -49,7 +46,7 @@ go_test( "addr_range_seq_test.go", "usermem_test.go", ], - embed = [":usermem"], + library = ":usermem", deps = [ "//pkg/sentry/context", "//pkg/sentry/safemem", diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 35c7be259..51acdc4e9 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -1,7 +1,6 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "vfs", @@ -24,7 +23,6 @@ go_library( "testutil.go", "vfs.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/vfs", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", @@ -47,7 +45,7 @@ go_test( "file_description_impl_util_test.go", "mount_test.go", ], - embed = [":vfs"], + library = ":vfs", deps = [ "//pkg/abi/linux", "//pkg/sentry/context", diff --git a/pkg/sentry/watchdog/BUILD b/pkg/sentry/watchdog/BUILD index 28f21f13d..1c5a1c9b6 100644 --- a/pkg/sentry/watchdog/BUILD +++ b/pkg/sentry/watchdog/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "watchdog", srcs = ["watchdog.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/watchdog", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index a23c86fb1..e131455f7 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "commit_noasm.go", "sleep_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sleep", visibility = ["//:sandbox"], ) @@ -22,5 +20,5 @@ go_test( srcs = [ "sleep_test.go", ], - embed = [":sleep"], + library = ":sleep", ) diff --git a/pkg/state/BUILD b/pkg/state/BUILD index be93750bf..921af9d63 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -1,6 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -49,7 +47,7 @@ go_library( "state.go", "stats.go", ], - importpath = "gvisor.dev/gvisor/pkg/state", + stateify = False, visibility = ["//:sandbox"], deps = [ ":object_go_proto", @@ -58,21 +56,14 @@ go_library( ) proto_library( - name = "object_proto", + name = "object", srcs = ["object.proto"], visibility = ["//:sandbox"], ) -go_proto_library( - name = "object_go_proto", - importpath = "gvisor.dev/gvisor/pkg/state/object_go_proto", - proto = ":object_proto", - visibility = ["//:sandbox"], -) - go_test( name = "state_test", timeout = "long", srcs = ["state_test.go"], - embed = [":state"], + library = ":state", ) diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index 8a865d229..e7581c09b 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "statefile", srcs = ["statefile.go"], - importpath = "gvisor.dev/gvisor/pkg/state/statefile", visibility = ["//:sandbox"], deps = [ "//pkg/binary", @@ -18,6 +16,6 @@ go_test( name = "statefile_test", size = "small", srcs = ["statefile_test.go"], - embed = [":statefile"], + library = ":statefile", deps = ["//pkg/compressio"], ) diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 97c4b3b1e..5340cf0d6 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template") package( @@ -40,7 +39,6 @@ go_library( "syncutil.go", "tmutex_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/sync", ) go_test( @@ -51,5 +49,5 @@ go_test( "seqcount_test.go", "tmutex_test.go", ], - embed = [":sync"], + library = ":sync", ) diff --git a/pkg/sync/atomicptrtest/BUILD b/pkg/sync/atomicptrtest/BUILD index 418eda29c..e97553254 100644 --- a/pkg/sync/atomicptrtest/BUILD +++ b/pkg/sync/atomicptrtest/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -18,12 +17,11 @@ go_template_instance( go_library( name = "atomicptr", srcs = ["atomicptr_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/sync/atomicptr", ) go_test( name = "atomicptr_test", size = "small", srcs = ["atomicptr_test.go"], - embed = [":atomicptr"], + library = ":atomicptr", ) diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD index eba21518d..5c38c783e 100644 --- a/pkg/sync/seqatomictest/BUILD +++ b/pkg/sync/seqatomictest/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -18,7 +17,6 @@ go_template_instance( go_library( name = "seqatomic", srcs = ["seqatomic_int_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/sync/seqatomic", deps = [ "//pkg/sync", ], @@ -28,6 +26,6 @@ go_test( name = "seqatomic_test", size = "small", srcs = ["seqatomic_test.go"], - embed = [":seqatomic"], + library = ":seqatomic", deps = ["//pkg/sync"], ) diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD index 5665ad4ee..7d760344a 100644 --- a/pkg/syserr/BUILD +++ b/pkg/syserr/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "netstack.go", "syserr.go", ], - importpath = "gvisor.dev/gvisor/pkg/syserr", visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux", diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD index bd3f9fd28..b13c15d9b 100644 --- a/pkg/syserror/BUILD +++ b/pkg/syserror/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "syserror", srcs = ["syserror.go"], - importpath = "gvisor.dev/gvisor/pkg/syserror", visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 23e4b09e7..26f7ba86b 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "time_unsafe.go", "timer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -25,7 +23,7 @@ go_test( name = "tcpip_test", size = "small", srcs = ["tcpip_test.go"], - embed = [":tcpip"], + library = ":tcpip", ) go_test( diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 3df7d18d3..a984f1712 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "gonet", srcs = ["gonet.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -23,7 +21,7 @@ go_test( name = "gonet_test", size = "small", srcs = ["gonet_test.go"], - embed = [":gonet"], + library = ":gonet", deps = [ "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index d6c31bfa2..563bc78ea 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "prependable.go", "view.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/buffer", visibility = ["//visibility:public"], ) @@ -17,5 +15,5 @@ go_test( name = "buffer_test", size = "small", srcs = ["view_test.go"], - embed = [":buffer"], + library = ":buffer", ) diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index b6fa6fc37..ed434807f 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "checker", testonly = 1, srcs = ["checker.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/checker", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD index e648efa71..ff2719291 100644 --- a/pkg/tcpip/hash/jenkins/BUILD +++ b/pkg/tcpip/hash/jenkins/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "jenkins", srcs = ["jenkins.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins", visibility = ["//visibility:public"], ) @@ -16,5 +14,5 @@ go_test( srcs = [ "jenkins_test.go", ], - embed = [":jenkins"], + library = ":jenkins", ) diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index cd747d100..9da0d71f8 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -24,7 +23,6 @@ go_library( "tcp.go", "udp.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/header", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -59,7 +57,7 @@ go_test( "eth_test.go", "ndp_test.go", ], - embed = [":header"], + library = ":header", deps = [ "//pkg/tcpip", "@com_github_google_go-cmp//cmp:go_default_library", diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index 297eaccaf..d1b73cfdf 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "targets.go", "types.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables", visibility = ["//visibility:public"], deps = [ "//pkg/log", diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 7dbc05754..3974c464e 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "channel", srcs = ["channel.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 66cc53ed4..abe725548 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -13,7 +12,6 @@ go_library( "mmap_unsafe.go", "packet_dispatchers.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -30,7 +28,7 @@ go_test( name = "fdbased_test", size = "small", srcs = ["endpoint_test.go"], - embed = [":fdbased"], + library = ":fdbased", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD index f35fcdff4..6bf3805b7 100644 --- a/pkg/tcpip/link/loopback/BUILD +++ b/pkg/tcpip/link/loopback/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "loopback", srcs = ["loopback.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 1ac7948b6..82b441b79 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "muxed", srcs = ["injectable.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -19,7 +17,7 @@ go_test( name = "muxed_test", size = "small", srcs = ["injectable_test.go"], - embed = [":muxed"], + library = ":muxed", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index d8211e93d..14b527bc2 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,7 +12,6 @@ go_library( "errors.go", "rawfile_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index 09165dd4c..13243ebbb 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "sharedmem_unsafe.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -30,7 +28,7 @@ go_test( srcs = [ "sharedmem_test.go", ], - embed = [":sharedmem"], + library = ":sharedmem", deps = [ "//pkg/sync", "//pkg/tcpip", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index a0d4ad0be..87020ec08 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "rx.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe", visibility = ["//visibility:public"], ) @@ -20,6 +18,6 @@ go_test( srcs = [ "pipe_test.go", ], - embed = [":pipe"], + library = ":pipe", deps = ["//pkg/sync"], ) diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD index 8c9234d54..3ba06af73 100644 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "rx.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -22,7 +20,7 @@ go_test( srcs = [ "queue_test.go", ], - embed = [":queue"], + library = ":queue", deps = [ "//pkg/tcpip/link/sharedmem/pipe", ], diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index d6ae0368a..230a8d53a 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "pcap.go", "sniffer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer", visibility = ["//visibility:public"], deps = [ "//pkg/log", diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index a71a493fc..e5096ea38 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -1,10 +1,9 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "tun", srcs = ["tun_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun", visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 134837943..0956d2c65 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -8,7 +7,6 @@ go_library( srcs = [ "waitable.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable", visibility = ["//visibility:public"], deps = [ "//pkg/gate", @@ -23,7 +21,7 @@ go_test( srcs = [ "waitable_test.go", ], - embed = [":waitable"], + library = ":waitable", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 9d16ff8c9..6a4839fb8 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index e7617229b..eddf7b725 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "arp", srcs = ["arp.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index ed16076fd..d1c728ccf 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -24,7 +23,6 @@ go_library( "reassembler.go", "reassembler_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -42,6 +40,6 @@ go_test( "fragmentation_test.go", "reassembler_test.go", ], - embed = [":fragmentation"], + library = ":fragmentation", deps = ["//pkg/tcpip/buffer"], ) diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD index e6db5c0b0..872165866 100644 --- a/pkg/tcpip/network/hash/BUILD +++ b/pkg/tcpip/network/hash/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "hash", srcs = ["hash.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/hash", visibility = ["//visibility:public"], deps = [ "//pkg/rand", diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 4e2aae9a3..0fef2b1f1 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "icmp.go", "ipv4.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index e4e273460..fb11874c6 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "icmp.go", "ipv6.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -27,7 +25,7 @@ go_test( "ipv6_test.go", "ndp_test.go", ], - embed = [":ipv6"], + library = ":ipv6", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index a6ef3bdcc..2bad05a2e 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "ports", srcs = ["ports.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -17,7 +15,7 @@ go_library( go_test( name = "ports_test", srcs = ["ports_test.go"], - embed = [":ports"], + library = ":ports", deps = [ "//pkg/tcpip", ], diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index d7496fde6..cf0a5fefe 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD index 875561566..43264b76d 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/BUILD +++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD index b31ddba2f..45f503845 100644 --- a/pkg/tcpip/seqnum/BUILD +++ b/pkg/tcpip/seqnum/BUILD @@ -1,10 +1,9 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "seqnum", srcs = ["seqnum.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum", visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 783351a69..f5b750046 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -30,7 +29,6 @@ go_library( "stack_global_state.go", "transport_demuxer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/stack", visibility = ["//visibility:public"], deps = [ "//pkg/ilist", @@ -81,7 +79,7 @@ go_test( name = "stack_test", size = "small", srcs = ["linkaddrcache_test.go"], - embed = [":stack"], + library = ":stack", deps = [ "//pkg/sleep", "//pkg/sync", diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 3aa23d529..ac18ec5b1 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,7 +23,6 @@ go_library( "icmp_packet_list.go", "protocol.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/icmp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD index 4858d150c..d22de6b26 100644 --- a/pkg/tcpip/transport/packet/BUILD +++ b/pkg/tcpip/transport/packet/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -22,7 +22,6 @@ go_library( "endpoint_state.go", "packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 2f2131ff7..c9baf4600 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,7 +23,6 @@ go_library( "protocol.go", "raw_packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 0e3ab05ad..4acd9fb9a 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -55,7 +54,6 @@ go_library( "tcp_segment_list.go", "timer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD index b33ec2087..ce6a2c31d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ b/pkg/tcpip/transport/tcp/testing/context/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "context", testonly = 1, srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context", visibility = [ "//visibility:public", ], diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD index 43fcc27f0..3ad6994a7 100644 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "tcpconntrack", srcs = ["tcp_conntrack.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 57ff123e3..adc908e24 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -25,7 +24,6 @@ go_library( "protocol.go", "udp_packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/udp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD index 07778e4f7..2dcba84ae 100644 --- a/pkg/tmutex/BUILD +++ b/pkg/tmutex/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "tmutex", srcs = ["tmutex.go"], - importpath = "gvisor.dev/gvisor/pkg/tmutex", visibility = ["//:sandbox"], ) @@ -14,6 +12,6 @@ go_test( name = "tmutex_test", size = "medium", srcs = ["tmutex_test.go"], - embed = [":tmutex"], + library = ":tmutex", deps = ["//pkg/sync"], ) diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD index d1885ae66..a86501fa2 100644 --- a/pkg/unet/BUILD +++ b/pkg/unet/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "unet.go", "unet_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/unet", visibility = ["//visibility:public"], deps = [ "//pkg/gate", @@ -23,6 +21,6 @@ go_test( srcs = [ "unet_test.go", ], - embed = [":unet"], + library = ":unet", deps = ["//pkg/sync"], ) diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD index b8fdc3125..850c34ed0 100644 --- a/pkg/urpc/BUILD +++ b/pkg/urpc/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "urpc", srcs = ["urpc.go"], - importpath = "gvisor.dev/gvisor/pkg/urpc", visibility = ["//:sandbox"], deps = [ "//pkg/fd", @@ -20,6 +18,6 @@ go_test( name = "urpc_test", size = "small", srcs = ["urpc_test.go"], - embed = [":urpc"], + library = ":urpc", deps = ["//pkg/unet"], ) diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD index 1c6890e52..852480a09 100644 --- a/pkg/waiter/BUILD +++ b/pkg/waiter/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -22,7 +21,6 @@ go_library( "waiter.go", "waiter_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/waiter", visibility = ["//visibility:public"], deps = ["//pkg/sync"], ) @@ -33,5 +31,5 @@ go_test( srcs = [ "waiter_test.go", ], - embed = [":waiter"], + library = ":waiter", ) diff --git a/runsc/BUILD b/runsc/BUILD index e5587421d..b35b41d81 100644 --- a/runsc/BUILD +++ b/runsc/BUILD @@ -1,7 +1,6 @@ -package(licenses = ["notice"]) # Apache 2.0 +load("//tools:defs.bzl", "go_binary", "pkg_deb", "pkg_tar") -load("@io_bazel_rules_go//go:def.bzl", "go_binary") -load("@rules_pkg//:pkg.bzl", "pkg_deb", "pkg_tar") +package(licenses = ["notice"]) go_binary( name = "runsc", @@ -9,7 +8,7 @@ go_binary( "main.go", "version.go", ], - pure = "on", + pure = True, visibility = [ "//visibility:public", ], @@ -26,10 +25,12 @@ go_binary( ) # The runsc-race target is a race-compatible BUILD target. This must be built -# via "bazel build --features=race //runsc:runsc-race", since the race feature -# must apply to all dependencies due a bug in gazelle file selection. The pure -# attribute must be off because the race detector requires linking with non-Go -# components, although we still require a static binary. +# via: bazel build --features=race //runsc:runsc-race +# +# This is neccessary because the race feature must apply to all dependencies +# due a bug in gazelle file selection. The pure attribute must be off because +# the race detector requires linking with non-Go components, although we still +# require a static binary. # # Note that in the future this might be convertible to a compatible target by # using the pure and static attributes within a select function, but select is @@ -42,7 +43,7 @@ go_binary( "main.go", "version.go", ], - static = "on", + static = True, visibility = [ "//visibility:public", ], @@ -82,7 +83,12 @@ genrule( # because they are assumes to be hermetic). srcs = [":runsc"], outs = ["version.txt"], - cmd = "$(location :runsc) -version | grep 'runsc version' | sed 's/^[^0-9]*//' > $@", + # Note that the little dance here is necessary because files in the $(SRCS) + # attribute are not executable by default, and we can't touch in place. + cmd = "cp $(location :runsc) $(@D)/runsc && \ + chmod a+x $(@D)/runsc && \ + $(@D)/runsc -version | grep version | sed 's/^[^0-9]*//' > $@ && \ + rm -f $(@D)/runsc", stamp = 1, ) @@ -109,5 +115,6 @@ sh_test( name = "version_test", size = "small", srcs = ["version_test.sh"], + args = ["$(location :runsc)"], data = [":runsc"], ) diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 3e20f8f2f..f3ebc0231 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -23,7 +23,6 @@ go_library( "strace.go", "user.go", ], - importpath = "gvisor.dev/gvisor/runsc/boot", visibility = [ "//runsc:__subpackages__", "//test:__subpackages__", @@ -107,7 +106,7 @@ go_test( "loader_test.go", "user_test.go", ], - embed = [":boot"], + library = ":boot", deps = [ "//pkg/control/server", "//pkg/log", diff --git a/runsc/boot/filter/BUILD b/runsc/boot/filter/BUILD index 3a9dcfc04..ce30f6c53 100644 --- a/runsc/boot/filter/BUILD +++ b/runsc/boot/filter/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -13,7 +13,6 @@ go_library( "extra_filters_race.go", "filter.go", ], - importpath = "gvisor.dev/gvisor/runsc/boot/filter", visibility = [ "//runsc/boot:__subpackages__", ], diff --git a/runsc/boot/platforms/BUILD b/runsc/boot/platforms/BUILD index 03391cdca..77774f43c 100644 --- a/runsc/boot/platforms/BUILD +++ b/runsc/boot/platforms/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "platforms", srcs = ["platforms.go"], - importpath = "gvisor.dev/gvisor/runsc/boot/platforms", visibility = [ "//runsc:__subpackages__", ], diff --git a/runsc/cgroup/BUILD b/runsc/cgroup/BUILD index d6165f9e5..d4c7bdfbb 100644 --- a/runsc/cgroup/BUILD +++ b/runsc/cgroup/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "cgroup", srcs = ["cgroup.go"], - importpath = "gvisor.dev/gvisor/runsc/cgroup", visibility = ["//:sandbox"], deps = [ "//pkg/log", @@ -19,6 +18,6 @@ go_test( name = "cgroup_test", size = "small", srcs = ["cgroup_test.go"], - embed = [":cgroup"], + library = ":cgroup", tags = ["local"], ) diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index b94bc4fa0..09aa46434 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -34,7 +34,6 @@ go_library( "syscalls.go", "wait.go", ], - importpath = "gvisor.dev/gvisor/runsc/cmd", visibility = [ "//runsc:__subpackages__", ], @@ -73,7 +72,7 @@ go_test( data = [ "//runsc", ], - embed = [":cmd"], + library = ":cmd", deps = [ "//pkg/abi/linux", "//pkg/log", diff --git a/runsc/console/BUILD b/runsc/console/BUILD index e623c1a0f..06924bccd 100644 --- a/runsc/console/BUILD +++ b/runsc/console/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -7,7 +7,6 @@ go_library( srcs = [ "console.go", ], - importpath = "gvisor.dev/gvisor/runsc/console", visibility = [ "//runsc:__subpackages__", ], diff --git a/runsc/container/BUILD b/runsc/container/BUILD index 6dea179e4..e21431e4c 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +10,6 @@ go_library( "state_file.go", "status.go", ], - importpath = "gvisor.dev/gvisor/runsc/container", visibility = [ "//runsc:__subpackages__", "//test:__subpackages__", @@ -42,7 +41,7 @@ go_test( "//runsc", "//runsc/container/test_app", ], - embed = [":container"], + library = ":container", shard_count = 5, tags = [ "requires-kvm", diff --git a/runsc/container/test_app/BUILD b/runsc/container/test_app/BUILD index bfd338bb6..e200bafd9 100644 --- a/runsc/container/test_app/BUILD +++ b/runsc/container/test_app/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) @@ -9,7 +9,7 @@ go_binary( "fds.go", "test_app.go", ], - pure = "on", + pure = True, visibility = ["//runsc/container:__pkg__"], deps = [ "//pkg/unet", diff --git a/runsc/criutil/BUILD b/runsc/criutil/BUILD index 558133a0e..8a571a000 100644 --- a/runsc/criutil/BUILD +++ b/runsc/criutil/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "criutil", testonly = 1, srcs = ["criutil.go"], - importpath = "gvisor.dev/gvisor/runsc/criutil", visibility = ["//:sandbox"], deps = ["//runsc/testutil"], ) diff --git a/runsc/dockerutil/BUILD b/runsc/dockerutil/BUILD index 0e0423504..8621af901 100644 --- a/runsc/dockerutil/BUILD +++ b/runsc/dockerutil/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "dockerutil", testonly = 1, srcs = ["dockerutil.go"], - importpath = "gvisor.dev/gvisor/runsc/dockerutil", visibility = ["//:sandbox"], deps = [ "//runsc/testutil", diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index a9582d92b..64a406ae2 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,10 +10,7 @@ go_library( "fsgofer_arm64_unsafe.go", "fsgofer_unsafe.go", ], - importpath = "gvisor.dev/gvisor/runsc/fsgofer", - visibility = [ - "//runsc:__subpackages__", - ], + visibility = ["//runsc:__subpackages__"], deps = [ "//pkg/abi/linux", "//pkg/fd", @@ -30,7 +27,7 @@ go_test( name = "fsgofer_test", size = "small", srcs = ["fsgofer_test.go"], - embed = [":fsgofer"], + library = ":fsgofer", deps = [ "//pkg/log", "//pkg/p9", diff --git a/runsc/fsgofer/filter/BUILD b/runsc/fsgofer/filter/BUILD index bac73f89d..82b48ef32 100644 --- a/runsc/fsgofer/filter/BUILD +++ b/runsc/fsgofer/filter/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -13,7 +13,6 @@ go_library( "extra_filters_race.go", "filter.go", ], - importpath = "gvisor.dev/gvisor/runsc/fsgofer/filter", visibility = [ "//runsc:__subpackages__", ], diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD index ddbc37456..c95d50294 100644 --- a/runsc/sandbox/BUILD +++ b/runsc/sandbox/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,7 +9,6 @@ go_library( "network_unsafe.go", "sandbox.go", ], - importpath = "gvisor.dev/gvisor/runsc/sandbox", visibility = [ "//runsc:__subpackages__", ], diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD index 205638803..4ccd77f63 100644 --- a/runsc/specutils/BUILD +++ b/runsc/specutils/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +10,6 @@ go_library( "namespace.go", "specutils.go", ], - importpath = "gvisor.dev/gvisor/runsc/specutils", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", @@ -28,6 +27,6 @@ go_test( name = "specutils_test", size = "small", srcs = ["specutils_test.go"], - embed = [":specutils"], + library = ":specutils", deps = ["@com_github_opencontainers_runtime-spec//specs-go:go_default_library"], ) diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD index 3c3027cb5..f845120b0 100644 --- a/runsc/testutil/BUILD +++ b/runsc/testutil/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "testutil", testonly = 1, srcs = ["testutil.go"], - importpath = "gvisor.dev/gvisor/runsc/testutil", visibility = ["//:sandbox"], deps = [ "//pkg/log", diff --git a/runsc/version_test.sh b/runsc/version_test.sh index cc0ca3f05..747350654 100755 --- a/runsc/version_test.sh +++ b/runsc/version_test.sh @@ -16,7 +16,7 @@ set -euf -x -o pipefail -readonly runsc="${TEST_SRCDIR}/__main__/runsc/linux_amd64_pure_stripped/runsc" +readonly runsc="$1" readonly version=$($runsc --version) # Version should should not match VERSION, which is the default and which will diff --git a/scripts/common.sh b/scripts/common.sh index fdb1aa142..cd91b9f8e 100755 --- a/scripts/common.sh +++ b/scripts/common.sh @@ -16,11 +16,7 @@ set -xeou pipefail -if [[ -f $(dirname $0)/common_google.sh ]]; then - source $(dirname $0)/common_google.sh -else - source $(dirname $0)/common_bazel.sh -fi +source $(dirname $0)/common_build.sh # Ensure it attempts to collect logs in all cases. trap collect_logs EXIT diff --git a/scripts/common_bazel.sh b/scripts/common_bazel.sh deleted file mode 100755 index a473a88a4..000000000 --- a/scripts/common_bazel.sh +++ /dev/null @@ -1,99 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Install the latest version of Bazel and log the version. -(which use_bazel.sh && use_bazel.sh latest) || which bazel -bazel version - -# Switch into the workspace; only necessary if run with kokoro. -if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then - cd git/repo -elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then - cd github/repo -fi - -# Set the standard bazel flags. -declare -r BAZEL_FLAGS=( - "--show_timestamps" - "--test_output=errors" - "--keep_going" - "--verbose_failures=true" -) -if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then - declare -r BAZEL_RBE_AUTH_FLAGS=( - "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}" - ) - declare -r BAZEL_RBE_FLAGS=("--config=remote") -fi - -# Wrap bazel. -function build() { - bazel build "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" 2>&1 | - tee /dev/fd/2 | grep -E '^ bazel-bin/' | awk '{ print $1; }' -} - -function test() { - bazel test "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" -} - -function run() { - local binary=$1 - shift - bazel run "${binary}" -- "$@" -} - -function run_as_root() { - local binary=$1 - shift - bazel run --run_under="sudo" "${binary}" -- "$@" -} - -function collect_logs() { - # Zip out everything into a convenient form. - if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then - # Merge results files of all shards for each test suite. - for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do - junitparser merge `find $d -name test.xml` $d/test.xml - cat $d/shard_*_of_*/test.log > $d/test.log - ls -l $d/shard_*_of_*/test.outputs/outputs.zip && zip -r -1 $d/outputs.zip $d/shard_*_of_*/test.outputs/outputs.zip - done - find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf - # Move test logs to Kokoro directory. tar is used to conveniently perform - # renames while moving files. - find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" | - tar --create --files-from - --transform 's/test\./sponge_log./' | - tar --extract --directory ${KOKORO_ARTIFACTS_DIR} - - # Collect sentry logs, if any. - if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then - # Check if the directory is empty or not (only the first line it needed). - local -r logs=$(ls "${RUNSC_LOGS_DIR}" | head -n1) - if [[ "${logs}" ]]; then - local -r archive=runsc_logs_"${RUNTIME}".tar.gz - if [[ -v KOKORO_BUILD_ARTIFACTS_SUBDIR ]]; then - echo "runsc logs will be uploaded to:" - echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp" - echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}" - fi - tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" . - fi - fi - fi -} - -function find_branch_name() { - git branch --show-current || git rev-parse HEAD || bazel info workspace | xargs basename -} diff --git a/scripts/common_build.sh b/scripts/common_build.sh new file mode 100755 index 000000000..a473a88a4 --- /dev/null +++ b/scripts/common_build.sh @@ -0,0 +1,99 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Install the latest version of Bazel and log the version. +(which use_bazel.sh && use_bazel.sh latest) || which bazel +bazel version + +# Switch into the workspace; only necessary if run with kokoro. +if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then + cd git/repo +elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then + cd github/repo +fi + +# Set the standard bazel flags. +declare -r BAZEL_FLAGS=( + "--show_timestamps" + "--test_output=errors" + "--keep_going" + "--verbose_failures=true" +) +if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then + declare -r BAZEL_RBE_AUTH_FLAGS=( + "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}" + ) + declare -r BAZEL_RBE_FLAGS=("--config=remote") +fi + +# Wrap bazel. +function build() { + bazel build "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" 2>&1 | + tee /dev/fd/2 | grep -E '^ bazel-bin/' | awk '{ print $1; }' +} + +function test() { + bazel test "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" +} + +function run() { + local binary=$1 + shift + bazel run "${binary}" -- "$@" +} + +function run_as_root() { + local binary=$1 + shift + bazel run --run_under="sudo" "${binary}" -- "$@" +} + +function collect_logs() { + # Zip out everything into a convenient form. + if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then + # Merge results files of all shards for each test suite. + for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do + junitparser merge `find $d -name test.xml` $d/test.xml + cat $d/shard_*_of_*/test.log > $d/test.log + ls -l $d/shard_*_of_*/test.outputs/outputs.zip && zip -r -1 $d/outputs.zip $d/shard_*_of_*/test.outputs/outputs.zip + done + find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf + # Move test logs to Kokoro directory. tar is used to conveniently perform + # renames while moving files. + find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" | + tar --create --files-from - --transform 's/test\./sponge_log./' | + tar --extract --directory ${KOKORO_ARTIFACTS_DIR} + + # Collect sentry logs, if any. + if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then + # Check if the directory is empty or not (only the first line it needed). + local -r logs=$(ls "${RUNSC_LOGS_DIR}" | head -n1) + if [[ "${logs}" ]]; then + local -r archive=runsc_logs_"${RUNTIME}".tar.gz + if [[ -v KOKORO_BUILD_ARTIFACTS_SUBDIR ]]; then + echo "runsc logs will be uploaded to:" + echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp" + echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}" + fi + tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" . + fi + fi + fi +} + +function find_branch_name() { + git branch --show-current || git rev-parse HEAD || bazel info workspace | xargs basename +} diff --git a/test/BUILD b/test/BUILD index bf834d994..34b950644 100644 --- a/test/BUILD +++ b/test/BUILD @@ -1,44 +1 @@ -package(licenses = ["notice"]) # Apache 2.0 - -# We need to define a bazel platform and toolchain to specify dockerPrivileged -# and dockerRunAsRoot options, they are required to run tests on the RBE -# cluster in Kokoro. -alias( - name = "rbe_ubuntu1604", - actual = ":rbe_ubuntu1604_r346485", -) - -platform( - name = "rbe_ubuntu1604_r346485", - constraint_values = [ - "@bazel_tools//platforms:x86_64", - "@bazel_tools//platforms:linux", - "@bazel_tools//tools/cpp:clang", - "@bazel_toolchains//constraints:xenial", - "@bazel_toolchains//constraints/sanitizers:support_msan", - ], - remote_execution_properties = """ - properties: { - name: "container-image" - value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:93f7e127196b9b653d39830c50f8b05d49ef6fd8739a9b5b8ab16e1df5399e50" - } - properties: { - name: "dockerAddCapabilities" - value: "SYS_ADMIN" - } - properties: { - name: "dockerPrivileged" - value: "true" - } - """, -) - -toolchain( - name = "cc-toolchain-clang-x86_64-default", - exec_compatible_with = [ - ], - target_compatible_with = [ - ], - toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8", - toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", -) +package(licenses = ["notice"]) diff --git a/test/e2e/BUILD b/test/e2e/BUILD index 4fe03a220..76e04f878 100644 --- a/test/e2e/BUILD +++ b/test/e2e/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -10,7 +10,7 @@ go_test( "integration_test.go", "regression_test.go", ], - embed = [":integration"], + library = ":integration", tags = [ # Requires docker and runsc to be configured before the test runs. "manual", @@ -29,5 +29,4 @@ go_test( go_library( name = "integration", srcs = ["integration.go"], - importpath = "gvisor.dev/gvisor/test/integration", ) diff --git a/test/image/BUILD b/test/image/BUILD index 09b0a0ad5..7392ac54e 100644 --- a/test/image/BUILD +++ b/test/image/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -14,7 +14,7 @@ go_test( "ruby.rb", "ruby.sh", ], - embed = [":image"], + library = ":image", tags = [ # Requires docker and runsc to be configured before the test runs. "manual", @@ -30,5 +30,4 @@ go_test( go_library( name = "image", srcs = ["image.go"], - importpath = "gvisor.dev/gvisor/test/image", ) diff --git a/test/iptables/BUILD b/test/iptables/BUILD index 22f470092..6bb3b82b5 100644 --- a/test/iptables/BUILD +++ b/test/iptables/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +12,6 @@ go_library( "iptables_util.go", "nat.go", ], - importpath = "gvisor.dev/gvisor/test/iptables", visibility = ["//test/iptables:__subpackages__"], deps = [ "//runsc/testutil", @@ -24,7 +23,7 @@ go_test( srcs = [ "iptables_test.go", ], - embed = [":iptables"], + library = ":iptables", tags = [ "local", "manual", diff --git a/test/iptables/runner/BUILD b/test/iptables/runner/BUILD index a5b6f082c..b9199387a 100644 --- a/test/iptables/runner/BUILD +++ b/test/iptables/runner/BUILD @@ -1,15 +1,21 @@ -load("@io_bazel_rules_docker//go:image.bzl", "go_image") -load("@io_bazel_rules_docker//container:container.bzl", "container_image") +load("//tools:defs.bzl", "container_image", "go_binary", "go_image") package(licenses = ["notice"]) +go_binary( + name = "runner", + testonly = 1, + srcs = ["main.go"], + deps = ["//test/iptables"], +) + container_image( name = "iptables-base", base = "@iptables-test//image", ) go_image( - name = "runner", + name = "runner-image", testonly = 1, srcs = ["main.go"], base = ":iptables-base", diff --git a/test/root/BUILD b/test/root/BUILD index d5dd9bca2..23ce2a70f 100644 --- a/test/root/BUILD +++ b/test/root/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "root", srcs = ["root.go"], - importpath = "gvisor.dev/gvisor/test/root", ) go_test( @@ -21,7 +20,7 @@ go_test( data = [ "//runsc", ], - embed = [":root"], + library = ":root", tags = [ # Requires docker and runsc to be configured before the test runs. # Also test only runs as root. diff --git a/test/root/testdata/BUILD b/test/root/testdata/BUILD index 125633680..bca5f9cab 100644 --- a/test/root/testdata/BUILD +++ b/test/root/testdata/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,7 +12,6 @@ go_library( "sandbox.go", "simple.go", ], - importpath = "gvisor.dev/gvisor/test/root/testdata", visibility = [ "//visibility:public", ], diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD index 367295206..2c472bf8d 100644 --- a/test/runtimes/BUILD +++ b/test/runtimes/BUILD @@ -1,6 +1,6 @@ # These packages are used to run language runtime tests inside gVisor sandboxes. -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary", "go_test") load("//test/runtimes:build_defs.bzl", "runtime_test") package(licenses = ["notice"]) @@ -49,5 +49,5 @@ go_test( name = "blacklist_test", size = "small", srcs = ["blacklist_test.go"], - embed = [":runner"], + library = ":runner", ) diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl index 6f84ca852..92e275a76 100644 --- a/test/runtimes/build_defs.bzl +++ b/test/runtimes/build_defs.bzl @@ -1,6 +1,6 @@ """Defines a rule for runtime test targets.""" -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test", "loopback") def runtime_test( name, @@ -34,6 +34,7 @@ def runtime_test( ] data = [ ":runner", + loopback, ] if blacklist_file: args += ["--blacklist_file", "test/runtimes/" + blacklist_file] @@ -61,7 +62,7 @@ def blacklist_test(name, blacklist_file): """Test that a blacklist parses correctly.""" go_test( name = name + "_blacklist_test", - embed = [":runner"], + library = ":runner", srcs = ["blacklist_test.go"], args = ["--blacklist_file", "test/runtimes/" + blacklist_file], data = [blacklist_file], diff --git a/test/runtimes/images/proctor/BUILD b/test/runtimes/images/proctor/BUILD index 09dc6c42f..85e004c45 100644 --- a/test/runtimes/images/proctor/BUILD +++ b/test/runtimes/images/proctor/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary", "go_test") package(licenses = ["notice"]) @@ -19,7 +19,7 @@ go_test( name = "proctor_test", size = "small", srcs = ["proctor_test.go"], - embed = [":proctor"], + library = ":proctor", deps = [ "//runsc/testutil", ], diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 90d52e73b..40e974314 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") load("//test/syscalls:build_defs.bzl", "syscall_test") package(licenses = ["notice"]) diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl index aaf77c65b..1df761dd0 100644 --- a/test/syscalls/build_defs.bzl +++ b/test/syscalls/build_defs.bzl @@ -1,5 +1,7 @@ """Defines a rule for syscall test targets.""" +load("//tools:defs.bzl", "loopback") + # syscall_test is a macro that will create targets to run the given test target # on the host (native) and runsc. def syscall_test( @@ -135,6 +137,7 @@ def _syscall_test( name = name, data = [ ":syscall_test_runner", + loopback, test, ], args = args, @@ -148,6 +151,3 @@ def sh_test(**kwargs): native.sh_test( **kwargs ) - -def select_for_linux(for_linux, for_others = []): - return for_linux diff --git a/test/syscalls/gtest/BUILD b/test/syscalls/gtest/BUILD index 9293f25cb..de4b2727c 100644 --- a/test/syscalls/gtest/BUILD +++ b/test/syscalls/gtest/BUILD @@ -1,12 +1,9 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "gtest", srcs = ["gtest.go"], - importpath = "gvisor.dev/gvisor/test/syscalls/gtest", - visibility = [ - "//test:__subpackages__", - ], + visibility = ["//:sandbox"], ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 4c7ec3f06..c2ef50c1d 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1,5 +1,4 @@ -load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") -load("//test/syscalls:build_defs.bzl", "select_for_linux") +load("//tools:defs.bzl", "cc_binary", "cc_library", "default_net_util", "select_system") package( default_visibility = ["//:sandbox"], @@ -126,13 +125,11 @@ cc_library( testonly = 1, srcs = [ "socket_test_util.cc", - ] + select_for_linux( - [ - "socket_test_util_impl.cc", - ], - ), + "socket_test_util_impl.cc", + ], hdrs = ["socket_test_util.h"], - deps = [ + defines = select_system(), + deps = default_net_util() + [ "@com_google_googletest//:gtest", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -143,8 +140,7 @@ cc_library( "//test/util:temp_path", "//test/util:test_util", "//test/util:thread_util", - ] + select_for_linux([ - ]), + ], ) cc_library( @@ -1443,6 +1439,7 @@ cc_binary( srcs = ["arch_prctl.cc"], linkstatic = 1, deps = [ + "//test/util:file_descriptor", "//test/util:test_main", "//test/util:test_util", "@com_google_googletest//:gtest", @@ -3383,11 +3380,11 @@ cc_library( name = "udp_socket_test_cases", testonly = 1, srcs = [ - "udp_socket_test_cases.cc", - ] + select_for_linux([ "udp_socket_errqueue_test_case.cc", - ]), + "udp_socket_test_cases.cc", + ], hdrs = ["udp_socket_test_cases.h"], + defines = select_system(), deps = [ ":socket_test_util", ":unix_domain_socket_test_util", diff --git a/test/syscalls/linux/arch_prctl.cc b/test/syscalls/linux/arch_prctl.cc index 81bf5a775..3a901faf5 100644 --- a/test/syscalls/linux/arch_prctl.cc +++ b/test/syscalls/linux/arch_prctl.cc @@ -14,8 +14,10 @@ #include #include +#include #include "gtest/gtest.h" +#include "test/util/file_descriptor.h" #include "test/util/test_util.h" // glibc does not provide a prototype for arch_prctl() so declare it here. diff --git a/test/syscalls/linux/rseq/BUILD b/test/syscalls/linux/rseq/BUILD index 5cfe4e56f..ed488dbc2 100644 --- a/test/syscalls/linux/rseq/BUILD +++ b/test/syscalls/linux/rseq/BUILD @@ -1,8 +1,7 @@ # This package contains a standalone rseq test binary. This binary must not # depend on libc, which might use rseq itself. -load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", "cc_flags_supplier") -load("@rules_cc//cc:defs.bzl", "cc_library") +load("//tools:defs.bzl", "cc_flags_supplier", "cc_library", "cc_toolchain") package(licenses = ["notice"]) @@ -37,8 +36,8 @@ genrule( "$(location start.S)", ]), toolchains = [ + cc_toolchain, ":no_pie_cc_flags", - "@bazel_tools//tools/cpp:current_cc_toolchain", ], visibility = ["//:sandbox"], ) diff --git a/test/syscalls/linux/udp_socket_errqueue_test_case.cc b/test/syscalls/linux/udp_socket_errqueue_test_case.cc index 147978f46..9a24e1df0 100644 --- a/test/syscalls/linux/udp_socket_errqueue_test_case.cc +++ b/test/syscalls/linux/udp_socket_errqueue_test_case.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef __fuchsia__ + #include "test/syscalls/linux/udp_socket_test_cases.h" #include @@ -52,3 +54,5 @@ TEST_P(UdpSocketTest, ErrorQueue) { } // namespace testing } // namespace gvisor + +#endif // __fuchsia__ diff --git a/test/uds/BUILD b/test/uds/BUILD index a3843e699..51e2c7ce8 100644 --- a/test/uds/BUILD +++ b/test/uds/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package( default_visibility = ["//:sandbox"], @@ -9,7 +9,6 @@ go_library( name = "uds", testonly = 1, srcs = ["uds.go"], - importpath = "gvisor.dev/gvisor/test/uds", deps = [ "//pkg/log", "//pkg/unet", diff --git a/test/util/BUILD b/test/util/BUILD index cbc728159..3c732be62 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -1,5 +1,4 @@ -load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") -load("//test/syscalls:build_defs.bzl", "select_for_linux") +load("//tools:defs.bzl", "cc_library", "cc_test", "select_system") package( default_visibility = ["//:sandbox"], @@ -142,12 +141,13 @@ cc_library( cc_library( name = "save_util", testonly = 1, - srcs = ["save_util.cc"] + - select_for_linux( - ["save_util_linux.cc"], - ["save_util_other.cc"], - ), + srcs = [ + "save_util.cc", + "save_util_linux.cc", + "save_util_other.cc", + ], hdrs = ["save_util.h"], + defines = select_system(), ) cc_library( @@ -234,13 +234,16 @@ cc_library( testonly = 1, srcs = [ "test_util.cc", - ] + select_for_linux( - [ - "test_util_impl.cc", - "test_util_runfiles.cc", + "test_util_impl.cc", + "test_util_runfiles.cc", + ], + hdrs = ["test_util.h"], + defines = select_system( + fuchsia = [ + "__opensource__", + "__fuchsia__", ], ), - hdrs = ["test_util.h"], deps = [ ":fs_util", ":logging", diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc index cd56118c0..d0aea8e6a 100644 --- a/test/util/save_util_linux.cc +++ b/test/util/save_util_linux.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifdef __linux__ + #include #include #include @@ -43,3 +45,5 @@ void MaybeSave() { } // namespace testing } // namespace gvisor + +#endif diff --git a/test/util/save_util_other.cc b/test/util/save_util_other.cc index 1aca663b7..931af2c29 100644 --- a/test/util/save_util_other.cc +++ b/test/util/save_util_other.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef __linux__ + namespace gvisor { namespace testing { @@ -21,3 +23,5 @@ void MaybeSave() { } // namespace testing } // namespace gvisor + +#endif diff --git a/test/util/test_util_runfiles.cc b/test/util/test_util_runfiles.cc index 7210094eb..694d21692 100644 --- a/test/util/test_util_runfiles.cc +++ b/test/util/test_util_runfiles.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef __fuchsia__ + #include #include @@ -44,3 +46,5 @@ std::string RunfilePath(std::string path) { } // namespace testing } // namespace gvisor + +#endif // __fuchsia__ diff --git a/tools/BUILD b/tools/BUILD new file mode 100644 index 000000000..e73a9c885 --- /dev/null +++ b/tools/BUILD @@ -0,0 +1,3 @@ +package(licenses = ["notice"]) + +exports_files(["nogo.js"]) diff --git a/tools/build/BUILD b/tools/build/BUILD new file mode 100644 index 000000000..0c0ce3f4d --- /dev/null +++ b/tools/build/BUILD @@ -0,0 +1,10 @@ +package(licenses = ["notice"]) + +# In bazel, no special support is required for loopback networking. This is +# just a dummy data target that does not change the test environment. +genrule( + name = "loopback", + outs = ["loopback.txt"], + cmd = "touch $@", + visibility = ["//visibility:public"], +) diff --git a/tools/build/defs.bzl b/tools/build/defs.bzl new file mode 100644 index 000000000..d0556abd1 --- /dev/null +++ b/tools/build/defs.bzl @@ -0,0 +1,91 @@ +"""Bazel implementations of standard rules.""" + +load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier") +load("@io_bazel_rules_go//go:def.bzl", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_test = "go_test", _go_tool_library = "go_tool_library") +load("@io_bazel_rules_go//proto:def.bzl", _go_proto_library = "go_proto_library") +load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test") +load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") +load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image") +load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image") +load("@pydeps//:requirements.bzl", _py_requirement = "requirement") + +container_image = _container_image +cc_binary = _cc_binary +cc_library = _cc_library +cc_flags_supplier = _cc_flags_supplier +cc_proto_library = _cc_proto_library +cc_test = _cc_test +cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" +go_image = _go_image +go_embed_data = _go_embed_data +loopback = "//tools/build:loopback" +proto_library = native.proto_library +pkg_deb = _pkg_deb +pkg_tar = _pkg_tar +py_library = native.py_library +py_binary = native.py_binary +py_test = native.py_test + +def go_binary(name, static = False, pure = False, **kwargs): + if static: + kwargs["static"] = "on" + if pure: + kwargs["pure"] = "on" + _go_binary( + name = name, + **kwargs + ) + +def go_library(name, **kwargs): + _go_library( + name = name, + importpath = "gvisor.dev/gvisor/" + native.package_name(), + **kwargs + ) + +def go_tool_library(name, **kwargs): + _go_tool_library( + name = name, + importpath = "gvisor.dev/gvisor/" + native.package_name(), + **kwargs + ) + +def go_proto_library(name, proto, **kwargs): + deps = kwargs.pop("deps", []) + _go_proto_library( + name = name, + importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name, + proto = proto, + deps = [dep.replace("_proto", "_go_proto") for dep in deps], + **kwargs + ) + +def go_test(name, **kwargs): + library = kwargs.pop("library", None) + if library: + kwargs["embed"] = [library] + _go_test( + name = name, + **kwargs + ) + +def py_requirement(name, direct = False): + return _py_requirement(name) + +def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs): + values = { + "@bazel_tools//src/conditions:linux_x86_64": amd64, + "@bazel_tools//src/conditions:linux_aarch64": arm64, + } + if default: + values["//conditions:default"] = default + return select(values, **kwargs) + +def select_system(linux = ["__linux__"], **kwargs): + return linux # Only Linux supported. + +def default_installer(): + return None + +def default_net_util(): + return [] # Nothing needed. diff --git a/tools/checkunsafe/BUILD b/tools/checkunsafe/BUILD index d85c56131..92ba8ab06 100644 --- a/tools/checkunsafe/BUILD +++ b/tools/checkunsafe/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_tool_library") +load("//tools:defs.bzl", "go_tool_library") package(licenses = ["notice"]) go_tool_library( name = "checkunsafe", srcs = ["check_unsafe.go"], - importpath = "checkunsafe", visibility = ["//visibility:public"], deps = [ "@org_golang_x_tools//go/analysis:go_tool_library", diff --git a/tools/defs.bzl b/tools/defs.bzl new file mode 100644 index 000000000..819f12b0d --- /dev/null +++ b/tools/defs.bzl @@ -0,0 +1,154 @@ +"""Wrappers for common build rules. + +These wrappers apply common BUILD configurations (e.g., proto_library +automagically creating cc_ and go_ proto targets) and act as a single point of +change for Google-internal and bazel-compatible rules. +""" + +load("//tools/go_stateify:defs.bzl", "go_stateify") +load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") +load("//tools/build:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system") + +# Delegate directly. +cc_binary = _cc_binary +cc_library = _cc_library +cc_test = _cc_test +cc_toolchain = _cc_toolchain +cc_flags_supplier = _cc_flags_supplier +container_image = _container_image +go_embed_data = _go_embed_data +go_image = _go_image +go_test = _go_test +go_tool_library = _go_tool_library +pkg_deb = _pkg_deb +pkg_tar = _pkg_tar +py_library = _py_library +py_binary = _py_binary +py_test = _py_test +py_requirement = _py_requirement +select_arch = _select_arch +select_system = _select_system +loopback = _loopback +default_installer = _default_installer +default_net_util = _default_net_util + +def go_binary(name, **kwargs): + """Wraps the standard go_binary. + + Args: + name: the rule name. + **kwargs: standard go_binary arguments. + """ + _go_binary( + name = name, + **kwargs + ) + +def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, **kwargs): + """Wraps the standard go_library and does stateification and marshalling. + + The recommended way is to use this rule with mostly identical configuration as the native + go_library rule. + + These definitions provide additional flags (stateify, marshal) that can be used + with the generators to automatically supplement the library code. + + load("//tools:defs.bzl", "go_library") + + go_library( + name = "foo", + srcs = ["foo.go"], + ) + + Args: + name: the rule name. + srcs: the library sources. + deps: the library dependencies. + imports: imports required for stateify. + stateify: whether statify is enabled (default: true). + marshal: whether marshal is enabled (default: false). + **kwargs: standard go_library arguments. + """ + if stateify: + # Only do stateification for non-state packages without manual autogen. + go_stateify( + name = name + "_state_autogen", + srcs = [src for src in srcs if src.endswith(".go")], + imports = imports, + package = name, + arch = select_arch(), + out = name + "_state_autogen.go", + ) + all_srcs = srcs + [name + "_state_autogen.go"] + if "//pkg/state" not in deps: + all_deps = deps + ["//pkg/state"] + else: + all_deps = deps + else: + all_deps = deps + all_srcs = srcs + if marshal: + go_marshal( + name = name + "_abi_autogen", + srcs = [src for src in srcs if src.endswith(".go")], + debug = False, + imports = imports, + package = name, + ) + extra_deps = [ + dep + for dep in marshal_deps + if not dep in all_deps + ] + all_deps = all_deps + extra_deps + all_srcs = srcs + [name + "_abi_autogen_unsafe.go"] + + _go_library( + name = name, + srcs = all_srcs, + deps = all_deps, + **kwargs + ) + + if marshal: + # Ignore importpath for go_test. + kwargs.pop("importpath", None) + + _go_test( + name = name + "_abi_autogen_test", + srcs = [name + "_abi_autogen_test.go"], + library = ":" + name, + deps = marshal_test_deps, + **kwargs + ) + +def proto_library(name, srcs, **kwargs): + """Wraps the standard proto_library. + + Given a proto_library named "foo", this produces three different targets: + - foo_proto: proto_library rule. + - foo_go_proto: go_proto_library rule. + - foo_cc_proto: cc_proto_library rule. + + Args: + srcs: the proto sources. + **kwargs: standard proto_library arguments. + """ + deps = kwargs.pop("deps", []) + _proto_library( + name = name + "_proto", + srcs = srcs, + deps = deps, + **kwargs + ) + _go_proto_library( + name = name + "_go_proto", + proto = ":" + name + "_proto", + deps = deps, + **kwargs + ) + _cc_proto_library( + name = name + "_cc_proto", + deps = [":" + name + "_proto"], + **kwargs + ) diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD index 39318b877..069df3856 100644 --- a/tools/go_generics/BUILD +++ b/tools/go_generics/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/tools/go_generics/globals/BUILD b/tools/go_generics/globals/BUILD index 74853c7d2..38caa3ce7 100644 --- a/tools/go_generics/globals/BUILD +++ b/tools/go_generics/globals/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,6 +8,6 @@ go_library( "globals_visitor.go", "scope.go", ], - importpath = "gvisor.dev/gvisor/tools/go_generics/globals", + stateify = False, visibility = ["//tools/go_generics:__pkg__"], ) diff --git a/tools/go_generics/go_merge/BUILD b/tools/go_generics/go_merge/BUILD index 02b09120e..b7d35e272 100644 --- a/tools/go_generics/go_merge/BUILD +++ b/tools/go_generics/go_merge/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/tools/go_generics/rules_tests/BUILD b/tools/go_generics/rules_tests/BUILD index 9d26a88b7..8a329dfc6 100644 --- a/tools/go_generics/rules_tests/BUILD +++ b/tools/go_generics/rules_tests/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD index c862b277c..80d9c0504 100644 --- a/tools/go_marshal/BUILD +++ b/tools/go_marshal/BUILD @@ -1,6 +1,6 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") -package(licenses = ["notice"]) +licenses(["notice"]) go_binary( name = "go_marshal", diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md index 481575bd3..4886efddf 100644 --- a/tools/go_marshal/README.md +++ b/tools/go_marshal/README.md @@ -20,19 +20,7 @@ comment `// +marshal`. # Usage -See `defs.bzl`: two new rules are provided, `go_marshal` and `go_library`. - -The recommended way to generate a go library with marshalling is to use the -`go_library` with mostly identical configuration as the native go_library rule. - -``` -load("/gvisor/tools/go_marshal:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], -) -``` +See `defs.bzl`: a new rule is provided, `go_marshal`. Under the hood, the `go_marshal` rule is used to generate a file that will appear in a Go target; the output file should appear explicitly in a srcs list. @@ -54,11 +42,7 @@ go_library( "foo.go", "foo_abi.go", ], - deps = [ - "/gvisor/pkg/abi", - "/gvisor/pkg/sentry/safemem/safemem", - "/gvisor/pkg/sentry/usermem/usermem", - ], + ... ) ``` @@ -69,22 +53,6 @@ These tests use reflection to verify properties of the ABI struct, and should be considered part of the generated interfaces (but are too expensive to execute at runtime). Ensure these tests run at some point. -``` -$ cat BUILD -load("/gvisor/tools/go_marshal:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], -) -$ blaze build :foo -$ blaze query ... -:foo_abi_autogen -:foo_abi_autogen_test -$ blaze test :foo_abi_autogen_test - -``` - # Restrictions Not all valid go type definitions can be used with `go_marshal`. `go_marshal` is @@ -131,22 +99,6 @@ for embedded structs that are not aligned. Because of this, it's generally best to avoid using `marshal:"unaligned"` and insert explicit padding fields instead. -## Debugging go_marshal - -To enable debugging output from the go marshal tool, pass the `-debug` flag to -the tool. When using the build rules from above, add a `debug = True` field to -the build rule like this: - -``` -load("/gvisor/tools/go_marshal:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], - debug = True, -) -``` - ## Modifying the `go_marshal` Tool The following are some guidelines for modifying the `go_marshal` tool: diff --git a/tools/go_marshal/analysis/BUILD b/tools/go_marshal/analysis/BUILD index c859ced77..c2a4d45c4 100644 --- a/tools/go_marshal/analysis/BUILD +++ b/tools/go_marshal/analysis/BUILD @@ -1,12 +1,11 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "analysis", testonly = 1, srcs = ["analysis_unsafe.go"], - importpath = "gvisor.dev/gvisor/tools/go_marshal/analysis", visibility = [ "//:sandbox", ], diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl index c32eb559f..2918ceffe 100644 --- a/tools/go_marshal/defs.bzl +++ b/tools/go_marshal/defs.bzl @@ -1,57 +1,14 @@ -"""Marshal is a tool for generating marshalling interfaces for Go types. - -The recommended way is to use the go_library rule defined below with mostly -identical configuration as the native go_library rule. - -load("//tools/go_marshal:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], -) - -Under the hood, the go_marshal rule is used to generate a file that will -appear in a Go target; the output file should appear explicitly in a srcs list. -For example (the above is still the preferred way): - -load("//tools/go_marshal:defs.bzl", "go_marshal") - -go_marshal( - name = "foo_abi", - srcs = ["foo.go"], - out = "foo_abi.go", - package = "foo", -) - -go_library( - name = "foo", - srcs = [ - "foo.go", - "foo_abi.go", - ], - deps = [ - "//tools/go_marshal:marshal", - "//pkg/sentry/platform/safecopy", - "//pkg/sentry/usermem", - ], -) -""" - -load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test") +"""Marshal is a tool for generating marshalling interfaces for Go types.""" def _go_marshal_impl(ctx): """Execute the go_marshal tool.""" output = ctx.outputs.lib output_test = ctx.outputs.test - (build_dir, _, _) = ctx.build_file_path.rpartition("/BUILD") - - decl = "/".join(["gvisor.dev/gvisor", build_dir]) # Run the marshal command. args = ["-output=%s" % output.path] args += ["-pkg=%s" % ctx.attr.package] args += ["-output_test=%s" % output_test.path] - args += ["-declarationPkg=%s" % decl] if ctx.attr.debug: args += ["-debug"] @@ -83,7 +40,6 @@ go_marshal = rule( implementation = _go_marshal_impl, attrs = { "srcs": attr.label_list(mandatory = True, allow_files = True), - "libname": attr.string(mandatory = True), "imports": attr.string_list(mandatory = False), "package": attr.string(mandatory = True), "debug": attr.bool(doc = "enable debugging output from the go_marshal tool"), @@ -95,58 +51,14 @@ go_marshal = rule( }, ) -def go_library(name, srcs, deps = [], imports = [], debug = False, **kwargs): - """wraps the standard go_library and does mashalling interface generation. - - Args: - name: Same as native go_library. - srcs: Same as native go_library. - deps: Same as native go_library. - imports: Extra import paths to pass to the go_marshal tool. - debug: Enables debugging output from the go_marshal tool. - **kwargs: Remaining args to pass to the native go_library rule unmodified. - """ - go_marshal( - name = name + "_abi_autogen", - libname = name, - srcs = [src for src in srcs if src.endswith(".go")], - debug = debug, - imports = imports, - package = name, - ) - - extra_deps = [ - "//tools/go_marshal/marshal", - "//pkg/sentry/platform/safecopy", - "//pkg/sentry/usermem", - ] - - all_srcs = srcs + [name + "_abi_autogen_unsafe.go"] - all_deps = deps + [] # + extra_deps - - for extra in extra_deps: - if extra not in deps: - all_deps.append(extra) - - _go_library( - name = name, - srcs = all_srcs, - deps = all_deps, - **kwargs - ) - - # Don't pass importpath arg to go_test. - kwargs.pop("importpath", "") - - _go_test( - name = name + "_abi_autogen_test", - srcs = [name + "_abi_autogen_test.go"], - # Generated test has a fixed set of dependencies since we generate these - # tests. They should only depend on the library generated above, and the - # Marshallable interface. - deps = [ - ":" + name, - "//tools/go_marshal/analysis", - ], - **kwargs - ) +# marshal_deps are the dependencies requied by generated code. +marshal_deps = [ + "//tools/go_marshal/marshal", + "//pkg/sentry/platform/safecopy", + "//pkg/sentry/usermem", +] + +# marshal_test_deps are required by test targets. +marshal_test_deps = [ + "//tools/go_marshal/analysis", +] diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD index a0eae6492..c92b59dd6 100644 --- a/tools/go_marshal/gomarshal/BUILD +++ b/tools/go_marshal/gomarshal/BUILD @@ -1,6 +1,6 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "gomarshal", @@ -10,7 +10,7 @@ go_library( "generator_tests.go", "util.go", ], - importpath = "gvisor.dev/gvisor/tools/go_marshal/gomarshal", + stateify = False, visibility = [ "//:sandbox", ], diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 641ccd938..8392f3f6d 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -62,15 +62,12 @@ type Generator struct { outputTest *os.File // Package name for the generated file. pkg string - // Go import path for package we're processing. This package should directly - // declare the type we're generating code for. - declaration string // Set of extra packages to import in the generated file. imports *importTable } // NewGenerator creates a new code Generator. -func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports []string) (*Generator, error) { +func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*Generator, error) { f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err) @@ -80,12 +77,11 @@ func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err) } g := Generator{ - inputs: srcs, - output: f, - outputTest: fTest, - pkg: pkg, - declaration: declaration, - imports: newImportTable(), + inputs: srcs, + output: f, + outputTest: fTest, + pkg: pkg, + imports: newImportTable(), } for _, i := range imports { // All imports on the extra imports list are unconditionally marked as @@ -264,7 +260,7 @@ func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interface // generateOneTestSuite generates a test suite for the automatically generated // implementations type t. func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator { - i := newTestGenerator(t, g.declaration) + i := newTestGenerator(t) i.emitTests() return i } @@ -359,7 +355,7 @@ func (g *Generator) Run() error { // source file. func (g *Generator) writeTests(ts []*testGenerator) error { var b sourceBuffer - b.emit("package %s_test\n\n", g.pkg) + b.emit("package %s\n\n", g.pkg) if err := b.write(g.outputTest); err != nil { return err } diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index df25cb5b2..bcda17c3b 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -46,7 +46,7 @@ type testGenerator struct { decl *importStmt } -func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator { +func newTestGenerator(t *ast.TypeSpec) *testGenerator { if _, ok := t.Type.(*ast.StructType); !ok { panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t)) } @@ -59,14 +59,12 @@ func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator { for _, i := range standardImports { g.imports.add(i).markUsed() } - g.decl = g.imports.add(declaration) - g.decl.markUsed() return g } func (g *testGenerator) typeName() string { - return fmt.Sprintf("%s.%s", g.decl.name, g.t.Name.Name) + return g.t.Name.Name } func (g *testGenerator) forEachField(fn func(f *ast.Field)) { diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go index 3d12eb93c..e1a97b311 100644 --- a/tools/go_marshal/main.go +++ b/tools/go_marshal/main.go @@ -31,11 +31,10 @@ import ( ) var ( - pkg = flag.String("pkg", "", "output package") - output = flag.String("output", "", "output file") - outputTest = flag.String("output_test", "", "output file for tests") - imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code") - declarationPkg = flag.String("declarationPkg", "", "import path of target declaring the types we're generating on") + pkg = flag.String("pkg", "", "output package") + output = flag.String("output", "", "output file") + outputTest = flag.String("output_test", "", "output file for tests") + imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code") ) func main() { @@ -62,7 +61,7 @@ func main() { // as an import. extraImports = strings.Split(*imports, ",") } - g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, *declarationPkg, extraImports) + g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, extraImports) if err != nil { panic(err) } diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD index 47dda97a1..ad508c72f 100644 --- a/tools/go_marshal/marshal/BUILD +++ b/tools/go_marshal/marshal/BUILD @@ -1,13 +1,12 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "marshal", srcs = [ "marshal.go", ], - importpath = "gvisor.dev/gvisor/tools/go_marshal/marshal", visibility = [ "//:sandbox", ], diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index d412e1ccf..38ba49fed 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -1,7 +1,6 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_marshal:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") -package(licenses = ["notice"]) +licenses(["notice"]) package_group( name = "gomarshal_test", @@ -25,6 +24,6 @@ go_library( name = "test", testonly = 1, srcs = ["test.go"], - importpath = "gvisor.dev/gvisor/tools/go_marshal/test", + marshal = True, deps = ["//tools/go_marshal/test/external"], ) diff --git a/tools/go_marshal/test/external/BUILD b/tools/go_marshal/test/external/BUILD index 9bb89e1da..0cf6da603 100644 --- a/tools/go_marshal/test/external/BUILD +++ b/tools/go_marshal/test/external/BUILD @@ -1,11 +1,11 @@ -load("//tools/go_marshal:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") -package(licenses = ["notice"]) +licenses(["notice"]) go_library( name = "external", testonly = 1, srcs = ["external.go"], - importpath = "gvisor.dev/gvisor/tools/go_marshal/test/external", + marshal = True, visibility = ["//tools/go_marshal/test:gomarshal_test"], ) diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD index bb53f8ae9..a133d6f8b 100644 --- a/tools/go_stateify/BUILD +++ b/tools/go_stateify/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl index 33267c074..0f261d89f 100644 --- a/tools/go_stateify/defs.bzl +++ b/tools/go_stateify/defs.bzl @@ -1,41 +1,4 @@ -"""Stateify is a tool for generating state wrappers for Go types. - -The recommended way is to use the go_library rule defined below with mostly -identical configuration as the native go_library rule. - -load("//tools/go_stateify:defs.bzl", "go_library") - -go_library( - name = "foo", - srcs = ["foo.go"], -) - -Under the hood, the go_stateify rule is used to generate a file that will -appear in a Go target; the output file should appear explicitly in a srcs list. -For example (the above is still the preferred way): - -load("//tools/go_stateify:defs.bzl", "go_stateify") - -go_stateify( - name = "foo_state", - srcs = ["foo.go"], - out = "foo_state.go", - package = "foo", -) - -go_library( - name = "foo", - srcs = [ - "foo.go", - "foo_state.go", - ], - deps = [ - "//pkg/state", - ], -) -""" - -load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library") +"""Stateify is a tool for generating state wrappers for Go types.""" def _go_stateify_impl(ctx): """Implementation for the stateify tool.""" @@ -103,43 +66,3 @@ files and must be added to the srcs of the relevant go_library. "_statepkg": attr.string(default = "gvisor.dev/gvisor/pkg/state"), }, ) - -def go_library(name, srcs, deps = [], imports = [], **kwargs): - """Standard go_library wrapped which generates state source files. - - Args: - name: the name of the go_library rule. - srcs: sources of the go_library. Each will be processed for stateify - annotations. - deps: dependencies for the go_library. - imports: an optional list of extra non-aliased, Go-style absolute import - paths required for stateified types. - **kwargs: passed to go_library. - """ - if "encode_unsafe.go" not in srcs and (name + "_state_autogen.go") not in srcs: - # Only do stateification for non-state packages without manual autogen. - go_stateify( - name = name + "_state_autogen", - srcs = [src for src in srcs if src.endswith(".go")], - imports = imports, - package = name, - arch = select({ - "@bazel_tools//src/conditions:linux_aarch64": "arm64", - "//conditions:default": "amd64", - }), - out = name + "_state_autogen.go", - ) - all_srcs = srcs + [name + "_state_autogen.go"] - if "//pkg/state" not in deps: - all_deps = deps + ["//pkg/state"] - else: - all_deps = deps - else: - all_deps = deps - all_srcs = srcs - _go_library( - name = name, - srcs = all_srcs, - deps = all_deps, - **kwargs - ) diff --git a/tools/images/BUILD b/tools/images/BUILD index 2b77c2737..f1699b184 100644 --- a/tools/images/BUILD +++ b/tools/images/BUILD @@ -1,4 +1,4 @@ -load("@rules_cc//cc:defs.bzl", "cc_binary") +load("//tools:defs.bzl", "cc_binary") load("//tools/images:defs.bzl", "vm_image", "vm_test") package( diff --git a/tools/images/defs.bzl b/tools/images/defs.bzl index d8e422a5d..32235813a 100644 --- a/tools/images/defs.bzl +++ b/tools/images/defs.bzl @@ -28,6 +28,8 @@ The vm_test rule can be used to execute a command remotely. For example, ) """ +load("//tools:defs.bzl", "default_installer") + def _vm_image_impl(ctx): script_paths = [] for script in ctx.files.scripts: @@ -165,8 +167,8 @@ def vm_test( targets = kwargs.pop("targets", []) if installer: targets = [installer] + targets - targets = [ - ] + targets + if default_installer(): + targets = [default_installer()] + targets _vm_test( tags = [ "local", diff --git a/tools/issue_reviver/BUILD b/tools/issue_reviver/BUILD index ee7ea11fd..4ef1a3124 100644 --- a/tools/issue_reviver/BUILD +++ b/tools/issue_reviver/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/tools/issue_reviver/github/BUILD b/tools/issue_reviver/github/BUILD index 6da22ba1c..da4133472 100644 --- a/tools/issue_reviver/github/BUILD +++ b/tools/issue_reviver/github/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "github", srcs = ["github.go"], - importpath = "gvisor.dev/gvisor/tools/issue_reviver/github", visibility = [ "//tools/issue_reviver:__subpackages__", ], diff --git a/tools/issue_reviver/reviver/BUILD b/tools/issue_reviver/reviver/BUILD index 2c3675977..d262932bd 100644 --- a/tools/issue_reviver/reviver/BUILD +++ b/tools/issue_reviver/reviver/BUILD @@ -1,11 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "reviver", srcs = ["reviver.go"], - importpath = "gvisor.dev/gvisor/tools/issue_reviver/reviver", visibility = [ "//tools/issue_reviver:__subpackages__", ], @@ -15,5 +14,5 @@ go_test( name = "reviver_test", size = "small", srcs = ["reviver_test.go"], - embed = [":reviver"], + library = ":reviver", ) diff --git a/tools/workspace_status.sh b/tools/workspace_status.sh index fb09ff331..a22c8c9f2 100755 --- a/tools/workspace_status.sh +++ b/tools/workspace_status.sh @@ -15,4 +15,4 @@ # limitations under the License. # The STABLE_ prefix will trigger a re-link if it changes. -echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty) +echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty || echo 0.0.0) diff --git a/vdso/BUILD b/vdso/BUILD index 2b6744c26..d37d4266d 100644 --- a/vdso/BUILD +++ b/vdso/BUILD @@ -3,20 +3,10 @@ # normal system VDSO (time, gettimeofday, clock_gettimeofday) but which uses # timekeeping parameters managed by the sandbox kernel. -load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", "cc_flags_supplier") +load("//tools:defs.bzl", "cc_flags_supplier", "cc_toolchain", "select_arch") package(licenses = ["notice"]) -config_setting( - name = "x86_64", - constraint_values = ["@bazel_tools//platforms:x86_64"], -) - -config_setting( - name = "aarch64", - constraint_values = ["@bazel_tools//platforms:aarch64"], -) - genrule( name = "vdso", srcs = [ @@ -39,14 +29,15 @@ genrule( "-O2 " + "-std=c++11 " + "-fPIC " + + "-fno-sanitize=all " + # Some toolchains enable stack protector by default. Disable it, the # VDSO has no hooks to handle failures. "-fno-stack-protector " + "-fuse-ld=gold " + - select({ - ":x86_64": "-m64 ", - "//conditions:default": "", - }) + + select_arch( + amd64 = "-m64 ", + arm64 = "", + ) + "-shared " + "-nostdlib " + "-Wl,-soname=linux-vdso.so.1 " + @@ -55,12 +46,10 @@ genrule( "-Wl,-Bsymbolic " + "-Wl,-z,max-page-size=4096 " + "-Wl,-z,common-page-size=4096 " + - select( - { - ":x86_64": "-Wl,-T$(location vdso_amd64.lds) ", - ":aarch64": "-Wl,-T$(location vdso_arm64.lds) ", - }, - no_match_error = "Unsupported architecture", + select_arch( + amd64 = "-Wl,-T$(location vdso_amd64.lds) ", + arm64 = "-Wl,-T$(location vdso_arm64.lds) ", + no_match_error = "unsupported architecture", ) + "-o $(location vdso.so) " + "$(location vdso.cc) " + @@ -73,7 +62,7 @@ genrule( ], features = ["-pie"], toolchains = [ - "@bazel_tools//tools/cpp:current_cc_toolchain", + cc_toolchain, ":no_pie_cc_flags", ], visibility = ["//:sandbox"], -- cgit v1.2.3 From 6b14be4246e8ed3779bf69dbd59e669caf3f5704 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Mon, 27 Jan 2020 10:08:18 -0800 Subject: Refactor to hide C from channel.Endpoint. This is to aid later implementation for /dev/net/tun device. PiperOrigin-RevId: 291746025 --- pkg/tcpip/link/channel/channel.go | 43 ++++- pkg/tcpip/network/arp/arp_test.go | 16 +- pkg/tcpip/network/ipv6/icmp_test.go | 7 +- pkg/tcpip/stack/ndp_test.go | 87 +++++----- pkg/tcpip/stack/stack_test.go | 4 +- pkg/tcpip/stack/transport_test.go | 6 +- pkg/tcpip/transport/tcp/testing/context/context.go | 86 +++++----- pkg/tcpip/transport/udp/udp_test.go | 182 +++++++++++---------- 8 files changed, 229 insertions(+), 202 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 70188551f..71b9da797 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -18,6 +18,8 @@ package channel import ( + "context" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -38,25 +40,52 @@ type Endpoint struct { linkAddr tcpip.LinkAddress GSO bool - // C is where outbound packets are queued. - C chan PacketInfo + // c is where outbound packets are queued. + c chan PacketInfo } // New creates a new channel endpoint. func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { return &Endpoint{ - C: make(chan PacketInfo, size), + c: make(chan PacketInfo, size), mtu: mtu, linkAddr: linkAddr, } } +// Close closes e. Further packet injections will panic. Reads continue to +// succeed until all packets are read. +func (e *Endpoint) Close() { + close(e.c) +} + +// Read does non-blocking read for one packet from the outbound packet queue. +func (e *Endpoint) Read() (PacketInfo, bool) { + select { + case pkt := <-e.c: + return pkt, true + default: + return PacketInfo{}, false + } +} + +// ReadContext does blocking read for one packet from the outbound packet queue. +// It can be cancelled by ctx, and in this case, it returns false. +func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) { + select { + case pkt := <-e.c: + return pkt, true + case <-ctx.Done(): + return PacketInfo{}, false + } +} + // Drain removes all outbound packets from the channel and counts them. func (e *Endpoint) Drain() int { c := 0 for { select { - case <-e.C: + case <-e.c: c++ default: return c @@ -125,7 +154,7 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.Ne } select { - case e.C <- p: + case e.c <- p: default: } @@ -150,7 +179,7 @@ packetLoop: } select { - case e.C <- p: + case e.c <- p: n++ default: break packetLoop @@ -169,7 +198,7 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { } select { - case e.C <- p: + case e.c <- p: default: } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 8e6048a21..03cf03b6d 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -15,6 +15,7 @@ package arp_test import ( + "context" "strconv" "testing" "time" @@ -83,7 +84,7 @@ func newTestContext(t *testing.T) *testContext { } func (c *testContext) cleanup() { - close(c.linkEP.C) + c.linkEP.Close() } func TestDirectRequest(t *testing.T) { @@ -110,7 +111,7 @@ func TestDirectRequest(t *testing.T) { for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { t.Run(strconv.Itoa(i), func(t *testing.T) { inject(address) - pi := <-c.linkEP.C + pi, _ := c.linkEP.ReadContext(context.Background()) if pi.Proto != arp.ProtocolNumber { t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } @@ -134,12 +135,11 @@ func TestDirectRequest(t *testing.T) { } inject(stackAddrBad) - select { - case pkt := <-c.linkEP.C: + // Sleep tests are gross, but this will only potentially flake + // if there's a bug. If there is no bug this will reliably + // succeed. + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + if pkt, ok := c.linkEP.ReadContext(ctx); ok { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) - case <-time.After(100 * time.Millisecond): - // Sleep tests are gross, but this will only potentially flake - // if there's a bug. If there is no bug this will reliably - // succeed. } } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index a2fdc5dcd..7a6820643 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "context" "reflect" "strings" "testing" @@ -264,8 +265,8 @@ func newTestContext(t *testing.T) *testContext { } func (c *testContext) cleanup() { - close(c.linkEP0.C) - close(c.linkEP1.C) + c.linkEP0.Close() + c.linkEP1.Close() } type routeArgs struct { @@ -276,7 +277,7 @@ type routeArgs struct { func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pi := <-args.src.C + pi, _ := args.src.ReadContext(context.Background()) { views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index f9460bd51..ad2c6f601 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "context" "encoding/binary" "fmt" "testing" @@ -405,7 +406,7 @@ func TestDADResolve(t *testing.T) { // Validate the sent Neighbor Solicitation messages. for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p := <-e.C + p, _ := e.ReadContext(context.Background()) // Make sure its an IPv6 packet. if p.Proto != header.IPv6ProtocolNumber { @@ -3285,29 +3286,29 @@ func TestRouterSolicitation(t *testing.T) { e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1) waitForPkt := func(timeout time.Duration) { t.Helper() - select { - case p := <-e.C: - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, - p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(), - ) - - case <-time.After(timeout): + ctx, _ := context.WithTimeout(context.Background(), timeout) + p, ok := e.ReadContext(ctx) + if !ok { t.Fatal("timed out waiting for packet") + return } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, + p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(), + ) } waitForNothing := func(timeout time.Duration) { t.Helper() - select { - case <-e.C: + ctx, _ := context.WithTimeout(context.Background(), timeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet") - case <-time.After(timeout): } } s := stack.New(stack.Options{ @@ -3362,20 +3363,21 @@ func TestStopStartSolicitingRouters(t *testing.T) { e := channel.New(maxRtrSolicitations, 1280, linkAddr1) waitForPkt := func(timeout time.Duration) { t.Helper() - select { - case p := <-e.C: - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS()) - - case <-time.After(timeout): + ctx, _ := context.WithTimeout(context.Background(), timeout) + p, ok := e.ReadContext(ctx) + if !ok { t.Fatal("timed out waiting for packet") + return } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS()) } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, @@ -3391,23 +3393,20 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Enable forwarding which should stop router solicitations. s.SetForwarding(true) - select { - case <-e.C: + ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { // A single RS may have been sent before forwarding was enabled. - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) + if _, ok = e.ReadContext(ctx); ok { t.Fatal("Should not have sent more than one RS message") - case <-time.After(interval + defaultTimeout): } - case <-time.After(delay + defaultTimeout): } // Enabling forwarding again should do nothing. s.SetForwarding(true) - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after becoming a router") - case <-time.After(delay + defaultTimeout): } // Disable forwarding which should start router solicitations. @@ -3415,17 +3414,15 @@ func TestStopStartSolicitingRouters(t *testing.T) { waitForPkt(delay + defaultAsyncEventTimeout) waitForPkt(interval + defaultAsyncEventTimeout) waitForPkt(interval + defaultAsyncEventTimeout) - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") - case <-time.After(interval + defaultTimeout): } // Disabling forwarding again should do nothing. s.SetForwarding(false) - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after becoming a router") - case <-time.After(delay + defaultTimeout): } } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index dad288642..834fe9487 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1880,9 +1880,7 @@ func TestNICForwarding(t *testing.T) { Data: buf.ToVectorisedView(), }) - select { - case <-ep2.C: - default: + if _, ok := ep2.Read(); !ok { t.Fatal("Packet not forwarded") } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index f50604a8a..869c69a6d 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -623,10 +623,8 @@ func TestTransportForwarding(t *testing.T) { t.Fatalf("Write failed: %v", err) } - var p channel.PacketInfo - select { - case p = <-ep2.C: - default: + p, ok := ep2.Read() + if !ok { t.Fatal("Response packet not forwarded") } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 822907998..730ac4292 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -18,6 +18,7 @@ package context import ( "bytes" + "context" "testing" "time" @@ -215,11 +216,9 @@ func (c *Context) Stack() *stack.Stack { func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { c.t.Helper() - select { - case <-c.linkEP.C: + ctx, _ := context.WithTimeout(context.Background(), wait) + if _, ok := c.linkEP.ReadContext(ctx); ok { c.t.Fatal(errMsg) - - case <-time.After(wait): } } @@ -234,27 +233,27 @@ func (c *Context) CheckNoPacket(errMsg string) { // 2 seconds. func (c *Context) GetPacket() []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + c.t.Fatalf("Packet wasn't written out") + return nil + } - if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { - c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) - } + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - case <-time.After(2 * time.Second): - c.t.Fatalf("Packet wasn't written out") + if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { + c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) } - return nil + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b } // GetPacketNonBlocking reads a packet from the link layer endpoint @@ -263,20 +262,21 @@ func (c *Context) GetPacket() []byte { // nil immediately. func (c *Context) GetPacketNonBlocking() []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b - default: + p, ok := c.linkEP.Read() + if !ok { return nil } + + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b } // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. @@ -484,23 +484,23 @@ func (c *Context) CreateV6Endpoint(v6only bool) { // and asserts that it is an IPv6 Packet with the expected src/dest addresses. func (c *Context) GetV6Packet() []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv6.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) - } - b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) - copy(b, p.Pkt.Header.View()) - copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) - - checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) - return b - case <-time.After(2 * time.Second): + ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { c.t.Fatalf("Packet wasn't written out") + return nil + } + + if p.Proto != ipv6.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } + b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) + copy(b, p.Pkt.Header.View()) + copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) - return nil + checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) + return b } // SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index c6927cfe3..f0ff3fe71 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,6 +16,7 @@ package udp_test import ( "bytes" + "context" "fmt" "math/rand" "testing" @@ -357,30 +358,29 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != flow.netProto() { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) - } - - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - h := flow.header4Tuple(outgoing) - checkers := append( - checkers, - checker.SrcAddr(h.srcAddr.Addr), - checker.DstAddr(h.dstAddr.Addr), - checker.UDP(checker.DstPort(h.dstAddr.Port)), - ) - flow.checkerFn()(c.t, b, checkers...) - return b - - case <-time.After(2 * time.Second): + ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { c.t.Fatalf("Packet wasn't written out") + return nil } - return nil + if p.Proto != flow.netProto() { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) + } + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + + h := flow.header4Tuple(outgoing) + checkers = append( + checkers, + checker.SrcAddr(h.srcAddr.Addr), + checker.DstAddr(h.dstAddr.Addr), + checker.UDP(checker.DstPort(h.dstAddr.Port)), + ) + flow.checkerFn()(c.t, b, checkers...) + return b } // injectPacket creates a packet of the given flow and with the given payload, @@ -1541,48 +1541,50 @@ func TestV4UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - select { - case p := <-c.linkEP.C: + ctx, _ := context.WithTimeout(context.Background(), time.Second) + if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) - case <-time.After(1 * time.Second): - return } + return } - select { - case p := <-c.linkEP.C: - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) - if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } + // ICMP required. + ctx, _ := context.WithTimeout(context.Background(), time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + t.Fatalf("packet wasn't written out") + return + } - hdr := header.IPv4(pkt) - checker.IPv4(t, hdr, checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + var pkt []byte + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) + if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } - icmpPkt := header.ICMPv4(hdr.Payload()) - payloadIPHeader := header.IPv4(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize - } + hdr := header.IPv4(pkt) + checker.IPv4(t, hdr, checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + icmpPkt := header.ICMPv4(hdr.Payload()) + payloadIPHeader := header.IPv4(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize + } - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %d, want: %d", got, want) - } - case <-time.After(1 * time.Second): - t.Fatalf("packet wasn't written out") + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %d, want: %d", got, want) } }) } @@ -1615,47 +1617,49 @@ func TestV6UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - select { - case p := <-c.linkEP.C: + ctx, _ := context.WithTimeout(context.Background(), time.Second) + if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) - case <-time.After(1 * time.Second): - return } + return } - select { - case p := <-c.linkEP.C: - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) - if got, want := len(pkt), header.IPv6MinimumMTU; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } + // ICMP required. + ctx, _ := context.WithTimeout(context.Background(), time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + t.Fatalf("packet wasn't written out") + return + } + + var pkt []byte + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) + if got, want := len(pkt), header.IPv6MinimumMTU; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } - hdr := header.IPv6(pkt) - checker.IPv6(t, hdr, checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6PortUnreachable))) + hdr := header.IPv6(pkt) + checker.IPv6(t, hdr, checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6PortUnreachable))) - icmpPkt := header.ICMPv6(hdr.Payload()) - payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize - } - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) + icmpPkt := header.ICMPv6(hdr.Payload()) + payloadIPHeader := header.IPv6(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + } + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %v, want: %v", got, want) - } - case <-time.After(1 * time.Second): - t.Fatalf("packet wasn't written out") + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %v, want: %v", got, want) } }) } -- cgit v1.2.3 From ce0bac4be9d808877248c328fac07ff0d66b9607 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 28 Jan 2020 13:37:10 -0800 Subject: Include the NDP Source Link Layer option when sending DAD messages Test: stack_test.TestDADResolve PiperOrigin-RevId: 292003124 --- pkg/tcpip/checker/checker.go | 50 +++++++++++++++++++++++++++++++++++++++++ pkg/tcpip/header/icmpv6.go | 2 +- pkg/tcpip/header/ndp_options.go | 35 ++++++++++++++++++++++++++--- pkg/tcpip/stack/ndp.go | 22 ++++++++++++++++-- pkg/tcpip/stack/ndp_test.go | 6 ++++- 5 files changed, 108 insertions(+), 7 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 885d773b0..4d6ae0871 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -771,6 +771,56 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { } } +// NDPNSOptions creates a checker that checks that the packet contains the +// provided NDP options within an NDP Neighbor Solicitation message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPNS message as far as the size is concerned. +func NDPNSOptions(opts []header.NDPOption) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + it, err := ns.Options().Iter(true) + if err != nil { + t.Errorf("opts.Iter(true): %s", err) + return + } + + i := 0 + for { + opt, done, _ := it.Next() + if done { + break + } + + if i >= len(opts) { + t.Errorf("got unexpected option: %s", opt) + continue + } + + switch wantOpt := opts[i].(type) { + case header.NDPSourceLinkLayerAddressOption: + gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) + if !ok { + t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) + } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { + t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) + } + default: + panic("not implemented") + } + + i++ + } + + if missing := opts[i:]; len(missing) > 0 { + t.Errorf("missing options: %s", missing) + } + } +} + // NDPRS creates a checker that checks that the packet contains a valid NDP // Router Solicitation message (as per the raw wire format). func NDPRS() NetworkChecker { diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index b095dc0ab..c7ee2de57 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -52,7 +52,7 @@ const ( // ICMPv6NeighborAdvertSize is size of a neighbor advertisement // including the NDP Target Link Layer option for an Ethernet // address. - ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + ndpLinkLayerAddressSize + ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet. ICMPv6EchoMinimumSize = 8 diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 1e60f3d4f..e6a6ad39b 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -17,6 +17,7 @@ package header import ( "encoding/binary" "errors" + "fmt" "math" "time" @@ -32,9 +33,9 @@ const ( // Address option, as per RFC 4861 section 4.6.1. NDPTargetLinkLayerAddressOptionType = 2 - // ndpLinkLayerAddressSize is the size of a Source or Target Link Layer - // Address option. - ndpLinkLayerAddressSize = 8 + // NDPLinkLayerAddressSize is the size of a Source or Target Link Layer + // Address option for an Ethernet address. + NDPLinkLayerAddressSize = 8 // NDPPrefixInformationType is the type of the Prefix Information // option, as per RFC 4861 section 4.6.2. @@ -300,6 +301,8 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { // NDPOption is the set of functions to be implemented by all NDP option types. type NDPOption interface { + fmt.Stringer + // Type returns the type of the receiver. Type() uint8 @@ -397,6 +400,11 @@ func (o NDPSourceLinkLayerAddressOption) serializeInto(b []byte) int { return copy(b, o) } +// String implements fmt.Stringer.String. +func (o NDPSourceLinkLayerAddressOption) String() string { + return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o)) +} + // EthernetAddress will return an ethernet (MAC) address if the // NDPSourceLinkLayerAddressOption's body has at minimum EthernetAddressSize // bytes. If the body has more than EthernetAddressSize bytes, only the first @@ -432,6 +440,11 @@ func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int { return copy(b, o) } +// String implements fmt.Stringer.String. +func (o NDPTargetLinkLayerAddressOption) String() string { + return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o)) +} + // EthernetAddress will return an ethernet (MAC) address if the // NDPTargetLinkLayerAddressOption's body has at minimum EthernetAddressSize // bytes. If the body has more than EthernetAddressSize bytes, only the first @@ -478,6 +491,17 @@ func (o NDPPrefixInformation) serializeInto(b []byte) int { return used } +// String implements fmt.Stringer.String. +func (o NDPPrefixInformation) String() string { + return fmt.Sprintf("%T(O=%t, A=%t, PL=%s, VL=%s, Prefix=%s)", + o, + o.OnLinkFlag(), + o.AutonomousAddressConfigurationFlag(), + o.PreferredLifetime(), + o.ValidLifetime(), + o.Subnet()) +} + // PrefixLength returns the value in the number of leading bits in the Prefix // that are valid. // @@ -587,6 +611,11 @@ func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { return used } +// String implements fmt.Stringer.String. +func (o NDPRecursiveDNSServer) String() string { + return fmt.Sprintf("%T(%s valid for %s)", o, o.Addresses(), o.Lifetime()) +} + // Lifetime returns the length of time that the DNS server addresses // in this option may be used for name resolution. // diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index d983ac390..245694118 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -538,11 +538,29 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) + linkAddr := ndp.nic.linkEP.LinkAddress() + isValidLinkAddr := header.IsValidUnicastEthernetAddress(linkAddr) + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + if isValidLinkAddr { + // Only include a Source Link Layer Address option if the NIC has a valid + // link layer address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + ndpNSSize += header.NDPLinkLayerAddressSize + } + + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) ns := header.NDPNeighborSolicit(pkt.NDPPayload()) ns.SetTargetAddress(addr) + + if isValidLinkAddr { + ns.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr), + }) + } pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) sent := r.Stats().ICMP.V6PacketsSent diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index ad2c6f601..726468e41 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -417,7 +417,11 @@ func TestDADResolve(t *testing.T) { checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), checker.TTL(header.NDPHopLimit), checker.NDPNS( - checker.NDPNSTargetAddress(addr1))) + checker.NDPNSTargetAddress(addr1), + checker.NDPNSOptions([]header.NDPOption{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }), + )) } }) } -- cgit v1.2.3 From 6f841c304d7bd9af6167d7d049bd5c594358a1b9 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 29 Jan 2020 19:54:11 -0800 Subject: Do not spawn a goroutine when calling stack.NDPDispatcher's methods Do not start a new goroutine when calling stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. PiperOrigin-RevId: 292268574 --- pkg/tcpip/stack/ndp.go | 8 ++++---- pkg/tcpip/stack/ndp_test.go | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 245694118..281ae786d 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -167,8 +167,8 @@ type NDPDispatcher interface { // reason, such as the address being removed). If an error occured // during DAD, err will be set and resolved must be ignored. // - // This function is permitted to block indefinitely without interfering - // with the stack's operation. + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) // OnDefaultRouterDiscovered will be called when a new default router is @@ -607,8 +607,8 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { delete(ndp.dad, addr) // Let the integrator know DAD did not resolve. - if ndp.nic.stack.ndpDisp != nil { - go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 726468e41..8c76e80f2 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -497,7 +497,7 @@ func TestDADFail(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } ndpConfigs := stack.DefaultNDPConfigurations() opts := stack.Options{ @@ -576,7 +576,7 @@ func TestDADFail(t *testing.T) { // removed. func TestDADStop(t *testing.T) { ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } ndpConfigs := stack.NDPConfigurations{ RetransmitTimer: time.Second, -- cgit v1.2.3 From ec0679737e8f9ab31ef6c7c3adb5a0005586b5a7 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 30 Jan 2020 07:12:04 -0800 Subject: Do not include the Source Link Layer option with an unspecified source address When sending NDP messages with an unspecified source address, the Source Link Layer address must not be included. Test: stack_test.TestDADResolve PiperOrigin-RevId: 292341334 --- pkg/tcpip/stack/ndp.go | 22 ++-------------------- pkg/tcpip/stack/ndp_test.go | 12 ++++++++---- 2 files changed, 10 insertions(+), 24 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 281ae786d..31294345d 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -538,29 +538,11 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() - linkAddr := ndp.nic.linkEP.LinkAddress() - isValidLinkAddr := header.IsValidUnicastEthernetAddress(linkAddr) - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize - if isValidLinkAddr { - // Only include a Source Link Layer Address option if the NIC has a valid - // link layer address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by - // LinkEndpoint.LinkAddress) before reaching this point. - ndpNSSize += header.NDPLinkLayerAddressSize - } - - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) pkt.SetType(header.ICMPv6NeighborSolicit) ns := header.NDPNeighborSolicit(pkt.NDPPayload()) ns.SetTargetAddress(addr) - - if isValidLinkAddr { - ns.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr), - }) - } pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) sent := r.Stats().ICMP.V6PacketsSent diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 8c76e80f2..bc7cfbcb4 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -413,14 +413,18 @@ func TestDADResolve(t *testing.T) { t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) } - // Check NDP packet. + // Check NDP NS packet. + // + // As per RFC 4861 section 4.3, a possible option is the Source Link + // Layer option, but this option MUST NOT be included when the source + // address of the packet is the unspecified address. checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.SolicitedNodeAddr(addr1)), checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(addr1), - checker.NDPNSOptions([]header.NDPOption{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }), + checker.NDPNSOptions(nil), )) } }) -- cgit v1.2.3 From 77bf586db75b3dbd9dcb14c349bde8372d26425c Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 31 Jan 2020 13:54:57 -0800 Subject: Use multicast Ethernet address for multicast NDP As per RFC 2464 section 7, an IPv6 packet with a multicast destination address is transmitted to the mapped Ethernet multicast address. Test: - ipv6.TestLinkResolution - stack_test.TestDADResolve - stack_test.TestRouterSolicitation PiperOrigin-RevId: 292610529 --- pkg/tcpip/header/ipv6_test.go | 29 ++++++++++++++++++++++ pkg/tcpip/link/channel/channel.go | 29 ++++++++++++++-------- pkg/tcpip/network/ipv6/icmp.go | 6 ++++- pkg/tcpip/network/ipv6/icmp_test.go | 12 ++++++--- pkg/tcpip/stack/ndp.go | 17 +++++++++++++ pkg/tcpip/stack/ndp_test.go | 16 +++++++++++- pkg/tcpip/stack/route.go | 4 ++- pkg/tcpip/transport/tcp/testing/context/context.go | 6 ++++- 8 files changed, 101 insertions(+), 18 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index 29f54bc57..c3ad503aa 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -17,6 +17,7 @@ package header_test import ( "bytes" "crypto/sha256" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -300,3 +301,31 @@ func TestScopeForIPv6Address(t *testing.T) { }) } } + +func TestSolicitedNodeAddr(t *testing.T) { + tests := []struct { + addr tcpip.Address + want tcpip.Address + }{ + { + addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\xa0", + want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0", + }, + { + addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x0e\x0f\xa0", + want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0", + }, + { + addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x01\x02\x03", + want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x01\x02\x03", + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) { + if got := header.SolicitedNodeAddr(test.addr); got != test.want { + t.Fatalf("got header.SolicitedNodeAddr(%s) = %s, want = %s", test.addr, got, test.want) + } + }) + } +} diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 71b9da797..78d447acd 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -30,15 +30,16 @@ type PacketInfo struct { Pkt tcpip.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO + Route stack.Route } // Endpoint is link layer endpoint that stores outbound packets in a channel // and allows injection of inbound packets. type Endpoint struct { - dispatcher stack.NetworkDispatcher - mtu uint32 - linkAddr tcpip.LinkAddress - GSO bool + dispatcher stack.NetworkDispatcher + mtu uint32 + linkAddr tcpip.LinkAddress + LinkEPCapabilities stack.LinkEndpointCapabilities // c is where outbound packets are queued. c chan PacketInfo @@ -122,11 +123,7 @@ func (e *Endpoint) MTU() uint32 { // Capabilities implements stack.LinkEndpoint.Capabilities. func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { - caps := stack.LinkEndpointCapabilities(0) - if e.GSO { - caps |= stack.CapabilityHardwareGSO - } - return caps + return e.LinkEPCapabilities } // GSOMaxSize returns the maximum GSO packet size. @@ -146,11 +143,16 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + // Clone r then release its resource so we only get the relevant fields from + // stack.Route without holding a reference to a NIC's endpoint. + route := r.Clone() + route.Release() p := PacketInfo{ Pkt: pkt, Proto: protocol, GSO: gso, + Route: route, } select { @@ -162,7 +164,11 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.Ne } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // Clone r then release its resource so we only get the relevant fields from + // stack.Route without holding a reference to a NIC's endpoint. + route := r.Clone() + route.Release() payloadView := pkts[0].Data.ToView() n := 0 packetLoop: @@ -176,6 +182,7 @@ packetLoop: }, Proto: protocol, GSO: gso, + Route: route, } select { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 7491cfc41..60817d36d 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -408,10 +408,14 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { snaddr := header.SolicitedNodeAddr(addr) + + // TODO(b/148672031): Use stack.FindRoute instead of manually creating the + // route here. Note, we would need the nicID to do this properly so the right + // NIC (associated to linkEP) is used to send the NDP NS message. r := &stack.Route{ LocalAddress: localAddr, RemoteAddress: snaddr, - RemoteLinkAddress: broadcastMAC, + RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr), } hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 7a6820643..d0e930e20 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -270,8 +270,9 @@ func (c *testContext) cleanup() { } type routeArgs struct { - src, dst *channel.Endpoint - typ header.ICMPv6Type + src, dst *channel.Endpoint + typ header.ICMPv6Type + remoteLinkAddr tcpip.LinkAddress } func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { @@ -292,6 +293,11 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. t.Errorf("unexpected protocol number %d", pi.Proto) return } + + if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress { + t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) + } + ipv6 := header.IPv6(pi.Pkt.Header.View()) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { @@ -339,7 +345,7 @@ func TestLinkResolution(t *testing.T) { t.Fatalf("ep.Write(_) = _, , %s, want = _, , tcpip.ErrNoLinkAddress", err) } for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit}, + {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, } { routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 31294345d..6123fda33 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -538,6 +538,14 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() + // Route should resolve immediately since snmc is a multicast address so a + // remote link address can be calculated without a resolution process. + if c, err := r.Resolve(nil); err != nil { + log.Fatalf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err) + } else if c != nil { + log.Fatalf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()) + } + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) pkt.SetType(header.ICMPv6NeighborSolicit) @@ -1197,6 +1205,15 @@ func (ndp *ndpState) startSolicitingRouters() { r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() + // Route should resolve immediately since + // header.IPv6AllRoutersMulticastAddress is a multicast address so a + // remote link address can be calculated without a resolution process. + if c, err := r.Resolve(nil); err != nil { + log.Fatalf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err) + } else if c != nil { + log.Fatalf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()) + } + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize) pkt := header.ICMPv6(hdr.Prepend(payloadSize)) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index bc7cfbcb4..8af8565f7 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -336,6 +336,7 @@ func TestDADResolve(t *testing.T) { opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(opts) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -413,6 +414,12 @@ func TestDADResolve(t *testing.T) { t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) } + // Make sure the right remote link address is used. + snmc := header.SolicitedNodeAddr(addr1) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + // Check NDP NS packet. // // As per RFC 4861 section 4.3, a possible option is the Source Link @@ -420,7 +427,7 @@ func TestDADResolve(t *testing.T) { // address of the packet is the unspecified address. checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.SolicitedNodeAddr(addr1)), + checker.DstAddr(snmc), checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(addr1), @@ -3292,6 +3299,7 @@ func TestRouterSolicitation(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired waitForPkt := func(timeout time.Duration) { t.Helper() ctx, _ := context.WithTimeout(context.Background(), timeout) @@ -3304,6 +3312,12 @@ func TestRouterSolicitation(t *testing.T) { if p.Proto != header.IPv6ProtocolNumber { t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) } + + // Make sure the right remote link address is used. + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + checker.IPv6(t, p.Pkt.Header.View(), checker.SrcAddr(header.IPv6Any), diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 517f4b941..f565aafb2 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -225,7 +225,9 @@ func (r *Route) Release() { // Clone Clone a route such that the original one can be released and the new // one will remain valid. func (r *Route) Clone() Route { - r.ref.incRef() + if r.ref != nil { + r.ref.incRef() + } return *r } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 730ac4292..1e9a0dea3 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -1082,7 +1082,11 @@ func (c *Context) SACKEnabled() bool { // SetGSOEnabled enables or disables generic segmentation offload. func (c *Context) SetGSOEnabled(enable bool) { - c.linkEP.GSO = enable + if enable { + c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO + } else { + c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO + } } // MSSWithoutOptions returns the value for the MSS used by the stack when no -- cgit v1.2.3 From 665b614e4a6e715bac25bea15c5c29184016e549 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Tue, 4 Feb 2020 18:04:26 -0800 Subject: Support RTM_NEWADDR and RTM_GETLINK in (rt)netlink. PiperOrigin-RevId: 293271055 --- pkg/sentry/inet/inet.go | 4 + pkg/sentry/inet/test_stack.go | 6 + pkg/sentry/socket/hostinet/stack.go | 5 + pkg/sentry/socket/netlink/BUILD | 14 +- pkg/sentry/socket/netlink/message.go | 129 +++++++++++ pkg/sentry/socket/netlink/message_test.go | 312 +++++++++++++++++++++++++++ pkg/sentry/socket/netlink/provider.go | 2 +- pkg/sentry/socket/netlink/route/BUILD | 2 - pkg/sentry/socket/netlink/route/protocol.go | 238 ++++++++++++++------ pkg/sentry/socket/netlink/socket.go | 54 ++--- pkg/sentry/socket/netlink/uevent/protocol.go | 2 +- pkg/sentry/socket/netstack/stack.go | 55 +++++ pkg/tcpip/stack/stack.go | 9 + test/syscalls/linux/BUILD | 2 + test/syscalls/linux/socket_netlink_route.cc | 296 ++++++++++++++++++++----- test/syscalls/linux/socket_netlink_util.cc | 45 +++- test/syscalls/linux/socket_netlink_util.h | 9 + 17 files changed, 1022 insertions(+), 162 deletions(-) create mode 100644 pkg/sentry/socket/netlink/message_test.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index a7dfb78a7..2916a0644 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -28,6 +28,10 @@ type Stack interface { // interface indexes to a slice of associated interface address properties. InterfaceAddrs() map[int32][]InterfaceAddr + // AddInterfaceAddr adds an address to the network interface identified by + // index. + AddInterfaceAddr(idx int32, addr InterfaceAddr) error + // SupportsIPv6 returns true if the stack supports IPv6 connectivity. SupportsIPv6() bool diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index dcfcbd97e..d8961fc94 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -47,6 +47,12 @@ func (s *TestStack) InterfaceAddrs() map[int32][]InterfaceAddr { return s.InterfaceAddrsMap } +// AddInterfaceAddr implements Stack.AddInterfaceAddr. +func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error { + s.InterfaceAddrsMap[idx] = append(s.InterfaceAddrsMap[idx], addr) + return nil +} + // SupportsIPv6 implements Stack.SupportsIPv6. func (s *TestStack) SupportsIPv6() bool { return s.SupportsIPv6Flag diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 034eca676..a48082631 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -310,6 +310,11 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return addrs } +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + return syserror.EACCES +} + // SupportsIPv6 implements inet.Stack.SupportsIPv6. func (s *Stack) SupportsIPv6() bool { return s.supportsIPv6 diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index f8b8e467d..1911cd9b8 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -33,3 +33,15 @@ go_library( "//pkg/waiter", ], ) + +go_test( + name = "netlink_test", + size = "small", + srcs = [ + "message_test.go", + ], + deps = [ + ":netlink", + "//pkg/abi/linux", + ], +) diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go index b21e0ca4b..4ea252ccb 100644 --- a/pkg/sentry/socket/netlink/message.go +++ b/pkg/sentry/socket/netlink/message.go @@ -30,8 +30,16 @@ func alignUp(length int, align uint) int { return (length + int(align) - 1) &^ (int(align) - 1) } +// alignPad returns the length of padding required for alignment. +// +// Preconditions: align is a power of two. +func alignPad(length int, align uint) int { + return alignUp(length, align) - length +} + // Message contains a complete serialized netlink message. type Message struct { + hdr linux.NetlinkMessageHeader buf []byte } @@ -40,10 +48,86 @@ type Message struct { // The header length will be updated by Finalize. func NewMessage(hdr linux.NetlinkMessageHeader) *Message { return &Message{ + hdr: hdr, buf: binary.Marshal(nil, usermem.ByteOrder, hdr), } } +// ParseMessage parses the first message seen at buf, returning the rest of the +// buffer. If message is malformed, ok of false is returned. For last message, +// padding check is loose, if there isn't enought padding, whole buf is consumed +// and ok is set to true. +func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) { + b := BytesView(buf) + + hdrBytes, ok := b.Extract(linux.NetlinkMessageHeaderSize) + if !ok { + return + } + var hdr linux.NetlinkMessageHeader + binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr) + + // Msg portion. + totalMsgLen := int(hdr.Length) + _, ok = b.Extract(totalMsgLen - linux.NetlinkMessageHeaderSize) + if !ok { + return + } + + // Padding. + numPad := alignPad(totalMsgLen, linux.NLMSG_ALIGNTO) + // Linux permits the last message not being aligned, just consume all of it. + // Ref: net/netlink/af_netlink.c:netlink_rcv_skb + if numPad > len(b) { + numPad = len(b) + } + _, ok = b.Extract(numPad) + if !ok { + return + } + + return &Message{ + hdr: hdr, + buf: buf[:totalMsgLen], + }, []byte(b), true +} + +// Header returns the header of this message. +func (m *Message) Header() linux.NetlinkMessageHeader { + return m.hdr +} + +// GetData unmarshals the payload message header from this netlink message, and +// returns the attributes portion. +func (m *Message) GetData(msg interface{}) (AttrsView, bool) { + b := BytesView(m.buf) + + _, ok := b.Extract(linux.NetlinkMessageHeaderSize) + if !ok { + return nil, false + } + + size := int(binary.Size(msg)) + msgBytes, ok := b.Extract(size) + if !ok { + return nil, false + } + binary.Unmarshal(msgBytes, usermem.ByteOrder, msg) + + numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO) + // Linux permits the last message not being aligned, just consume all of it. + // Ref: net/netlink/af_netlink.c:netlink_rcv_skb + if numPad > len(b) { + numPad = len(b) + } + _, ok = b.Extract(numPad) + if !ok { + return nil, false + } + + return AttrsView(b), true +} + // Finalize returns the []byte containing the entire message, with the total // length set in the message header. The Message must not be modified after // calling Finalize. @@ -157,3 +241,48 @@ func (ms *MessageSet) AddMessage(hdr linux.NetlinkMessageHeader) *Message { ms.Messages = append(ms.Messages, m) return m } + +// AttrsView is a view into the attributes portion of a netlink message. +type AttrsView []byte + +// Empty returns whether there is no attribute left in v. +func (v AttrsView) Empty() bool { + return len(v) == 0 +} + +// ParseFirst parses first netlink attribute at the beginning of v. +func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest AttrsView, ok bool) { + b := BytesView(v) + + hdrBytes, ok := b.Extract(linux.NetlinkAttrHeaderSize) + if !ok { + return + } + binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr) + + value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize) + if !ok { + return + } + + _, ok = b.Extract(alignPad(int(hdr.Length), linux.NLA_ALIGNTO)) + if !ok { + return + } + + return hdr, value, AttrsView(b), ok +} + +// BytesView supports extracting data from a byte slice with bounds checking. +type BytesView []byte + +// Extract removes the first n bytes from v and returns it. If n is out of +// bounds, it returns false. +func (v *BytesView) Extract(n int) ([]byte, bool) { + if n < 0 || n > len(*v) { + return nil, false + } + extracted := (*v)[:n] + *v = (*v)[n:] + return extracted, true +} diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go new file mode 100644 index 000000000..ef13d9386 --- /dev/null +++ b/pkg/sentry/socket/netlink/message_test.go @@ -0,0 +1,312 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message_test + +import ( + "bytes" + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/socket/netlink" +) + +type dummyNetlinkMsg struct { + Foo uint16 +} + +func TestParseMessage(t *testing.T) { + tests := []struct { + desc string + input []byte + + header linux.NetlinkMessageHeader + dataMsg *dummyNetlinkMsg + restLen int + ok bool + }{ + { + desc: "valid", + input: []byte{ + 0x14, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + }, + header: linux.NetlinkMessageHeader{ + Length: 20, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 0, + ok: true, + }, + { + desc: "valid with next message", + input: []byte{ + 0x14, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + 0xFF, // Next message (rest) + }, + header: linux.NetlinkMessageHeader{ + Length: 20, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 1, + ok: true, + }, + { + desc: "valid for last message without padding", + input: []byte{ + 0x12, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, // Data message + }, + header: linux.NetlinkMessageHeader{ + Length: 18, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 0, + ok: true, + }, + { + desc: "valid for last message not to be aligned", + input: []byte{ + 0x13, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, // Data message + 0x00, // Excessive 1 byte permitted at end + }, + header: linux.NetlinkMessageHeader{ + Length: 19, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 0, + ok: true, + }, + { + desc: "header.Length too short", + input: []byte{ + 0x04, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + }, + ok: false, + }, + { + desc: "header.Length too long", + input: []byte{ + 0xFF, 0xFF, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + }, + ok: false, + }, + { + desc: "header incomplete", + input: []byte{ + 0x04, 0x00, 0x00, 0x00, // Length + }, + ok: false, + }, + { + desc: "empty message", + input: []byte{}, + ok: false, + }, + } + for _, test := range tests { + msg, rest, ok := netlink.ParseMessage(test.input) + if ok != test.ok { + t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok) + continue + } + if !test.ok { + continue + } + if !reflect.DeepEqual(msg.Header(), test.header) { + t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header) + } + + dataMsg := &dummyNetlinkMsg{} + _, dataOk := msg.GetData(dataMsg) + if !dataOk { + t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk) + } else if !reflect.DeepEqual(dataMsg, test.dataMsg) { + t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg) + } + + if got, want := rest, test.input[len(test.input)-test.restLen:]; !bytes.Equal(got, want) { + t.Errorf("%v: got rest = %v, want = %v", test.desc, got, want) + } + } +} + +func TestAttrView(t *testing.T) { + tests := []struct { + desc string + input []byte + + // Outputs for ParseFirst. + hdr linux.NetlinkAttrHeader + value []byte + restLen int + ok bool + + // Outputs for Empty. + isEmpty bool + }{ + { + desc: "valid", + input: []byte{ + 0x06, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x00, 0x00, // Data with 2 bytes padding + }, + hdr: linux.NetlinkAttrHeader{ + Length: 6, + Type: 1, + }, + value: []byte{0x30, 0x31}, + restLen: 0, + ok: true, + isEmpty: false, + }, + { + desc: "at alignment", + input: []byte{ + 0x08, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + }, + hdr: linux.NetlinkAttrHeader{ + Length: 8, + Type: 1, + }, + value: []byte{0x30, 0x31, 0x32, 0x33}, + restLen: 0, + ok: true, + isEmpty: false, + }, + { + desc: "at alignment with rest data", + input: []byte{ + 0x08, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + 0xFF, 0xFE, // Rest data + }, + hdr: linux.NetlinkAttrHeader{ + Length: 8, + Type: 1, + }, + value: []byte{0x30, 0x31, 0x32, 0x33}, + restLen: 2, + ok: true, + isEmpty: false, + }, + { + desc: "hdr.Length too long", + input: []byte{ + 0xFF, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + }, + ok: false, + isEmpty: false, + }, + { + desc: "hdr.Length too short", + input: []byte{ + 0x01, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + }, + ok: false, + isEmpty: false, + }, + { + desc: "empty", + input: []byte{}, + ok: false, + isEmpty: true, + }, + } + for _, test := range tests { + attrs := netlink.AttrsView(test.input) + + // Test ParseFirst(). + hdr, value, rest, ok := attrs.ParseFirst() + if ok != test.ok { + t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok) + } else if test.ok { + if !reflect.DeepEqual(hdr, test.hdr) { + t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, hdr, test.hdr) + } + if !bytes.Equal(value, test.value) { + t.Errorf("%v: got value = %v, want = %v", test.desc, value, test.value) + } + if wantRest := test.input[len(test.input)-test.restLen:]; !bytes.Equal(rest, wantRest) { + t.Errorf("%v: got rest = %v, want = %v", test.desc, rest, wantRest) + } + } + + // Test Empty(). + if got, want := attrs.Empty(), test.isEmpty; got != want { + t.Errorf("%v: got empty = %v, want = %v", test.desc, got, want) + } + } +} diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 07f860a49..b0dc70e5c 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -42,7 +42,7 @@ type Protocol interface { // If err == nil, any messages added to ms will be sent back to the // other end of the socket. Setting ms.Multi will cause an NLMSG_DONE // message to be sent even if ms contains no messages. - ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *MessageSet) *syserr.Error + ProcessMessage(ctx context.Context, msg *Message, ms *MessageSet) *syserr.Error } // Provider is a function that creates a new Protocol for a specific netlink diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD index 622a1eafc..93127398d 100644 --- a/pkg/sentry/socket/netlink/route/BUILD +++ b/pkg/sentry/socket/netlink/route/BUILD @@ -10,13 +10,11 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/socket/netlink", "//pkg/syserr", - "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index 2b3c7f5b3..c84d8bd7c 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -17,16 +17,15 @@ package route import ( "bytes" + "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/usermem" ) // commandKind describes the operational class of a message type. @@ -69,13 +68,7 @@ func (p *Protocol) CanSend() bool { } // dumpLinks handles RTM_GETLINK dump requests. -func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { - // TODO(b/68878065): Only the dump variant of the types below are - // supported. - if hdr.Flags&linux.NLM_F_DUMP != linux.NLM_F_DUMP { - return syserr.ErrNotSupported - } - +func (p *Protocol) dumpLinks(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // NLM_F_DUMP + RTM_GETLINK messages are supposed to include an // ifinfomsg. However, Linux <3.9 only checked for rtgenmsg, and some // userspace applications (including glibc) still include rtgenmsg. @@ -99,44 +92,105 @@ func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader return nil } - for id, i := range stack.Interfaces() { - m := ms.AddMessage(linux.NetlinkMessageHeader{ - Type: linux.RTM_NEWLINK, - }) + for idx, i := range stack.Interfaces() { + addNewLinkMessage(ms, idx, i) + } - m.Put(linux.InterfaceInfoMessage{ - Family: linux.AF_UNSPEC, - Type: i.DeviceType, - Index: id, - Flags: i.Flags, - }) + return nil +} - m.PutAttrString(linux.IFLA_IFNAME, i.Name) - m.PutAttr(linux.IFLA_MTU, i.MTU) +// getLinks handles RTM_GETLINK requests. +func (p *Protocol) getLink(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network devices. + return nil + } - mac := make([]byte, 6) - brd := mac - if len(i.Addr) > 0 { - mac = i.Addr - brd = bytes.Repeat([]byte{0xff}, len(i.Addr)) + // Parse message. + var ifi linux.InterfaceInfoMessage + attrs, ok := msg.GetData(&ifi) + if !ok { + return syserr.ErrInvalidArgument + } + + // Parse attributes. + var byName []byte + for !attrs.Empty() { + ahdr, value, rest, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument } - m.PutAttr(linux.IFLA_ADDRESS, mac) - m.PutAttr(linux.IFLA_BROADCAST, brd) + attrs = rest - // TODO(gvisor.dev/issue/578): There are many more attributes. + switch ahdr.Type { + case linux.IFLA_IFNAME: + if len(value) < 1 { + return syserr.ErrInvalidArgument + } + byName = value[:len(value)-1] + + // TODO(gvisor.dev/issue/578): Support IFLA_EXT_MASK. + } } + found := false + for idx, i := range stack.Interfaces() { + switch { + case ifi.Index > 0: + if idx != ifi.Index { + continue + } + case byName != nil: + if string(byName) != i.Name { + continue + } + default: + // Criteria not specified. + return syserr.ErrInvalidArgument + } + + addNewLinkMessage(ms, idx, i) + found = true + break + } + if !found { + return syserr.ErrNoDevice + } return nil } -// dumpAddrs handles RTM_GETADDR dump requests. -func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { - // TODO(b/68878065): Only the dump variant of the types below are - // supported. - if hdr.Flags&linux.NLM_F_DUMP != linux.NLM_F_DUMP { - return syserr.ErrNotSupported +// addNewLinkMessage appends RTM_NEWLINK message for the given interface into +// the message set. +func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.RTM_NEWLINK, + }) + + m.Put(linux.InterfaceInfoMessage{ + Family: linux.AF_UNSPEC, + Type: i.DeviceType, + Index: idx, + Flags: i.Flags, + }) + + m.PutAttrString(linux.IFLA_IFNAME, i.Name) + m.PutAttr(linux.IFLA_MTU, i.MTU) + + mac := make([]byte, 6) + brd := mac + if len(i.Addr) > 0 { + mac = i.Addr + brd = bytes.Repeat([]byte{0xff}, len(i.Addr)) } + m.PutAttr(linux.IFLA_ADDRESS, mac) + m.PutAttr(linux.IFLA_BROADCAST, brd) + + // TODO(gvisor.dev/issue/578): There are many more attributes. +} +// dumpAddrs handles RTM_GETADDR dump requests. +func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // RTM_GETADDR dump requests need not contain anything more than the // netlink header and 1 byte protocol family common to all // NETLINK_ROUTE requests. @@ -168,6 +222,7 @@ func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader Index: uint32(id), }) + m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr)) m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr)) // TODO(gvisor.dev/issue/578): There are many more attributes. @@ -252,12 +307,12 @@ func fillRoute(routes []inet.Route, addr []byte) (inet.Route, *syserr.Error) { } // parseForDestination parses a message as format of RouteMessage-RtAttr-dst. -func parseForDestination(data []byte) ([]byte, *syserr.Error) { +func parseForDestination(msg *netlink.Message) ([]byte, *syserr.Error) { var rtMsg linux.RouteMessage - if len(data) < linux.SizeOfRouteMessage { + attrs, ok := msg.GetData(&rtMsg) + if !ok { return nil, syserr.ErrInvalidArgument } - binary.Unmarshal(data[:linux.SizeOfRouteMessage], usermem.ByteOrder, &rtMsg) // iproute2 added the RTM_F_LOOKUP_TABLE flag in version v4.4.0. See // commit bc234301af12. Note we don't check this flag for backward // compatibility. @@ -265,26 +320,15 @@ func parseForDestination(data []byte) ([]byte, *syserr.Error) { return nil, syserr.ErrNotSupported } - data = data[linux.SizeOfRouteMessage:] - - // TODO(gvisor.dev/issue/1611): Add generic attribute parsing. - var rtAttr linux.RtAttr - if len(data) < linux.SizeOfRtAttr { - return nil, syserr.ErrInvalidArgument + // Expect first attribute is RTA_DST. + if hdr, value, _, ok := attrs.ParseFirst(); ok && hdr.Type == linux.RTA_DST { + return value, nil } - binary.Unmarshal(data[:linux.SizeOfRtAttr], usermem.ByteOrder, &rtAttr) - if rtAttr.Type != linux.RTA_DST { - return nil, syserr.ErrInvalidArgument - } - - if len(data) < int(rtAttr.Len) { - return nil, syserr.ErrInvalidArgument - } - return data[linux.SizeOfRtAttr:rtAttr.Len], nil + return nil, syserr.ErrInvalidArgument } // dumpRoutes handles RTM_GETROUTE requests. -func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { +func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // RTM_GETROUTE dump requests need not contain anything more than the // netlink header and 1 byte protocol family common to all // NETLINK_ROUTE requests. @@ -295,10 +339,11 @@ func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeade return nil } + hdr := msg.Header() routeTables := stack.RouteTable() if hdr.Flags == linux.NLM_F_REQUEST { - dst, err := parseForDestination(data) + dst, err := parseForDestination(msg) if err != nil { return err } @@ -357,10 +402,55 @@ func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeade return nil } +// newAddr handles RTM_NEWADDR requests. +func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network stack. + return syserr.ErrProtocolNotSupported + } + + var ifa linux.InterfaceAddrMessage + attrs, ok := msg.GetData(&ifa) + if !ok { + return syserr.ErrInvalidArgument + } + + for !attrs.Empty() { + ahdr, value, rest, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument + } + attrs = rest + + switch ahdr.Type { + case linux.IFA_LOCAL: + err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ + Family: ifa.Family, + PrefixLen: ifa.PrefixLen, + Flags: ifa.Flags, + Addr: value, + }) + if err == syscall.EEXIST { + flags := msg.Header().Flags + if flags&linux.NLM_F_EXCL != 0 { + return syserr.ErrExists + } + } else if err != nil { + return syserr.ErrInvalidArgument + } + } + } + return nil +} + // ProcessMessage implements netlink.Protocol.ProcessMessage. -func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { +func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + hdr := msg.Header() + // All messages start with a 1 byte protocol family. - if len(data) < 1 { + var family uint8 + if _, ok := msg.GetData(&family); !ok { // Linux ignores messages missing the protocol family. See // net/core/rtnetlink.c:rtnetlink_rcv_msg. return nil @@ -374,16 +464,32 @@ func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageH } } - switch hdr.Type { - case linux.RTM_GETLINK: - return p.dumpLinks(ctx, hdr, data, ms) - case linux.RTM_GETADDR: - return p.dumpAddrs(ctx, hdr, data, ms) - case linux.RTM_GETROUTE: - return p.dumpRoutes(ctx, hdr, data, ms) - default: - return syserr.ErrNotSupported + if hdr.Flags&linux.NLM_F_DUMP == linux.NLM_F_DUMP { + // TODO(b/68878065): Only the dump variant of the types below are + // supported. + switch hdr.Type { + case linux.RTM_GETLINK: + return p.dumpLinks(ctx, msg, ms) + case linux.RTM_GETADDR: + return p.dumpAddrs(ctx, msg, ms) + case linux.RTM_GETROUTE: + return p.dumpRoutes(ctx, msg, ms) + default: + return syserr.ErrNotSupported + } + } else if hdr.Flags&linux.NLM_F_REQUEST == linux.NLM_F_REQUEST { + switch hdr.Type { + case linux.RTM_GETLINK: + return p.getLink(ctx, msg, ms) + case linux.RTM_GETROUTE: + return p.dumpRoutes(ctx, msg, ms) + case linux.RTM_NEWADDR: + return p.newAddr(ctx, msg, ms) + default: + return syserr.ErrNotSupported + } } + return syserr.ErrNotSupported } // init registers the NETLINK_ROUTE provider. diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index c4b95debb..2ca02567d 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -644,47 +644,38 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error return nil } -func (s *Socket) dumpErrorMesage(ctx context.Context, hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) *syserr.Error { +func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) { m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.NLMSG_ERROR, }) - m.Put(linux.NetlinkErrorMessage{ Error: int32(-err.ToLinux().Number()), Header: hdr, }) - return nil +} +func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.NLMSG_ERROR, + }) + m.Put(linux.NetlinkErrorMessage{ + Error: 0, + Header: hdr, + }) } // processMessages handles each message in buf, passing it to the protocol // handler for final handling. func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error { for len(buf) > 0 { - if len(buf) < linux.NetlinkMessageHeaderSize { + msg, rest, ok := ParseMessage(buf) + if !ok { // Linux ignores messages that are too short. See // net/netlink/af_netlink.c:netlink_rcv_skb. break } - - var hdr linux.NetlinkMessageHeader - binary.Unmarshal(buf[:linux.NetlinkMessageHeaderSize], usermem.ByteOrder, &hdr) - - if hdr.Length < linux.NetlinkMessageHeaderSize || uint64(hdr.Length) > uint64(len(buf)) { - // Linux ignores malformed messages. See - // net/netlink/af_netlink.c:netlink_rcv_skb. - break - } - - // Data from this message. - data := buf[linux.NetlinkMessageHeaderSize:hdr.Length] - - // Advance to the next message. - next := alignUp(int(hdr.Length), linux.NLMSG_ALIGNTO) - if next >= len(buf)-1 { - next = len(buf) - 1 - } - buf = buf[next:] + buf = rest + hdr := msg.Header() // Ignore control messages. if hdr.Type < linux.NLMSG_MIN_TYPE { @@ -692,19 +683,10 @@ func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error } ms := NewMessageSet(s.portID, hdr.Seq) - var err *syserr.Error - // TODO(b/68877377): ACKs not supported yet. - if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK { - err = syserr.ErrNotSupported - } else { - - err = s.protocol.ProcessMessage(ctx, hdr, data, ms) - } - if err != nil { - ms = NewMessageSet(s.portID, hdr.Seq) - if err := s.dumpErrorMesage(ctx, hdr, ms, err); err != nil { - return err - } + if err := s.protocol.ProcessMessage(ctx, msg, ms); err != nil { + dumpErrorMesage(hdr, ms, err) + } else if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK { + dumpAckMesage(hdr, ms) } if err := s.sendResponse(ctx, ms); err != nil { diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go index 1ee4296bc..029ba21b5 100644 --- a/pkg/sentry/socket/netlink/uevent/protocol.go +++ b/pkg/sentry/socket/netlink/uevent/protocol.go @@ -49,7 +49,7 @@ func (p *Protocol) CanSend() bool { } // ProcessMessage implements netlink.Protocol.ProcessMessage. -func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { +func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // Silently ignore all messages. return nil } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 31ea66eca..0692482e9 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -20,6 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -88,6 +90,59 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return nicAddrs } +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + var ( + protocol tcpip.NetworkProtocolNumber + address tcpip.Address + ) + switch addr.Family { + case linux.AF_INET: + if len(addr.Addr) < header.IPv4AddressSize { + return syserror.EINVAL + } + if addr.PrefixLen > header.IPv4AddressSize*8 { + return syserror.EINVAL + } + protocol = ipv4.ProtocolNumber + address = tcpip.Address(addr.Addr[:header.IPv4AddressSize]) + + case linux.AF_INET6: + if len(addr.Addr) < header.IPv6AddressSize { + return syserror.EINVAL + } + if addr.PrefixLen > header.IPv6AddressSize*8 { + return syserror.EINVAL + } + protocol = ipv6.ProtocolNumber + address = tcpip.Address(addr.Addr[:header.IPv6AddressSize]) + + default: + return syserror.ENOTSUP + } + + protocolAddress := tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: int(addr.PrefixLen), + }, + } + + // Attach address to interface. + if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + return syserr.TranslateNetstackError(err).ToError() + } + + // Add route for local network. + s.Stack.AddRoute(tcpip.Route{ + Destination: protocolAddress.AddressWithPrefix.Subnet(), + Gateway: "", // No gateway for local network. + NIC: tcpip.NICID(idx), + }) + return nil +} + // TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize. func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { var rs tcp.ReceiveBufferSizeOption diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 7057b110e..b793f1d74 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -795,6 +795,8 @@ func (s *Stack) Forwarding() bool { // SetRouteTable assigns the route table to be used by this stack. It // specifies which NIC to use for given destination address ranges. +// +// This method takes ownership of the table. func (s *Stack) SetRouteTable(table []tcpip.Route) { s.mu.Lock() defer s.mu.Unlock() @@ -809,6 +811,13 @@ func (s *Stack) GetRouteTable() []tcpip.Route { return append([]tcpip.Route(nil), s.routeTable...) } +// AddRoute appends a route to the route table. +func (s *Stack) AddRoute(route tcpip.Route) { + s.mu.Lock() + defer s.mu.Unlock() + s.routeTable = append(s.routeTable, route) +} + // NewEndpoint creates a new transport layer endpoint of the given protocol. func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { t, ok := s.transportProtocols[transport] diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 273b014d6..f2e3c7072 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2769,9 +2769,11 @@ cc_binary( deps = [ ":socket_netlink_util", ":socket_test_util", + "//test/util:capability_util", "//test/util:cleanup", "//test/util:file_descriptor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", gtest, "//test/util:test_main", "//test/util:test_util", diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index 1e28e658d..e5aed1eec 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -25,8 +26,10 @@ #include "gtest/gtest.h" #include "absl/strings/str_format.h" +#include "absl/types/optional.h" #include "test/syscalls/linux/socket_netlink_util.h" #include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" #include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/test_util.h" @@ -38,6 +41,8 @@ namespace testing { namespace { +constexpr uint32_t kSeq = 12345; + using ::testing::AnyOf; using ::testing::Eq; @@ -113,58 +118,224 @@ void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) { // TODO(mpratt): Check ifinfomsg contents and following attrs. } +PosixError DumpLinks( + const FileDescriptor& fd, uint32_t seq, + const std::function& fn) { + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = seq; + req.ifm.ifi_family = AF_UNSPEC; + + return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false); +} + TEST(NetlinkRouteTest, GetLinkDump) { FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); + // Loopback is common among all tests, check that it's found. + bool loopbackFound = false; + ASSERT_NO_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { + CheckGetLinkResponse(hdr, kSeq, port); + if (hdr->nlmsg_type != RTM_NEWLINK) { + return; + } + ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + std::cout << "Found interface idx=" << msg->ifi_index + << ", type=" << std::hex << msg->ifi_type; + if (msg->ifi_type == ARPHRD_LOOPBACK) { + loopbackFound = true; + EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0); + } + })); + EXPECT_TRUE(loopbackFound); +} + +struct Link { + int index; + std::string name; +}; + +PosixErrorOr> FindLoopbackLink() { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + absl::optional link; + RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { + if (hdr->nlmsg_type != RTM_NEWLINK || + hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) { + return; + } + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + if (msg->ifi_type == ARPHRD_LOOPBACK) { + const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); + if (rta == nullptr) { + // Ignore links that do not have a name. + return; + } + + link = Link(); + link->index = msg->ifi_index; + link->name = std::string(reinterpret_cast(RTA_DATA(rta))); + } + })); + return link; +} + +// CheckLinkMsg checks a netlink message against an expected link. +void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) { + ASSERT_THAT(hdr->nlmsg_type, Eq(RTM_NEWLINK)); + ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + EXPECT_EQ(msg->ifi_index, link.index); + + const struct rtattr* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); + EXPECT_NE(nullptr, rta) << "IFLA_IFNAME not found in message."; + if (rta != nullptr) { + std::string name(reinterpret_cast(RTA_DATA(rta))); + EXPECT_EQ(name, link.name); + } +} + +TEST(NetlinkRouteTest, GetLinkByIndex) { + absl::optional loopback_link = + ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); + ASSERT_TRUE(loopback_link.has_value()); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + struct request { struct nlmsghdr hdr; struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_flags = NLM_F_REQUEST; req.hdr.nlmsg_seq = kSeq; req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = loopback_link->index; - // Loopback is common among all tests, check that it's found. - bool loopbackFound = false; + bool found = false; ASSERT_NO_ERRNO(NetlinkRequestResponse( fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { - CheckGetLinkResponse(hdr, kSeq, port); - if (hdr->nlmsg_type != RTM_NEWLINK) { - return; - } - ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); - const struct ifinfomsg* msg = - reinterpret_cast(NLMSG_DATA(hdr)); - std::cout << "Found interface idx=" << msg->ifi_index - << ", type=" << std::hex << msg->ifi_type; - if (msg->ifi_type == ARPHRD_LOOPBACK) { - loopbackFound = true; - EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0); - } + CheckLinkMsg(hdr, *loopback_link); + found = true; }, false)); - EXPECT_TRUE(loopbackFound); + EXPECT_TRUE(found) << "Netlink response does not contain any links."; } -TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { +TEST(NetlinkRouteTest, GetLinkByName) { + absl::optional loopback_link = + ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); + ASSERT_TRUE(loopback_link.has_value()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; struct ifinfomsg ifm; + struct rtattr rtattr; + char ifname[IFNAMSIZ]; + char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; }; - constexpr uint32_t kSeq = 12345; + struct request req = {}; + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.rtattr.rta_type = IFLA_IFNAME; + req.rtattr.rta_len = RTA_LENGTH(loopback_link->name.size() + 1); + strncpy(req.ifname, loopback_link->name.c_str(), sizeof(req.ifname)); + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); + + bool found = false; + ASSERT_NO_ERRNO(NetlinkRequestResponse( + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { + CheckLinkMsg(hdr, *loopback_link); + found = true; + }, + false)); + EXPECT_TRUE(found) << "Netlink response does not contain any links."; +} + +TEST(NetlinkRouteTest, GetLinkByIndexNotFound) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = 1234590; + + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(ENODEV, ::testing::_)); +} + +TEST(NetlinkRouteTest, GetLinkByNameNotFound) { + const std::string name = "nodevice?!"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + struct rtattr rtattr; + char ifname[IFNAMSIZ]; + char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.rtattr.rta_type = IFLA_IFNAME; + req.rtattr.rta_len = RTA_LENGTH(name.size() + 1); + strncpy(req.ifname, name.c_str(), sizeof(req.ifname)); + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); + + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(ENODEV, ::testing::_)); +} + +TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; struct request req = {}; req.hdr.nlmsg_len = sizeof(req); @@ -175,18 +346,8 @@ TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { req.hdr.nlmsg_seq = kSeq; req.ifm.ifi_family = AF_UNSPEC; - ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), - [&](const struct nlmsghdr* hdr) { - EXPECT_THAT(hdr->nlmsg_type, Eq(NLMSG_ERROR)); - EXPECT_EQ(hdr->nlmsg_seq, kSeq); - EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr)); - - const struct nlmsgerr* msg = - reinterpret_cast(NLMSG_DATA(hdr)); - EXPECT_EQ(msg->error, -EOPNOTSUPP); - }, - true)); + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(EOPNOTSUPP, ::testing::_)); } TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { @@ -198,8 +359,6 @@ TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETLINK; @@ -238,8 +397,6 @@ TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) { struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETLINK; @@ -282,8 +439,6 @@ TEST(NetlinkRouteTest, ControlMessageIgnored) { struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; // This control message is ignored. We still receive a response for the @@ -317,8 +472,6 @@ TEST(NetlinkRouteTest, GetAddrDump) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -367,6 +520,57 @@ TEST(NetlinkRouteTest, LookupAll) { ASSERT_GT(count, 0); } +TEST(NetlinkRouteTest, AddAddr) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + absl::optional loopback_link = + ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); + ASSERT_TRUE(loopback_link.has_value()); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifaddrmsg ifa; + struct rtattr rtattr; + struct in_addr addr; + char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_type = RTM_NEWADDR; + req.hdr.nlmsg_seq = kSeq; + req.ifa.ifa_family = AF_INET; + req.ifa.ifa_prefixlen = 24; + req.ifa.ifa_flags = 0; + req.ifa.ifa_scope = 0; + req.ifa.ifa_index = loopback_link->index; + req.rtattr.rta_type = IFA_LOCAL; + req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr)); + inet_pton(AF_INET, "10.0.0.1", &req.addr); + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(req.ifa)) + NLMSG_ALIGN(req.rtattr.rta_len); + + // Create should succeed, as no such address in kernel. + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_ACK; + EXPECT_NO_ERRNO( + NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len)); + + // Replace an existing address should succeed. + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK; + req.hdr.nlmsg_seq++; + EXPECT_NO_ERRNO( + NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len)); + + // Create exclusive should fail, as we created the address above. + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK; + req.hdr.nlmsg_seq++; + EXPECT_THAT( + NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len), + PosixErrorIs(EEXIST, ::testing::_)); +} + // GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request. TEST(NetlinkRouteTest, GetRouteDump) { FileDescriptor fd = @@ -378,8 +582,6 @@ TEST(NetlinkRouteTest, GetRouteDump) { struct rtmsg rtm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETROUTE; @@ -538,8 +740,6 @@ TEST(NetlinkRouteTest, RecvmsgTrunc) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -615,8 +815,6 @@ TEST(NetlinkRouteTest, RecvmsgTruncPeek) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -695,8 +893,6 @@ TEST(NetlinkRouteTest, NoPasscredNoCreds) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -743,8 +939,6 @@ TEST(NetlinkRouteTest, PasscredCreds) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc index cd2212a1a..952eecfe8 100644 --- a/test/syscalls/linux/socket_netlink_util.cc +++ b/test/syscalls/linux/socket_netlink_util.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -71,9 +72,10 @@ PosixError NetlinkRequestResponse( iov.iov_base = buf.data(); iov.iov_len = buf.size(); - // Response is a series of NLM_F_MULTI messages, ending with a NLMSG_DONE - // message. + // If NLM_F_MULTI is set, response is a series of messages that ends with a + // NLMSG_DONE message. int type = -1; + int flags = 0; do { int len; RETURN_ERROR_IF_SYSCALL_FAIL(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0)); @@ -89,6 +91,7 @@ PosixError NetlinkRequestResponse( for (struct nlmsghdr* hdr = reinterpret_cast(buf.data()); NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) { fn(hdr); + flags = hdr->nlmsg_flags; type = hdr->nlmsg_type; // Done should include an integer payload for dump_done_errno. // See net/netlink/af_netlink.c:netlink_dump @@ -98,11 +101,11 @@ PosixError NetlinkRequestResponse( EXPECT_GE(hdr->nlmsg_len, NLMSG_LENGTH(sizeof(int))); } } - } while (type != NLMSG_DONE && type != NLMSG_ERROR); + } while ((flags & NLM_F_MULTI) && type != NLMSG_DONE && type != NLMSG_ERROR); if (expect_nlmsgerr) { EXPECT_EQ(type, NLMSG_ERROR); - } else { + } else if (flags & NLM_F_MULTI) { EXPECT_EQ(type, NLMSG_DONE); } return NoError(); @@ -146,5 +149,39 @@ PosixError NetlinkRequestResponseSingle( return NoError(); } +PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq, + void* request, size_t len) { + // Dummy negative number for no error message received. + // We won't get a negative error number so there will be no confusion. + int err = -42; + RETURN_IF_ERRNO(NetlinkRequestResponse( + fd, request, len, + [&](const struct nlmsghdr* hdr) { + EXPECT_EQ(NLMSG_ERROR, hdr->nlmsg_type); + EXPECT_EQ(hdr->nlmsg_seq, seq); + EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr)); + + const struct nlmsgerr* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + err = -msg->error; + }, + true)); + return PosixError(err); +} + +const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr, + const struct ifinfomsg* msg, int16_t attr) { + const int ifi_space = NLMSG_SPACE(sizeof(*msg)); + int attrlen = hdr->nlmsg_len - ifi_space; + const struct rtattr* rta = reinterpret_cast( + reinterpret_cast(hdr) + NLMSG_ALIGN(ifi_space)); + for (; RTA_OK(rta, attrlen); rta = RTA_NEXT(rta, attrlen)) { + if (rta->rta_type == attr) { + return rta; + } + } + return nullptr; +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h index 3678c0599..e13ead406 100644 --- a/test/syscalls/linux/socket_netlink_util.h +++ b/test/syscalls/linux/socket_netlink_util.h @@ -19,6 +19,7 @@ // socket.h has to be included before if_arp.h. #include #include +#include #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -47,6 +48,14 @@ PosixError NetlinkRequestResponseSingle( const FileDescriptor& fd, void* request, size_t len, const std::function& fn); +// Send the passed request then expect and return an ack or error. +PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq, + void* request, size_t len); + +// Find rtnetlink attribute in message. +const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr, + const struct ifinfomsg* msg, int16_t attr); + } // namespace testing } // namespace gvisor -- cgit v1.2.3 From 6bd59b4e08893281468e8af5aebb5fab0f7a8c0d Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 6 Feb 2020 11:12:41 -0800 Subject: Update link address for targets of Neighbor Adverts Get the link address for the target of an NDP Neighbor Advertisement from the NDP Target Link Layer Address option. Tests: - ipv6.TestNeighorAdvertisementWithTargetLinkLayerOption - ipv6.TestNeighorAdvertisementWithInvalidTargetLinkLayerOption PiperOrigin-RevId: 293632609 --- pkg/tcpip/network/ipv6/icmp.go | 44 ++++-- pkg/tcpip/network/ipv6/icmp_test.go | 186 +++++++++++++++--------- pkg/tcpip/network/ipv6/ndp_test.go | 278 +++++++++++++++++++++++++----------- pkg/tcpip/stack/ndp_test.go | 8 +- 4 files changed, 352 insertions(+), 164 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 60817d36d..45dc757c7 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -15,6 +15,8 @@ package ipv6 import ( + "log" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -194,7 +196,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P // TODO(b/148429853): Properly process the NS message and do Neighbor // Unreachability Detection. for { - opt, done, _ := it.Next() + opt, done, err := it.Next() + if err != nil { + // This should never happen as Iter(true) above did not return an error. + log.Fatalf("unexpected error when iterating over NDP options: %s", err) + } if done { break } @@ -253,21 +259,25 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P } na := header.NDPNeighborAdvert(h.NDPPayload()) + it, err := na.Options().Iter(true) + if err != nil { + // If we have a malformed NDP NA option, drop the packet. + received.Invalid.Increment() + return + } + targetAddr := na.TargetAddress() stack := r.Stack() rxNICID := r.NICID() - isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr) - if err != nil { + if isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr); err != nil { // We will only get an error if rxNICID is unrecognized, // which should not happen. For now short-circuit this // packet. // // TODO(b/141002840): Handle this better? return - } - - if isTentative { + } else if isTentative { // We just got an NA from a node that owns an address we // are performing DAD on, implying the address is not // unique. In this case we let the stack know so it can @@ -283,13 +293,29 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P // scenario is beyond the scope of RFC 4862. As such, we simply // ignore such a scenario for now and proceed as normal. // + // If the NA message has the target link layer option, update the link + // address cache with the link address for the target of the message. + // // TODO(b/143147598): Handle the scenario described above. Also // inform the netstack integration that a duplicate address was // detected outside of DAD. + // + // TODO(b/148429853): Properly process the NA message and do Neighbor + // Unreachability Detection. + for { + opt, done, err := it.Next() + if err != nil { + // This should never happen as Iter(true) above did not return an error. + log.Fatalf("unexpected error when iterating over NDP options: %s", err) + } + if done { + break + } - e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, r.RemoteLinkAddress) - if targetAddr != r.RemoteAddress { - e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) + switch opt := opt.(type) { + case header.NDPTargetLinkLayerAddressOption: + e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, opt.EthernetAddress()) + } } case header.ICMPv6EchoRequest: diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d0e930e20..50c4b6474 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -121,21 +121,60 @@ func TestICMPCounts(t *testing.T) { } defer r.Release() + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + types := []struct { - typ header.ICMPv6Type - size int + typ header.ICMPv6Type + size int + extraData []byte }{ - {header.ICMPv6DstUnreachable, header.ICMPv6DstUnreachableMinimumSize}, - {header.ICMPv6PacketTooBig, header.ICMPv6PacketTooBigMinimumSize}, - {header.ICMPv6TimeExceeded, header.ICMPv6MinimumSize}, - {header.ICMPv6ParamProblem, header.ICMPv6MinimumSize}, - {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize}, - {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize}, - {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize}, - {header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize}, - {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize}, - {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize}, - {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize}, + { + typ: header.ICMPv6DstUnreachable, + size: header.ICMPv6DstUnreachableMinimumSize, + }, + { + typ: header.ICMPv6PacketTooBig, + size: header.ICMPv6PacketTooBigMinimumSize, + }, + { + typ: header.ICMPv6TimeExceeded, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6ParamProblem, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6EchoRequest, + size: header.ICMPv6EchoMinimumSize, + }, + { + typ: header.ICMPv6EchoReply, + size: header.ICMPv6EchoMinimumSize, + }, + { + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + }, + { + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize}, + { + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + }, + { + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + }, } handleIPv6Payload := func(hdr buffer.Prependable) { @@ -154,10 +193,13 @@ func TestICMPCounts(t *testing.T) { } for _, typ := range types { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) + extraDataLen := len(typ.extraData) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) + extraData := buffer.View(hdr.Prepend(extraDataLen)) + copy(extraData, typ.extraData) pkt := header.ICMPv6(hdr.Prepend(typ.size)) pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) handleIPv6Payload(hdr) } @@ -372,97 +414,104 @@ func TestLinkResolution(t *testing.T) { } func TestICMPChecksumValidationSimple(t *testing.T) { + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + types := []struct { name string typ header.ICMPv6Type size int + extraData []byte statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter }{ { - "DstUnreachable", - header.ICMPv6DstUnreachable, - header.ICMPv6DstUnreachableMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "DstUnreachable", + typ: header.ICMPv6DstUnreachable, + size: header.ICMPv6DstUnreachableMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.DstUnreachable }, }, { - "PacketTooBig", - header.ICMPv6PacketTooBig, - header.ICMPv6PacketTooBigMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "PacketTooBig", + typ: header.ICMPv6PacketTooBig, + size: header.ICMPv6PacketTooBigMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.PacketTooBig }, }, { - "TimeExceeded", - header.ICMPv6TimeExceeded, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "TimeExceeded", + typ: header.ICMPv6TimeExceeded, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.TimeExceeded }, }, { - "ParamProblem", - header.ICMPv6ParamProblem, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "ParamProblem", + typ: header.ICMPv6ParamProblem, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.ParamProblem }, }, { - "EchoRequest", - header.ICMPv6EchoRequest, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "EchoRequest", + typ: header.ICMPv6EchoRequest, + size: header.ICMPv6EchoMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.EchoRequest }, }, { - "EchoReply", - header.ICMPv6EchoReply, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "EchoReply", + typ: header.ICMPv6EchoReply, + size: header.ICMPv6EchoMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.EchoReply }, }, { - "RouterSolicit", - header.ICMPv6RouterSolicit, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "RouterSolicit", + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterSolicit }, }, { - "RouterAdvert", - header.ICMPv6RouterAdvert, - header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "RouterAdvert", + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterAdvert }, }, { - "NeighborSolicit", - header.ICMPv6NeighborSolicit, - header.ICMPv6NeighborSolicitMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "NeighborSolicit", + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.NeighborSolicit }, }, { - "NeighborAdvert", - header.ICMPv6NeighborAdvert, - header.ICMPv6NeighborAdvertSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "NeighborAdvert", + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.NeighborAdvert }, }, { - "RedirectMsg", - header.ICMPv6RedirectMsg, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "RedirectMsg", + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RedirectMsg }, }, @@ -494,16 +543,19 @@ func TestICMPChecksumValidationSimple(t *testing.T) { ) } - handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - pkt := header.ICMPv6(hdr.Prepend(size)) - pkt.SetType(typ) + handleIPv6Payload := func(checksum bool) { + extraDataLen := len(typ.extraData) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) + extraData := buffer.View(hdr.Prepend(extraDataLen)) + copy(extraData, typ.extraData) + pkt := header.ICMPv6(hdr.Prepend(typ.size)) + pkt.SetType(typ.typ) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView())) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(size), + PayloadLength: uint16(typ.size + extraDataLen), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, SrcAddr: lladdr1, @@ -528,7 +580,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { // Without setting checksum, the incoming packet should // be invalid. - handleIPv6Payload(typ.typ, typ.size, false) + handleIPv6Payload(false) if got := invalid.Value(); got != 1 { t.Fatalf("got invalid = %d, want = 1", got) } @@ -538,7 +590,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { } // When checksum is set, it should be received. - handleIPv6Payload(typ.typ, typ.size, true) + handleIPv6Payload(true) if got := typStat.Value(); got != 1 { t.Fatalf("got %s = %d, want = 1", typ.name, got) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index bd732f93f..c9395de52 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -70,76 +70,29 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack return s, ep } -// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving an -// NDP NS message with the Source Link Layer Address option results in a +// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a +// valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - e := channel.New(0, 1280, linkAddr0) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) - ns.SetTargetAddress(lladdr0) - ns.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - - linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) - if err != nil { - t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - if c != nil { - t.Errorf("got unexpected channel") - } - if linkAddr != linkAddr1 { - t.Errorf("got link address = %s, want = %s", linkAddr, linkAddr1) - } -} - -// TestNeighorSolicitationWithInvalidSourceLinkLayerOption tests that receiving -// an NDP NS message with an invalid Source Link Layer Address option does not -// result in a new entry in the link address cache for the sender of the -// message. -func TestNeighorSolicitationWithInvalidSourceLinkLayerOption(t *testing.T) { - const nicID = 1 - tests := []struct { - name string - optsBuf []byte + name string + optsBuf []byte + expectedLinkAddr tcpip.LinkAddress }{ + { + name: "Valid", + optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7}, + expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", + }, { name: "Too Small", - optsBuf: []byte{1, 1, 1, 2, 3, 4, 5}, + optsBuf: []byte{1, 1, 2, 3, 4, 5, 6}, }, { name: "Invalid Length", - optsBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6}, + optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7}, }, } @@ -186,20 +139,138 @@ func TestNeighorSolicitationWithInvalidSourceLinkLayerOption(t *testing.T) { Data: hdr.View().ToVectorisedView(), }) - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) + linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) + if linkAddr != test.expectedLinkAddr { + t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) } - linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) + if test.expectedLinkAddr != "" { + if err != nil { + t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) + } + if c != nil { + t.Errorf("got unexpected channel") + } + + // Invalid count should not have increased. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) + } + if c == nil { + t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + } + + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Errorf("got invalid = %d, want = 1", got) + } + } + }) + } +} + +// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a +// valid NDP NA message with the Target Link Layer Address option results in a +// new entry in the link address cache for the target of the message. +func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + optsBuf []byte + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "Valid", + optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7}, + expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", + }, + { + name: "Too Small", + optsBuf: []byte{2, 1, 2, 3, 4, 5, 6}, + }, + { + name: "Invalid Length", + optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr0) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr1) + opts := ns.Options() + copy(opts, test.optsBuf) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) } - if c == nil { - t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) + if linkAddr != test.expectedLinkAddr { + t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) } - if linkAddr != "" { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (%s, _, ), want = ('', _, _)", nicID, lladdr1, lladdr0, ProtocolNumber, linkAddr) + + if test.expectedLinkAddr != "" { + if err != nil { + t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) + } + if c != nil { + t.Errorf("got unexpected channel") + } + + // Invalid count should not have increased. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) + } + if c == nil { + t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + } + + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Errorf("got invalid = %d, want = 1", got) + } } }) } @@ -238,27 +309,59 @@ func TestHopLimitValidation(t *testing.T) { }) } + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + types := []struct { name string typ header.ICMPv6Type size int + extraData []byte statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter }{ - {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }}, - {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }}, - {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }}, - {"NeighborAdvert", header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }}, - {"RedirectMsg", header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }}, + { + name: "RouterSolicit", + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + }, + { + name: "RouterAdvert", + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + name: "NeighborSolicit", + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + name: "NeighborAdvert", + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + name: "RedirectMsg", + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, + }, } for _, typ := range types { @@ -270,10 +373,13 @@ func TestHopLimitValidation(t *testing.T) { invalid := stats.Invalid typStat := typ.statCounter(stats) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) + extraDataLen := len(typ.extraData) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) + extraData := buffer.View(hdr.Prepend(extraDataLen)) + copy(extraData, typ.extraData) pkt := header.ICMPv6(hdr.Prepend(typ.size)) pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 8af8565f7..9a4607dcb 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -478,13 +478,17 @@ func TestDADFail(t *testing.T) { { "RxAdvert", func(tgt tcpip.Address) buffer.Prependable { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) + pkt := header.ICMPv6(hdr.Prepend(naSize)) pkt.SetType(header.ICMPv6NeighborAdvert) na := header.NDPNeighborAdvert(pkt.NDPPayload()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(tgt) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) -- cgit v1.2.3 From 940d255971c38af9f91ceed1345fd973f8fdb41d Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 6 Feb 2020 15:57:34 -0800 Subject: Perform DAD on IPv6 addresses when enabling a NIC Addresses may be added before a NIC is enabled. Make sure DAD is performed on the permanent IPv6 addresses when they get enabled. Test: - stack_test.TestDoDADWhenNICEnabled - stack.TestDisabledRxStatsWhenNICDisabled PiperOrigin-RevId: 293697429 --- pkg/tcpip/stack/BUILD | 6 ++- pkg/tcpip/stack/ndp_test.go | 74 +++++++++++++-------------- pkg/tcpip/stack/nic.go | 84 ++++++++++++++++++++++-------- pkg/tcpip/stack/nic_test.go | 62 +++++++++++++++++++++++ pkg/tcpip/stack/stack_test.go | 115 ++++++++++++++++++++++++++++++++++++++++++ pkg/tcpip/tcpip.go | 8 +-- 6 files changed, 287 insertions(+), 62 deletions(-) create mode 100644 pkg/tcpip/stack/nic_test.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index f5b750046..705cf01ee 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -78,11 +78,15 @@ go_test( go_test( name = "stack_test", size = "small", - srcs = ["linkaddrcache_test.go"], + srcs = [ + "linkaddrcache_test.go", + "nic_test.go", + ], library = ":stack", deps = [ "//pkg/sleep", "//pkg/sync", "//pkg/tcpip", + "//pkg/tcpip/buffer", ], ) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 9a4607dcb..1e575bdaf 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1539,7 +1539,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { } // Checks to see if list contains an IPv6 address, item. -func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { +func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { protocolAddress := tcpip.ProtocolAddress{ Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: item, @@ -1665,7 +1665,7 @@ func TestAutoGenAddr(t *testing.T) { // with non-zero lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) expectAutoGenAddrEvent(addr1, newAddr) - if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should have %s in the list of addresses", addr1) } @@ -1681,10 +1681,10 @@ func TestAutoGenAddr(t *testing.T) { // Receive an RA with prefix2 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) expectAutoGenAddrEvent(addr2, newAddr) - if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { t.Fatalf("Should have %s in the list of addresses", addr2) } @@ -1705,10 +1705,10 @@ func TestAutoGenAddr(t *testing.T) { case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - if contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should not have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { t.Fatalf("Should have %s in the list of addresses", addr2) } } @@ -1853,7 +1853,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // Receive PI for prefix1. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) expectAutoGenAddrEvent(addr1, newAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should have %s in the list of addresses", addr1) } expectPrimaryAddr(addr1) @@ -1861,7 +1861,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // Deprecate addr for prefix1 immedaitely. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) expectAutoGenAddrEvent(addr1, deprecatedAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should have %s in the list of addresses", addr1) } // addr should still be the primary endpoint as there are no other addresses. @@ -1879,7 +1879,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // Receive PI for prefix2. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) expectAutoGenAddrEvent(addr2, newAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } expectPrimaryAddr(addr2) @@ -1887,7 +1887,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // Deprecate addr for prefix2 immedaitely. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) expectAutoGenAddrEvent(addr2, deprecatedAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } // addr1 should be the primary endpoint now since addr2 is deprecated but @@ -1982,7 +1982,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // Receive PI for prefix2. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) expectAutoGenAddrEvent(addr2, newAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } expectPrimaryAddr(addr2) @@ -1990,10 +1990,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // Receive a PI for prefix1. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) expectAutoGenAddrEvent(addr1, newAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } expectPrimaryAddr(addr1) @@ -2009,10 +2009,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // Wait for addr of prefix1 to be deprecated. expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } // addr2 should be the primary endpoint now since addr1 is deprecated but @@ -2049,10 +2049,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // Wait for addr of prefix1 to be deprecated. expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } // addr2 should be the primary endpoint now since it is not deprecated. @@ -2063,10 +2063,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // Wait for addr of prefix1 to be invalidated. expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout) - if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } expectPrimaryAddr(addr2) @@ -2112,10 +2112,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } - if contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should not have %s in the list of addresses", addr2) } // Should not have any primary endpoints. @@ -2600,7 +2600,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) } @@ -2613,7 +2613,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") default: } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) } @@ -2624,7 +2624,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { t.Fatal("unexpectedly received an auto gen addr event") case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) } } @@ -2702,17 +2702,17 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { const validLifetimeSecondPrefix1 = 1 e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0)) expectAutoGenAddrEvent(addr1, newAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should have %s in the list of addresses", addr1) } // Receive an RA with prefix2 in a PI with a large valid lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) expectAutoGenAddrEvent(addr2, newAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } @@ -2725,10 +2725,10 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } } @@ -3014,16 +3014,16 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { nicinfo := s.NICInfo() nic1Addrs := nicinfo[nicID1].ProtocolAddresses nic2Addrs := nicinfo[nicID2].ProtocolAddresses - if !contains(nic1Addrs, e1Addr1) { + if !containsV6Addr(nic1Addrs, e1Addr1) { t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) } - if !contains(nic1Addrs, e1Addr2) { + if !containsV6Addr(nic1Addrs, e1Addr2) { t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) } - if !contains(nic2Addrs, e2Addr1) { + if !containsV6Addr(nic2Addrs, e2Addr1) { t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) } - if !contains(nic2Addrs, e2Addr2) { + if !containsV6Addr(nic2Addrs, e2Addr2) { t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) } @@ -3102,16 +3102,16 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { nicinfo = s.NICInfo() nic1Addrs = nicinfo[nicID1].ProtocolAddresses nic2Addrs = nicinfo[nicID2].ProtocolAddresses - if contains(nic1Addrs, e1Addr1) { + if containsV6Addr(nic1Addrs, e1Addr1) { t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) } - if contains(nic1Addrs, e1Addr2) { + if containsV6Addr(nic1Addrs, e1Addr2) { t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) } - if contains(nic2Addrs, e2Addr1) { + if containsV6Addr(nic2Addrs, e2Addr1) { t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) } - if contains(nic2Addrs, e2Addr2) { + if containsV6Addr(nic2Addrs, e2Addr2) { t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 7dad9a8cb..682e9c416 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,6 +16,7 @@ package stack import ( "log" + "reflect" "sort" "strings" "sync/atomic" @@ -39,6 +40,7 @@ type NIC struct { mu struct { sync.RWMutex + enabled bool spoofing bool promiscuous bool primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint @@ -56,6 +58,14 @@ type NIC struct { type NICStats struct { Tx DirectionStats Rx DirectionStats + + DisabledRx DirectionStats +} + +func makeNICStats() NICStats { + var s NICStats + tcpip.InitStatCounters(reflect.ValueOf(&s).Elem()) + return s } // DirectionStats includes packet and byte counts. @@ -99,16 +109,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC name: name, linkEP: ep, context: ctx, - stats: NICStats{ - Tx: DirectionStats{ - Packets: &tcpip.StatCounter{}, - Bytes: &tcpip.StatCounter{}, - }, - Rx: DirectionStats{ - Packets: &tcpip.StatCounter{}, - Bytes: &tcpip.StatCounter{}, - }, - }, + stats: makeNICStats(), } nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) @@ -137,14 +138,30 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // enable enables the NIC. enable will attach the link to its LinkEndpoint and // join the IPv6 All-Nodes Multicast address (ff02::1). func (n *NIC) enable() *tcpip.Error { + n.mu.RLock() + enabled := n.mu.enabled + n.mu.RUnlock() + if enabled { + return nil + } + + n.mu.Lock() + defer n.mu.Unlock() + + if n.mu.enabled { + return nil + } + + n.mu.enabled = true + n.attachLinkEndpoint() // Create an endpoint to receive broadcast packets on this interface. if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - if err := n.AddAddress(tcpip.ProtocolAddress{ + if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, - }, NeverPrimaryEndpoint); err != nil { + }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { return err } } @@ -166,8 +183,22 @@ func (n *NIC) enable() *tcpip.Error { return nil } - n.mu.Lock() - defer n.mu.Unlock() + // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent + // state. + // + // Addresses may have aleady completed DAD but in the time since the NIC was + // last enabled, other devices may have acquired the same addresses. + for _, r := range n.mu.endpoints { + addr := r.ep.ID().LocalAddress + if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) { + continue + } + + r.setKind(permanentTentative) + if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil { + return err + } + } if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { return err @@ -633,7 +664,9 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) // If the address is an IPv6 address and it is a permanent address, - // mark it as tentative so it goes through the DAD process. + // mark it as tentative so it goes through the DAD process if the NIC is + // enabled. If the NIC is not enabled, DAD will be started when the NIC is + // enabled. if isIPv6Unicast && kind == permanent { kind = permanentTentative } @@ -668,8 +701,8 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar n.insertPrimaryEndpointLocked(ref, peb) - // If we are adding a tentative IPv6 address, start DAD. - if isIPv6Unicast && kind == permanentTentative { + // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled. + if isIPv6Unicast && kind == permanentTentative && n.mu.enabled { if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { return nil, err } @@ -700,9 +733,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { // Don't include tentative, expired or temporary endpoints to // avoid confusion and prevent the caller from using those. switch ref.getKind() { - case permanentTentative, permanentExpired, temporary: - // TODO(b/140898488): Should tentative addresses be - // returned? + case permanentExpired, temporary: continue } addrs = append(addrs, tcpip.ProtocolAddress{ @@ -1016,11 +1047,23 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + n.mu.RLock() + enabled := n.mu.enabled + // If the NIC is not yet enabled, don't receive any packets. + if !enabled { + n.mu.RUnlock() + + n.stats.DisabledRx.Packets.Increment() + n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + return + } + n.stats.Rx.Packets.Increment() n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) netProto, ok := n.stack.networkProtocols[protocol] if !ok { + n.mu.RUnlock() n.stack.stats.UnknownProtocolRcvdPackets.Increment() return } @@ -1032,7 +1075,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } // Are any packet sockets listening for this network protocol? - n.mu.RLock() packetEPs := n.mu.packetEPs[protocol] // Check whether there are packet sockets listening for every protocol. // If we received a packet with protocol EthernetProtocolAll, then the diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go new file mode 100644 index 000000000..edaee3b86 --- /dev/null +++ b/pkg/tcpip/stack/nic_test.go @@ -0,0 +1,62 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { + // When the NIC is disabled, the only field that matters is the stats field. + // This test is limited to stats counter checks. + nic := NIC{ + stats: makeNICStats(), + } + + if got := nic.stats.DisabledRx.Packets.Value(); got != 0 { + t.Errorf("got DisabledRx.Packets = %d, want = 0", got) + } + if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 { + t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) + } + if got := nic.stats.Rx.Packets.Value(); got != 0 { + t.Errorf("got Rx.Packets = %d, want = 0", got) + } + if got := nic.stats.Rx.Bytes.Value(); got != 0 { + t.Errorf("got Rx.Bytes = %d, want = 0", got) + } + + if t.Failed() { + t.FailNow() + } + + nic.DeliverNetworkPacket(nil, "", "", 0, tcpip.PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + + if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { + t.Errorf("got DisabledRx.Packets = %d, want = 1", got) + } + if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 { + t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) + } + if got := nic.stats.Rx.Packets.Value(); got != 0 { + t.Errorf("got Rx.Packets = %d, want = 0", got) + } + if got := nic.stats.Rx.Bytes.Value(); got != 0 { + t.Errorf("got Rx.Bytes = %d, want = 0", got) + } +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 834fe9487..243868f3a 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2561,3 +2561,118 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { }) } } + +// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC +// was disabled have DAD performed on them when the NIC is enabled. +func TestDoDADWhenNICEnabled(t *testing.T) { + t.Parallel() + + const dadTransmits = 1 + const retransmitTimer = time.Second + const nicID = 1 + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, + NDPDisp: &ndpDisp, + } + + e := channel.New(dadTransmits, 1280, linkAddr1) + s := stack.New(opts) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + addr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: llAddr1, + PrefixLen: 128, + }, + } + if err := s.AddProtocolAddress(nicID, addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) + } + + // Address should be in the list of all addresses. + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + + // Address should be tentative so it should not be a main address. + got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); got != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + } + + // Enabling the NIC should start DAD for the address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + + // Address should not be considered bound to the NIC yet (DAD ongoing). + got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); got != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + } + + // Wait for DAD to resolve. + select { + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != nicID { + t.Fatalf("got DAD event w/ nicID = %d, want = %d", e.nicID, nicID) + } + if e.addr != addr.AddressWithPrefix.Address { + t.Fatalf("got DAD event w/ addr = %s, want = %s", e.addr, addr.AddressWithPrefix.Address) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if got != addr.AddressWithPrefix { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + } + + // Enabling the NIC again should be a no-op. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if got != addr.AddressWithPrefix { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + } +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index d29d9a704..0e944712f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1170,7 +1170,9 @@ type TransportEndpointStats struct { // marker interface. func (*TransportEndpointStats) IsEndpointStats() {} -func fillIn(v reflect.Value) { +// InitStatCounters initializes v's fields with nil StatCounter fields to new +// StatCounters. +func InitStatCounters(v reflect.Value) { for i := 0; i < v.NumField(); i++ { v := v.Field(i) if s, ok := v.Addr().Interface().(**StatCounter); ok { @@ -1178,14 +1180,14 @@ func fillIn(v reflect.Value) { *s = new(StatCounter) } } else { - fillIn(v) + InitStatCounters(v) } } } // FillIn returns a copy of s with nil fields initialized to new StatCounters. func (s Stats) FillIn() Stats { - fillIn(reflect.ValueOf(&s).Elem()) + InitStatCounters(reflect.ValueOf(&s).Elem()) return s } -- cgit v1.2.3 From 3700221b1f3ff0779a0f4479fd2bafa3312d5a23 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 6 Feb 2020 16:42:37 -0800 Subject: Auto-generate link-local address as a SLAAC address Auto-generated link-local addresses should have the same lifecycle hooks as global SLAAC addresses. The Stack's NDP dispatcher should be notified when link-local addresses are auto-generated and invalidated. They should also be removed when a NIC is disabled (which will be supported in a later change). Tests: - stack_test.TestNICAutoGenAddrWithOpaque - stack_test.TestNICAutoGenAddr PiperOrigin-RevId: 293706760 --- pkg/tcpip/stack/ndp.go | 36 +++-- pkg/tcpip/stack/ndp_test.go | 85 +++++++----- pkg/tcpip/stack/nic.go | 30 +---- pkg/tcpip/stack/stack_test.go | 307 +++++++++++++++++++----------------------- 4 files changed, 218 insertions(+), 240 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 6123fda33..fae5f5014 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -906,22 +906,21 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform return } - // We do not already have an address within the prefix, prefix. Do the + // We do not already have an address with the prefix prefix. Do the // work as outlined by RFC 4862 section 5.5.3.d if n is configured - // to auto-generated global addresses by SLAAC. - ndp.newAutoGenAddress(prefix, pl, vl) + // to auto-generate global addresses by SLAAC. + if !ndp.configs.AutoGenGlobalAddresses { + return + } + + ndp.doSLAAC(prefix, pl, vl) } -// newAutoGenAddress generates a new SLAAC address with the provided lifetimes +// doSLAAC generates a new SLAAC address with the provided lifetimes // for prefix. // // pl is the new preferred lifetime. vl is the new valid lifetime. -func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration) { - // Are we configured to auto-generate new global addresses? - if !ndp.configs.AutoGenGlobalAddresses { - return - } - +func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // If we do not already have an address for this prefix and the valid // lifetime is 0, no need to do anything further, as per RFC 4862 // section 5.5.3.d. @@ -1152,12 +1151,21 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupHostOnlyState() { + linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() + linkLocalAddrs := 0 for addr := range ndp.autoGenAddresses { + // RFC 4862 section 5 states that routers are also expected to generate a + // link-local address so we do not invalidate them. + if linkLocalSubnet.Contains(addr) { + linkLocalAddrs++ + continue + } + ndp.invalidateAutoGenAddress(addr) } - if got := len(ndp.autoGenAddresses); got != 0 { - log.Fatalf("ndp: still have auto-generated addresses after cleaning up, found = %d", got) + if got := len(ndp.autoGenAddresses); got != linkLocalAddrs { + log.Fatalf("ndp: still have non-linklocal auto-generated addresses after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalAddrs) } for prefix := range ndp.onLinkPrefixes { @@ -1165,7 +1173,7 @@ func (ndp *ndpState) cleanupHostOnlyState() { } if got := len(ndp.onLinkPrefixes); got != 0 { - log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up, found = %d", got) + log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got) } for router := range ndp.defaultRouters { @@ -1173,7 +1181,7 @@ func (ndp *ndpState) cleanupHostOnlyState() { } if got := len(ndp.defaultRouters); got != 0 { - log.Fatalf("ndp: still have discovered default routers after cleaning up, found = %d", got) + log.Fatalf("ndp: still have discovered default routers after cleaning up; found = %d", got) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 1e575bdaf..e13509fbd 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -42,6 +42,7 @@ const ( linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") + linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") defaultTimeout = 100 * time.Millisecond defaultAsyncEventTimeout = time.Second ) @@ -50,6 +51,7 @@ var ( llAddr1 = header.LinkLocalAddr(linkAddr1) llAddr2 = header.LinkLocalAddr(linkAddr2) llAddr3 = header.LinkLocalAddr(linkAddr3) + llAddr4 = header.LinkLocalAddr(linkAddr4) dstAddr = tcpip.FullAddress{ Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", Port: 25, @@ -2882,8 +2884,8 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } // TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers -// and prefixes, and auto-generated addresses get invalidated when a NIC -// becomes a router. +// and prefixes, and non-linklocal auto-generated addresses are invalidated when +// a NIC becomes a router. func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { t.Parallel() @@ -2898,6 +2900,14 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1) e2Addr1 := addrForSubnet(subnet1, linkAddr2) e2Addr2 := addrForSubnet(subnet2, linkAddr2) + llAddrWithPrefix1 := tcpip.AddressWithPrefix{ + Address: llAddr1, + PrefixLen: 64, + } + llAddrWithPrefix2 := tcpip.AddressWithPrefix{ + Address: llAddr2, + PrefixLen: 64, + } ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, maxEvents), @@ -2907,7 +2917,8 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents), } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: true, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: true, @@ -2917,16 +2928,6 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { NDPDisp: &ndpDisp, }) - e1 := channel.New(0, 1280, linkAddr1) - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) - } - - e2 := channel.New(0, 1280, linkAddr2) - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) - } - expectRouterEvent := func() (bool, ndpRouterEvent) { select { case e := <-ndpDisp.routerC: @@ -2957,18 +2958,30 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { return false, ndpAutoGenAddrEvent{} } - // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr1 and - // llAddr2) w/ PI (for prefix1 in RA from llAddr1 and prefix2 in RA from - // llAddr2) to discover multiple routers and prefixes, and auto-gen - // multiple addresses. - - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr1, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + e1 := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) + } // We have other tests that make sure we receive the *correct* events // on normal discovery of routers/prefixes, and auto-generated // addresses. Here we just make sure we get an event and let other tests // handle the correctness check. + expectAutoGenAddrEvent() + + e2 := channel.New(0, 1280, linkAddr2) + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) + } + expectAutoGenAddrEvent() + + // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and + // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from + // llAddr4) to discover multiple routers and prefixes, and auto-gen + // multiple addresses. + + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr1, nicID1) + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) @@ -2977,9 +2990,9 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) } - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr2, nicID1) + t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) @@ -2988,9 +3001,9 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) } - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr1, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr1, nicID2) + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) @@ -2999,9 +3012,9 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) } - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr2, nicID2) + t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) @@ -3014,12 +3027,18 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { nicinfo := s.NICInfo() nic1Addrs := nicinfo[nicID1].ProtocolAddresses nic2Addrs := nicinfo[nicID2].ProtocolAddresses + if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } if !containsV6Addr(nic1Addrs, e1Addr1) { t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) } if !containsV6Addr(nic1Addrs, e1Addr2) { t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) } + if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } if !containsV6Addr(nic2Addrs, e2Addr1) { t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) } @@ -3071,10 +3090,10 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { } expectedRouterEvents := map[ndpRouterEvent]int{ - {nicID: nicID1, addr: llAddr1, discovered: false}: 1, - {nicID: nicID1, addr: llAddr2, discovered: false}: 1, - {nicID: nicID2, addr: llAddr1, discovered: false}: 1, - {nicID: nicID2, addr: llAddr2, discovered: false}: 1, + {nicID: nicID1, addr: llAddr3, discovered: false}: 1, + {nicID: nicID1, addr: llAddr4, discovered: false}: 1, + {nicID: nicID2, addr: llAddr3, discovered: false}: 1, + {nicID: nicID2, addr: llAddr4, discovered: false}: 1, } if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { t.Errorf("router events mismatch (-want +got):\n%s", diff) @@ -3102,12 +3121,18 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { nicinfo = s.NICInfo() nic1Addrs = nicinfo[nicID1].ProtocolAddresses nic2Addrs = nicinfo[nicID2].ProtocolAddresses + if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } if containsV6Addr(nic1Addrs, e1Addr1) { t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) } if containsV6Addr(nic1Addrs, e1Addr2) { t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) } + if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } if containsV6Addr(nic2Addrs, e2Addr1) { t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 682e9c416..78d451cca 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -206,33 +206,9 @@ func (n *NIC) enable() *tcpip.Error { // Do not auto-generate an IPv6 link-local address for loopback devices. if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() { - var addr tcpip.Address - if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { - addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey) - } else { - l2addr := n.linkEP.LinkAddress() - - // Only attempt to generate the link-local address if we have a valid MAC - // address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by - // LinkEndpoint.LinkAddress) before reaching this point. - if !header.IsValidUnicastEthernetAddress(l2addr) { - return nil - } - - addr = header.LinkLocalAddr(l2addr) - } - - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, - }, - }, CanBePrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { - return err - } + // The valid and preferred lifetime is infinite for the auto-generated + // link-local address. + n.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) } // If we are operating as a router, then do not solicit routers since we diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 243868f3a..b2c1763bf 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1894,112 +1894,6 @@ func TestNICForwarding(t *testing.T) { } } -// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses -// using the modified EUI-64 of the NIC's MAC address (or lack there-of if -// disabled (default)). Note, DAD will be disabled in these tests. -func TestNICAutoGenAddr(t *testing.T) { - tests := []struct { - name string - autoGen bool - linkAddr tcpip.LinkAddress - iidOpts stack.OpaqueInterfaceIdentifierOptions - shouldGen bool - }{ - { - "Disabled", - false, - linkAddr1, - stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(nicID tcpip.NICID, _ string) string { - return fmt.Sprintf("nic%d", nicID) - }, - }, - false, - }, - { - "Enabled", - true, - linkAddr1, - stack.OpaqueInterfaceIdentifierOptions{}, - true, - }, - { - "Nil MAC", - true, - tcpip.LinkAddress([]byte(nil)), - stack.OpaqueInterfaceIdentifierOptions{}, - false, - }, - { - "Empty MAC", - true, - tcpip.LinkAddress(""), - stack.OpaqueInterfaceIdentifierOptions{}, - false, - }, - { - "Invalid MAC", - true, - tcpip.LinkAddress("\x01\x02\x03"), - stack.OpaqueInterfaceIdentifierOptions{}, - false, - }, - { - "Multicast MAC", - true, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - stack.OpaqueInterfaceIdentifierOptions{}, - false, - }, - { - "Unspecified MAC", - true, - tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), - stack.OpaqueInterfaceIdentifierOptions{}, - false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - OpaqueIIDOpts: test.iidOpts, - } - - if test.autoGen { - // Only set opts.AutoGenIPv6LinkLocal when test.autoGen is true because - // opts.AutoGenIPv6LinkLocal should be false by default. - opts.AutoGenIPv6LinkLocal = true - } - - e := channel.New(10, 1280, test.linkAddr) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - - if test.shouldGen { - // Should have auto-generated an address and resolved immediately (DAD - // is disabled). - if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) - } - } else { - // Should not have auto-generated an address. - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } - } - }) - } -} - // TestNICContextPreservation tests that you can read out via stack.NICInfo the // Context data you pass via NICContext.Context in stack.CreateNICWithOptions. func TestNICContextPreservation(t *testing.T) { @@ -2040,11 +1934,9 @@ func TestNICContextPreservation(t *testing.T) { } } -// TestNICAutoGenAddrWithOpaque tests the auto-generation of IPv6 link-local -// addresses with opaque interface identifiers. Link Local addresses should -// always be generated with opaque IIDs if configured to use them, even if the -// NIC has an invalid MAC address. -func TestNICAutoGenAddrWithOpaque(t *testing.T) { +// TestNICAutoGenLinkLocalAddr tests the auto-generation of IPv6 link-local +// addresses. +func TestNICAutoGenLinkLocalAddr(t *testing.T) { const nicID = 1 var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte @@ -2056,108 +1948,185 @@ func TestNICAutoGenAddrWithOpaque(t *testing.T) { t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n) } + nicNameFunc := func(_ tcpip.NICID, name string) string { + return name + } + tests := []struct { - name string - nicName string - autoGen bool - linkAddr tcpip.LinkAddress - secretKey []byte + name string + nicName string + autoGen bool + linkAddr tcpip.LinkAddress + iidOpts stack.OpaqueInterfaceIdentifierOptions + shouldGen bool + expectedAddr tcpip.Address }{ { name: "Disabled", nicName: "nic1", autoGen: false, linkAddr: linkAddr1, - secretKey: secretKey[:], + shouldGen: false, }, { - name: "Enabled", - nicName: "nic1", - autoGen: true, - linkAddr: linkAddr1, - secretKey: secretKey[:], + name: "Disabled without OIID options", + nicName: "nic1", + autoGen: false, + linkAddr: linkAddr1, + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:], + }, + shouldGen: false, }, - // These are all cases where we would not have generated a - // link-local address if opaque IIDs were disabled. + + // Tests for EUI64 based addresses. { - name: "Nil MAC and empty nicName", - nicName: "", + name: "EUI64 Enabled", + autoGen: true, + linkAddr: linkAddr1, + shouldGen: true, + expectedAddr: header.LinkLocalAddr(linkAddr1), + }, + { + name: "EUI64 Empty MAC", autoGen: true, - linkAddr: tcpip.LinkAddress([]byte(nil)), - secretKey: secretKey[:1], + shouldGen: false, }, { - name: "Empty MAC and empty nicName", + name: "EUI64 Invalid MAC", autoGen: true, - linkAddr: tcpip.LinkAddress(""), - secretKey: secretKey[:2], + linkAddr: "\x01\x02\x03", + shouldGen: false, }, { - name: "Invalid MAC", - nicName: "test", + name: "EUI64 Multicast MAC", autoGen: true, - linkAddr: tcpip.LinkAddress("\x01\x02\x03"), - secretKey: secretKey[:3], + linkAddr: "\x01\x02\x03\x04\x05\x06", + shouldGen: false, }, { - name: "Multicast MAC", - nicName: "test2", + name: "EUI64 Unspecified MAC", autoGen: true, - linkAddr: tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - secretKey: secretKey[:4], + linkAddr: "\x00\x00\x00\x00\x00\x00", + shouldGen: false, }, + + // Tests for Opaque IID based addresses. { - name: "Unspecified MAC and nil SecretKey", + name: "OIID Enabled", + nicName: "nic1", + autoGen: true, + linkAddr: linkAddr1, + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]), + }, + // These are all cases where we would not have generated a + // link-local address if opaque IIDs were disabled. + { + name: "OIID Empty MAC and empty nicName", + autoGen: true, + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:1], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("", 0, secretKey[:1]), + }, + { + name: "OIID Invalid MAC", + nicName: "test", + autoGen: true, + linkAddr: "\x01\x02\x03", + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:2], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("test", 0, secretKey[:2]), + }, + { + name: "OIID Multicast MAC", + nicName: "test2", + autoGen: true, + linkAddr: "\x01\x02\x03\x04\x05\x06", + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:3], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("test2", 0, secretKey[:3]), + }, + { + name: "OIID Unspecified MAC and nil SecretKey", nicName: "test3", autoGen: true, - linkAddr: tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + linkAddr: "\x00\x00\x00\x00\x00\x00", + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("test3", 0, nil), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: test.secretKey, - }, + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } - - if test.autoGen { - // Only set opts.AutoGenIPv6LinkLocal when - // test.autoGen is true because - // opts.AutoGenIPv6LinkLocal should be false by - // default. - opts.AutoGenIPv6LinkLocal = true + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: test.autoGen, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: test.iidOpts, } - e := channel.New(10, 1280, test.linkAddr) + e := channel.New(0, 1280, test.linkAddr) s := stack.New(opts) nicOpts := stack.NICOptions{Name: test.nicName} if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) } - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) - } + var expectedMainAddr tcpip.AddressWithPrefix + + if test.shouldGen { + expectedMainAddr = tcpip.AddressWithPrefix{ + Address: test.expectedAddr, + PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, + } - if test.autoGen { - // Should have auto-generated an address and - // resolved immediately (DAD is disabled). - if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddrWithOpaqueIID(test.nicName, 0, test.secretKey), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + // Should have auto-generated an address and resolved immediately (DAD + // is disabled). + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, expectedMainAddr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") } } else { // Should not have auto-generated an address. - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address") + default: } } + + gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if gotMainAddr != expectedMainAddr { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr) + } }) } } @@ -2226,7 +2195,7 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { NDPDisp: &ndpDisp, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) -- cgit v1.2.3 From ca30dfa065f5458228b06dcf5379ed4edf29c165 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 6 Feb 2020 19:49:30 -0800 Subject: Send DAD event when DAD resolves immediately Previously, a DAD event would not be sent if DAD was disabled. This allows integrators to do some work when an IPv6 address is bound to a NIC without special logic that checks if DAD is enabled. Without this change, integrators would need to check if a NIC has DAD enabled when an address is auto-generated. If DAD is enabled, it would need to delay the work until the DAD completion event; otherwise, it would need to do the work in the address auto-generated event handler. Test: stack_test.TestDADDisabled PiperOrigin-RevId: 293732914 --- pkg/tcpip/stack/ndp.go | 7 ++ pkg/tcpip/stack/ndp_test.go | 263 +++++++++++++++++++++--------------------- pkg/tcpip/stack/stack_test.go | 44 +++---- 3 files changed, 154 insertions(+), 160 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index fae5f5014..045409bda 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -448,6 +448,13 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref remaining := ndp.configs.DupAddrDetectTransmits if remaining == 0 { ref.setKind(permanent) + + // Consider DAD to have resolved even if no DAD messages were actually + // transmitted. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, true, nil) + } + return nil } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index e13509fbd..1f6f77439 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -86,39 +86,6 @@ func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWi return prefix, subnet, addrForSubnet(subnet, linkAddr) } -// TestDADDisabled tests that an address successfully resolves immediately -// when DAD is not enabled (the default for an empty stack.Options). -func TestDADDisabled(t *testing.T) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - } - - e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) - } - - // Should get the address immediately since we should not have performed - // DAD on it. - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1) - } - - // We should not have sent any NDP NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { - t.Fatalf("got NeighborSolicit = %d, want = 0", got) - } -} - // ndpDADEvent is a set of parameters that was passed to // ndpDispatcher.OnDuplicateAddressDetectionStatus. type ndpDADEvent struct { @@ -300,6 +267,58 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s } } +// Check e to make sure that the event is for addr on nic with ID 1, and the +// resolved flag set to resolved with the specified err. +func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string { + return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e)) +} + +// TestDADDisabled tests that an address successfully resolves immediately +// when DAD is not enabled (the default for an empty stack.Options). +func TestDADDisabled(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + } + + e := channel.New(0, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + } + + // Should get the address immediately since we should not have performed + // DAD on it. + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err) + } + if addr.Address != addr1 { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) + } + + // We should not have sent any NDP NS messages. + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { + t.Fatalf("got NeighborSolicit = %d, want = 0", got) + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -381,17 +400,8 @@ func TestDADResolve(t *testing.T) { // means something is wrong. t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != nicID { - t.Fatalf("got DAD event w/ nicID = %d, want = %d", e.nicID, nicID) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") + if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) @@ -445,6 +455,8 @@ func TestDADResolve(t *testing.T) { // a node doing DAD for the same address), or if another node is detected to own // the address already (receive an NA message for the tentative address). func TestDADFail(t *testing.T) { + const nicID = 1 + tests := []struct { name string makeBuf func(tgt tcpip.Address) buffer.Prependable @@ -526,22 +538,22 @@ func TestDADFail(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) } // Address should not be considered bound to the NIC yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Receive a packet to simulate multiple nodes owning or @@ -565,25 +577,16 @@ func TestDADFail(t *testing.T) { // something is wrong. t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) - } - if e.resolved { - t.Fatal("got DAD event w/ resolved = true, want = false") + if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } }) } @@ -592,6 +595,8 @@ func TestDADFail(t *testing.T) { // TestDADStop tests to make sure that the DAD process stops when an address is // removed. func TestDADStop(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } @@ -607,26 +612,26 @@ func TestDADStop(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Remove the address. This should stop DAD. - if err := s.RemoveAddress(1, addr1); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err) + if err := s.RemoveAddress(nicID, addr1); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) } // Wait for DAD to fail (since the address was removed during DAD). @@ -636,26 +641,16 @@ func TestDADStop(t *testing.T) { // time + extra 1s buffer, something is wrong. t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - if e.resolved { - t.Fatal("got DAD event w/ resolved = true, want = false") - } - } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Should not have sent more than 1 NS message. @@ -681,6 +676,10 @@ func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { // configurations without affecting the default NDP configurations or other // interfaces' configurations. func TestSetNDPConfigurations(t *testing.T) { + const nicID1 = 1 + const nicID2 = 2 + const nicID3 = 3 + tests := []struct { name string dupAddrDetectTransmits uint8 @@ -704,7 +703,7 @@ func TestSetNDPConfigurations(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -712,17 +711,28 @@ func TestSetNDPConfigurations(t *testing.T) { NDPDisp: &ndpDisp, }) + expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatalf("expected DAD event for %s", addr) + } + } + // This NIC(1)'s NDP configurations will be updated to // be different from the default. - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) } // Created before updating NIC(1)'s NDP configurations // but updating NIC(1)'s NDP configurations should not // affect other existing NICs. - if err := s.CreateNIC(2, e); err != nil { - t.Fatalf("CreateNIC(2) = %s", err) + if err := s.CreateNIC(nicID2, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) } // Update the NDP configurations on NIC(1) to use DAD. @@ -730,36 +740,38 @@ func TestSetNDPConfigurations(t *testing.T) { DupAddrDetectTransmits: test.dupAddrDetectTransmits, RetransmitTimer: test.retransmitTimer, } - if err := s.SetNDPConfigurations(1, configs); err != nil { - t.Fatalf("got SetNDPConfigurations(1, _) = %s", err) + if err := s.SetNDPConfigurations(nicID1, configs); err != nil { + t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err) } // Created after updating NIC(1)'s NDP configurations // but the stack's default NDP configurations should not // have been updated. - if err := s.CreateNIC(3, e); err != nil { - t.Fatalf("CreateNIC(3) = %s", err) + if err := s.CreateNIC(nicID3, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID3, err) } // Add addresses for each NIC. - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(1, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err) } - if err := s.AddAddress(2, header.IPv6ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(2, %d, %s) = %s", header.IPv6ProtocolNumber, addr2, err) + if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err) } - if err := s.AddAddress(3, header.IPv6ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(3, %d, %s) = %s", header.IPv6ProtocolNumber, addr3, err) + expectDADEvent(nicID2, addr2) + if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err) } + expectDADEvent(nicID3, addr3) // Address should not be considered bound to NIC(1) yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) } // Should get the address on NIC(2) and NIC(3) @@ -767,31 +779,31 @@ func TestSetNDPConfigurations(t *testing.T) { // it as the stack was configured to not do DAD by // default and we only updated the NDP configurations on // NIC(1). - addr, err = s.GetMainNICAddress(2, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(2, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err) } if addr.Address != addr2 { - t.Fatalf("got stack.GetMainNICAddress(2, _) = %s, want = %s", addr, addr2) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2) } - addr, err = s.GetMainNICAddress(3, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(3, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err) } if addr.Address != addr3 { - t.Fatalf("got stack.GetMainNICAddress(3, _) = %s, want = %s", addr, addr3) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3) } // Sleep until right (500ms before) before resolution to // make sure the address didn't resolve on NIC(1) yet. const delta = 500 * time.Millisecond time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) } // Wait for DAD to resolve. @@ -805,25 +817,16 @@ func TestSetNDPConfigurations(t *testing.T) { // means something is wrong. t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") + if diff := checkDADEvent(e, nicID1, addr1, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(1, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) } if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(1, _) = %s, want = %s", addr, addr1) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1) } }) } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b2c1763bf..24133e6f2 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2184,6 +2184,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { // TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 // link-local addresses will only be assigned after the DAD process resolves. func TestNICAutoGenAddrDoesDAD(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } @@ -2197,18 +2199,18 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1) s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } // Address should not be considered bound to the // NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } linkLocalAddr := header.LinkLocalAddr(linkAddr1) @@ -2222,25 +2224,16 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { // means something is wrong. t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != linkLocalAddr { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") + if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } } @@ -2606,17 +2599,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != nicID { - t.Fatalf("got DAD event w/ nicID = %d, want = %d", e.nicID, nicID) - } - if e.addr != addr.AddressWithPrefix.Address { - t.Fatalf("got DAD event w/ addr = %s, want = %s", e.addr, addr.AddressWithPrefix.Address) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") + if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { -- cgit v1.2.3 From b8e22e241cab625d2809034c74d0ff808b948b4c Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Tue, 11 Feb 2020 12:58:13 -0800 Subject: Disallow duplicate NIC names. PiperOrigin-RevId: 294500858 --- pkg/tcpip/stack/nic.go | 5 +++ pkg/tcpip/stack/stack.go | 9 +++++ pkg/tcpip/stack/stack_test.go | 85 +++++++++++++++++++++++++++++++++++++++++++ runsc/sandbox/network.go | 30 +++++++-------- 4 files changed, 114 insertions(+), 15 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 78d451cca..ca3a7a07e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1215,6 +1215,11 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +// Name returns the name of n. +func (n *NIC) Name() string { + return n.name +} + // Stack returns the instance of the Stack that owns this NIC. func (n *NIC) Stack() *Stack { return n.stack diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index b793f1d74..6eac16e16 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -890,6 +890,15 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp return tcpip.ErrDuplicateNICID } + // Make sure name is unique, unless unnamed. + if opts.Name != "" { + for _, n := range s.nics { + if n.Name() == opts.Name { + return tcpip.ErrDuplicateNICID + } + } + } + n := newNIC(s, id, opts.Name, ep, opts.Context) s.nics[id] = n diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 24133e6f2..7ba604442 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1792,6 +1792,91 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { verifyAddresses(t, expectedAddresses, gotAddresses) } +func TestCreateNICWithOptions(t *testing.T) { + type callArgsAndExpect struct { + nicID tcpip.NICID + opts stack.NICOptions + err *tcpip.Error + } + + tests := []struct { + desc string + calls []callArgsAndExpect + }{ + { + desc: "DuplicateNICID", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "eth1"}, + err: nil, + }, + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "eth2"}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + { + desc: "DuplicateName", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "lo"}, + err: nil, + }, + { + nicID: tcpip.NICID(2), + opts: stack.NICOptions{Name: "lo"}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + { + desc: "Unnamed", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: nil, + }, + { + nicID: tcpip.NICID(2), + opts: stack.NICOptions{}, + err: nil, + }, + }, + }, + { + desc: "UnnamedDuplicateNICID", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: nil, + }, + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + s := stack.New(stack.Options{}) + ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + for _, call := range test.calls { + if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { + t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) + } + } + }) + } +} + func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index ff48f5646..99e143696 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -174,13 +174,13 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG return fmt.Errorf("fetching interface addresses for %q: %v", iface.Name, err) } - // We build our own loopback devices. + // We build our own loopback device. if iface.Flags&net.FlagLoopback != 0 { - links, err := loopbackLinks(iface, allAddrs) + link, err := loopbackLink(iface, allAddrs) if err != nil { - return fmt.Errorf("getting loopback routes and links for iface %q: %v", iface.Name, err) + return fmt.Errorf("getting loopback link for iface %q: %v", iface.Name, err) } - args.LoopbackLinks = append(args.LoopbackLinks, links...) + args.LoopbackLinks = append(args.LoopbackLinks, link) continue } @@ -339,25 +339,25 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) ( return &socketEntry{deviceFile, gsoMaxSize}, nil } -// loopbackLinks collects the links for a loopback interface. -func loopbackLinks(iface net.Interface, addrs []net.Addr) ([]boot.LoopbackLink, error) { - var links []boot.LoopbackLink +// loopbackLink returns the link with addresses and routes for a loopback +// interface. +func loopbackLink(iface net.Interface, addrs []net.Addr) (boot.LoopbackLink, error) { + link := boot.LoopbackLink{ + Name: iface.Name, + } for _, addr := range addrs { ipNet, ok := addr.(*net.IPNet) if !ok { - return nil, fmt.Errorf("address is not IPNet: %+v", addr) + return boot.LoopbackLink{}, fmt.Errorf("address is not IPNet: %+v", addr) } dst := *ipNet dst.IP = dst.IP.Mask(dst.Mask) - links = append(links, boot.LoopbackLink{ - Name: iface.Name, - Addresses: []net.IP{ipNet.IP}, - Routes: []boot.Route{{ - Destination: dst, - }}, + link.Addresses = append(link.Addresses, ipNet.IP) + link.Routes = append(link.Routes, boot.Route{ + Destination: dst, }) } - return links, nil + return link, nil } // routesForIface iterates over all routes for the given interface and converts -- cgit v1.2.3 From b30b7f3422202232ad1c385a7ac0d775151fee2f Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Tue, 18 Feb 2020 11:30:42 -0800 Subject: Add nat table support for iptables. Add nat table support for Prerouting hook with Redirect option. Add tests to check redirect of ports. --- pkg/abi/linux/netfilter.go | 27 ++++++++++++++ pkg/sentry/socket/netfilter/netfilter.go | 58 +++++++++++++++++++++++++++-- pkg/tcpip/iptables/iptables.go | 21 +++++++++++ pkg/tcpip/iptables/targets.go | 24 ++++++++++++ pkg/tcpip/stack/nic.go | 23 ++++++++++++ test/iptables/iptables_test.go | 12 ++++++ test/iptables/iptables_util.go | 10 +++++ test/iptables/nat.go | 64 +++++++++++++++++++++++++++++++- 8 files changed, 234 insertions(+), 5 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index bbc4df74c..ba4d84962 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -250,6 +250,33 @@ type XTErrorTarget struct { // SizeOfXTErrorTarget is the size of an XTErrorTarget. const SizeOfXTErrorTarget = 64 +// NfNATIPV4Range. It corresponds to struct nf_nat_ipv4_range +// in include/uapi/linux/netfilter/nf_nat.h. +type NfNATIPV4Range struct { + Flags uint32 + MinIP [4]byte + MaxIP [4]byte + MinPort uint16 + MaxPort uint16 +} + +// NfNATIPV4MultiRangeCompat. It corresponds to struct +// nf_nat_ipv4_multi_range_compat in include/uapi/linux/netfilter/nf_nat.h. +type NfNATIPV4MultiRangeCompat struct { + Rangesize uint32 + RangeIPV4 [1]NfNATIPV4Range +} + +// XTRedirectTarget triggers a redirect when reached. +type XTRedirectTarget struct { + Target XTEntryTarget + NfRange NfNATIPV4MultiRangeCompat + _ [4]byte +} + +// SizeOfXTRedirectTarget is the size of an XTRedirectTarget. +const SizeOfXTRedirectTarget = 56 + // IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds // to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h. type IPTGetinfo struct { diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 3fc80e0de..512ad624a 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -35,6 +35,11 @@ import ( // shouldn't be reached - an error has occurred if we fall through to one. const errorTargetName = "ERROR" +// redirectTargetName is used to mark targets as redirect targets. Redirect +// targets should be reached for only NAT and Mangle tables. These targets will +// change the destination port/destination IP for packets. +const redirectTargetName = "REDIRECT" + // Metadata is used to verify that we are correctly serializing and // deserializing iptables into structs consumable by the iptables tool. We save // a metadata struct when the tables are written, and when they are read out we @@ -240,6 +245,8 @@ func marshalTarget(target iptables.Target) []byte { return marshalErrorTarget(tg.Name) case iptables.ReturnTarget: return marshalStandardTarget(iptables.RuleReturn) + case iptables.RedirectTarget: + return marshalRedirectTarget() default: panic(fmt.Errorf("unknown target of type %T", target)) } @@ -274,6 +281,18 @@ func marshalErrorTarget(errorName string) []byte { return binary.Marshal(ret, usermem.ByteOrder, target) } +func marshalRedirectTarget() []byte { + // This is a redirect target named redirect + target := linux.XTRedirectTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTRedirectTarget, + }, + } + + ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + // translateFromStandardVerdict translates verdicts the same way as the iptables // tool. func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 { @@ -326,6 +345,8 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { switch replace.Name.String() { case iptables.TablenameFilter: table = iptables.EmptyFilterTable() + case iptables.TablenameNat: + table = iptables.EmptyNatTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) return syserr.ErrInvalidArgument @@ -455,10 +476,11 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { } // TODO(gvisor.dev/issue/170): Support other chains. - // Since we only support modifying the INPUT chain right now, make sure - // all other chains point to ACCEPT rules. + // Since we only support modifying the INPUT chain and redirect for + // PREROUTING chain right now, make sure all other chains point to + // ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook != iptables.Input { + if hook != iptables.Input && hook != iptables.Prerouting { if _, ok := table.Rules[ruleIdx].Target.(iptables.AcceptTarget); !ok { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument @@ -575,6 +597,36 @@ func parseTarget(optVal []byte) (iptables.Target, error) { nflog("set entries: user-defined target %q", name) return iptables.UserChainTarget{Name: name}, nil } + + case redirectTargetName: + // Redirect target. + if len(optVal) < linux.SizeOfXTRedirectTarget { + return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal)) + } + + var redirectTarget linux.XTRedirectTarget + buf = optVal[:linux.SizeOfXTRedirectTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) + + // Copy linux.XTRedirectTarget to iptables.RedirectTarget. + var target iptables.RedirectTarget + nfRange := redirectTarget.NfRange + + target.RangeSize = nfRange.Rangesize + target.Flags = nfRange.RangeIPV4[0].Flags + + target.MinIP = tcpip.Address(nfRange.RangeIPV4[0].MinIP[:]) + target.MaxIP = tcpip.Address(nfRange.RangeIPV4[0].MaxIP[:]) + + // Convert port from big endian to little endian. + port := make([]byte, 2) + binary.BigEndian.PutUint16(port, nfRange.RangeIPV4[0].MinPort) + target.MinPort = binary.LittleEndian.Uint16(port) + + binary.BigEndian.PutUint16(port, nfRange.RangeIPV4[0].MaxPort) + target.MaxPort = binary.LittleEndian.Uint16(port) + return target, nil + } // Unknown target. diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index 75a433a3b..c00d012c0 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -135,6 +135,27 @@ func EmptyFilterTable() Table { } } +// EmptyNatTable returns a Table with no rules and the filter table chains +// mapped to HookUnset. +func EmptyNatTable() Table { + return Table{ + Rules: []Rule{}, + BuiltinChains: map[Hook]int{ + Prerouting: HookUnset, + Input: HookUnset, + Output: HookUnset, + Postrouting: HookUnset, + }, + Underflows: map[Hook]int{ + Prerouting: HookUnset, + Input: HookUnset, + Output: HookUnset, + Postrouting: HookUnset, + }, + UserChains: map[string]int{}, + } +} + // Check runs pkt through the rules for hook. It returns true when the packet // should continue traversing the network stack and false when it should be // dropped. diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go index 9fc60cfad..06e65bece 100644 --- a/pkg/tcpip/iptables/targets.go +++ b/pkg/tcpip/iptables/targets.go @@ -19,6 +19,7 @@ package iptables import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // AcceptTarget accepts packets. @@ -65,3 +66,26 @@ type ReturnTarget struct{} func (ReturnTarget) Action(tcpip.PacketBuffer) (RuleVerdict, string) { return RuleReturn, "" } + +// RedirectTarget redirects the packet by modifying the destination port/IP. +type RedirectTarget struct { + RangeSize uint32 + Flags uint32 + MinIP tcpip.Address + MaxIP tcpip.Address + MinPort uint16 + MaxPort uint16 +} + +// Action implements Target.Action. +func (rt RedirectTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, string) { + log.Infof("RedirectTarget triggered.") + + // TODO(gvisor.dev/issue/170): Checking only for UDP protocol. + // We're yet to support for TCP protocol. + headerView := packet.Data.First() + h := header.UDP(headerView) + h.SetDestinationPort(rt.MinPort) + + return RuleAccept, "" +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ca3a7a07e..2028f5201 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" ) // NIC represents a "network interface card" to which the networking stack is @@ -1012,6 +1013,7 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr + ref.ep.HandlePacket(&r, pkt) ref.decRef() } @@ -1082,6 +1084,27 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.InvalidSourceAddressesReceived.Increment() return } + + // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. + if protocol == header.IPv4ProtocolNumber { + newPkt := pkt.Clone() + + headerView := newPkt.Data.First() + h := header.IPv4(headerView) + newPkt.NetworkHeader = headerView[:h.HeaderLength()] + + hlen := int(h.HeaderLength()) + tlen := int(h.TotalLength()) + newPkt.Data.TrimFront(hlen) + newPkt.Data.CapLength(tlen - hlen) + + ipt := n.stack.IPTables() + if ok := ipt.Check(iptables.Prerouting, newPkt); !ok { + // iptables is telling us to drop the packet. + return + } + } + if ref := n.getRef(protocol, dst); ref != nil { handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) return diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 46a7c99b0..7d061acba 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -196,12 +196,24 @@ func TestNATRedirectUDPPort(t *testing.T) { } } +func TestNATRedirectTCPPort(t *testing.T) { + if err := singleTest(NATRedirectTCPPort{}); err != nil { + t.Fatal(err) + } +} + func TestNATDropUDP(t *testing.T) { if err := singleTest(NATDropUDP{}); err != nil { t.Fatal(err) } } +func TestNATAcceptAll(t *testing.T) { + if err := singleTest(NATAcceptAll{}); err != nil { + t.Fatal(err) + } +} + func TestFilterInputDropTCPDestPort(t *testing.T) { if err := singleTest(FilterInputDropTCPDestPort{}); err != nil { t.Fatal(err) diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index 043114c78..5c9199abf 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -35,6 +35,16 @@ func filterTable(args ...string) error { return nil } +// natTable calls `iptables -t nat` with the given args. +func natTable(args ...string) error { + args = append([]string{"-t", "nat"}, args...) + cmd := exec.Command(iptablesBinary, args...) + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out)) + } + return nil +} + // listenUDP listens on a UDP port and returns the value of net.Conn.Read() for // the first read on that port. func listenUDP(port int, timeout time.Duration) error { diff --git a/test/iptables/nat.go b/test/iptables/nat.go index b5c6f927e..306cbd1b3 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -25,7 +25,9 @@ const ( func init() { RegisterTestCase(NATRedirectUDPPort{}) + RegisterTestCase(NATRedirectTCPPort{}) RegisterTestCase(NATDropUDP{}) + RegisterTestCase(NATAcceptAll{}) } // NATRedirectUDPPort tests that packets are redirected to different port. @@ -38,13 +40,14 @@ func (NATRedirectUDPPort) Name() string { // ContainerAction implements TestCase.ContainerAction. func (NATRedirectUDPPort) ContainerAction(ip net.IP) error { - if err := filterTable("-t", "nat", "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { + if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } if err := listenUDP(redirectPort, sendloopDuration); err != nil { return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", redirectPort, err) } + return nil } @@ -53,6 +56,37 @@ func (NATRedirectUDPPort) LocalAction(ip net.IP) error { return sendUDPLoop(ip, acceptPort, sendloopDuration) } +// NATRedirectTCPPort tests that connections are redirected on specified ports. +type NATRedirectTCPPort struct{} + +// Name implements TestCase.Name. +func (NATRedirectTCPPort) Name() string { + return "NATRedirectTCPPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATRedirectTCPPort) ContainerAction(ip net.IP) error { + if err := natTable("-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { + return err + } + + // Listen for TCP packets on redirect port. + if err := listenTCP(redirectPort, sendloopDuration); err != nil { + return fmt.Errorf("connection on port %d should be accepted, but got error %v", redirectPort, err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (NATRedirectTCPPort) LocalAction(ip net.IP) error { + if err := connectTCP(ip, dropPort, acceptPort, sendloopDuration); err != nil { + return fmt.Errorf("connection destined to port %d should be accepted, but got error %v", dropPort, err) + } + + return nil +} + // NATDropUDP tests that packets are not received in ports other than redirect port. type NATDropUDP struct{} @@ -63,7 +97,7 @@ func (NATDropUDP) Name() string { // ContainerAction implements TestCase.ContainerAction. func (NATDropUDP) ContainerAction(ip net.IP) error { - if err := filterTable("-t", "nat", "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { + if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } @@ -78,3 +112,29 @@ func (NATDropUDP) ContainerAction(ip net.IP) error { func (NATDropUDP) LocalAction(ip net.IP) error { return sendUDPLoop(ip, acceptPort, sendloopDuration) } + +// NATAcceptAll tests that all UDP packets are accepted. +type NATAcceptAll struct{} + +// Name implements TestCase.Name. +func (NATAcceptAll) Name() string { + return "NATAcceptAll" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATAcceptAll) ContainerAction(ip net.IP) error { + if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil { + return err + } + + if err := listenUDP(acceptPort, sendloopDuration); err != nil { + return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (NATAcceptAll) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} -- cgit v1.2.3 From 92d2d78876a938871327685e1104d7b4ff46986e Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Tue, 18 Feb 2020 21:20:41 -0800 Subject: Fix mis-named comment. --- pkg/tcpip/stack/registration.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ec91f60dd..d83adf0ec 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -277,7 +277,7 @@ type NetworkProtocol interface { // DefaultPrefixLen returns the protocol's default prefix length. DefaultPrefixLen() int - // ParsePorts returns the source and destination addresses stored in a + // ParseAddresses returns the source and destination addresses stored in a // packet of this protocol. ParseAddresses(v buffer.View) (src, dst tcpip.Address) -- cgit v1.2.3 From 67b615b86f2aa1d4ded3dcf2eb8aca4e7fec57a0 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 20 Feb 2020 14:31:39 -0800 Subject: Support disabling a NIC - Disabled NICs will have their associated NDP state cleared. - Disabled NICs will not accept incoming packets. - Writes through a Route with a disabled NIC will return an invalid endpoint state error. - stack.Stack.FindRoute will not return a route with a disabled NIC. - NIC's Running flag will report the NIC's enabled status. Tests: - stack_test.TestDisableUnknownNIC - stack_test.TestDisabledNICsNICInfoAndCheckNIC - stack_test.TestRoutesWithDisabledNIC - stack_test.TestRouteWritePacketWithDisabledNIC - stack_test.TestStopStartSolicitingRouters - stack_test.TestCleanupNDPState - stack_test.TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable - stack_test.TestJoinLeaveAllNodesMulticastOnNICEnableDisable PiperOrigin-RevId: 296298588 --- pkg/tcpip/stack/ndp.go | 23 +- pkg/tcpip/stack/ndp_test.go | 822 ++++++++++++++++++++++++------------------ pkg/tcpip/stack/nic.go | 110 +++++- pkg/tcpip/stack/stack.go | 45 ++- pkg/tcpip/stack/stack_test.go | 377 ++++++++++++++++++- 5 files changed, 998 insertions(+), 379 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 045409bda..19bd05aa3 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -1148,22 +1148,27 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo return true } -// cleanupHostOnlyState cleans up any state that is only useful for hosts. +// cleanupState cleans up ndp's state. // -// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a -// host to a router. This function will invalidate all discovered on-link -// prefixes, discovered routers, and auto-generated addresses as routers do not -// normally process Router Advertisements to discover default routers and -// on-link prefixes, and auto-generate addresses via SLAAC. +// If hostOnly is true, then only host-specific state will be cleaned up. +// +// cleanupState MUST be called with hostOnly set to true when ndp's NIC is +// transitioning from a host to a router. This function will invalidate all +// discovered on-link prefixes, discovered routers, and auto-generated +// addresses. +// +// If hostOnly is true, then the link-local auto-generated address will not be +// invalidated as routers are also expected to generate a link-local address. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupHostOnlyState() { +func (ndp *ndpState) cleanupState(hostOnly bool) { linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() linkLocalAddrs := 0 for addr := range ndp.autoGenAddresses { // RFC 4862 section 5 states that routers are also expected to generate a - // link-local address so we do not invalidate them. - if linkLocalSubnet.Contains(addr) { + // link-local address so we do not invalidate them if we are cleaning up + // host-only state. + if hostOnly && linkLocalSubnet.Contains(addr) { linkLocalAddrs++ continue } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 1f6f77439..f7b75b74e 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -592,70 +592,94 @@ func TestDADFail(t *testing.T) { } } -// TestDADStop tests to make sure that the DAD process stops when an address is -// removed. func TestDADStop(t *testing.T) { const nicID = 1 - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - ndpConfigs := stack.NDPConfigurations{ - RetransmitTimer: time.Second, - DupAddrDetectTransmits: 2, - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, - } + tests := []struct { + name string + stopFn func(t *testing.T, s *stack.Stack) + }{ + // Tests to make sure that DAD stops when an address is removed. + { + name: "Remove address", + stopFn: func(t *testing.T, s *stack.Stack) { + if err := s.RemoveAddress(nicID, addr1); err != nil { + t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err) + } + }, + }, - e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + // Tests to make sure that DAD stops when the NIC is disabled. + { + name: "Disable NIC", + stopFn: func(t *testing.T, s *stack.Stack) { + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + }, + }, } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + } + ndpConfigs := stack.NDPConfigurations{ + RetransmitTimer: time.Second, + DupAddrDetectTransmits: 2, + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + } - // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } - // Remove the address. This should stop DAD. - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) - } + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err) + } - // Wait for DAD to fail (since the address was removed during DAD). - select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } + // Address should not be considered bound to the NIC yet (DAD ongoing). + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } + + test.stopFn(t, s) + + // Wait for DAD to fail (since the address was removed during DAD). + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + } + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } - // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { - t.Fatalf("got NeighborSolicit = %d, want <= 1", got) + // Should not have sent more than 1 NS message. + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { + t.Errorf("got NeighborSolicit = %d, want <= 1", got) + } + }) } } @@ -2886,17 +2910,16 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } } -// TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers -// and prefixes, and non-linklocal auto-generated addresses are invalidated when -// a NIC becomes a router. -func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { +// TestCleanupNDPState tests that all discovered routers and prefixes, and +// auto-generated addresses are invalidated when a NIC becomes a router. +func TestCleanupNDPState(t *testing.T) { t.Parallel() const ( - lifetimeSeconds = 5 - maxEvents = 4 - nicID1 = 1 - nicID2 = 2 + lifetimeSeconds = 5 + maxRouterAndPrefixEvents = 4 + nicID1 = 1 + nicID2 = 2 ) prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1) @@ -2912,254 +2935,308 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { PrefixLen: 64, } - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, maxEvents), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, maxEvents), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents), - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: true, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - DiscoverOnLinkPrefixes: true, - AutoGenGlobalAddresses: true, + tests := []struct { + name string + cleanupFn func(t *testing.T, s *stack.Stack) + keepAutoGenLinkLocal bool + maxAutoGenAddrEvents int + }{ + // A NIC should still keep its auto-generated link-local address when + // becoming a router. + { + name: "Forwarding Enable", + cleanupFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + s.SetForwarding(true) + }, + keepAutoGenLinkLocal: true, + maxAutoGenAddrEvents: 4, }, - NDPDisp: &ndpDisp, - }) - expectRouterEvent := func() (bool, ndpRouterEvent) { - select { - case e := <-ndpDisp.routerC: - return true, e - default: - } + // A NIC should cleanup all NDP state when it is disabled. + { + name: "NIC Disable", + cleanupFn: func(t *testing.T, s *stack.Stack) { + t.Helper() - return false, ndpRouterEvent{} + if err := s.DisableNIC(nicID1); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) + } + if err := s.DisableNIC(nicID2); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) + } + }, + keepAutoGenLinkLocal: false, + maxAutoGenAddrEvents: 6, + }, } - expectPrefixEvent := func() (bool, ndpPrefixEvent) { - select { - case e := <-ndpDisp.prefixC: - return true, e - default: - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents), + rememberRouter: true, + prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), + rememberPrefix: true, + autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: true, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) - return false, ndpPrefixEvent{} - } + expectRouterEvent := func() (bool, ndpRouterEvent) { + select { + case e := <-ndpDisp.routerC: + return true, e + default: + } - expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { - select { - case e := <-ndpDisp.autoGenAddrC: - return true, e - default: - } + return false, ndpRouterEvent{} + } - return false, ndpAutoGenAddrEvent{} - } + expectPrefixEvent := func() (bool, ndpPrefixEvent) { + select { + case e := <-ndpDisp.prefixC: + return true, e + default: + } - e1 := channel.New(0, 1280, linkAddr1) - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) - } - // We have other tests that make sure we receive the *correct* events - // on normal discovery of routers/prefixes, and auto-generated - // addresses. Here we just make sure we get an event and let other tests - // handle the correctness check. - expectAutoGenAddrEvent() + return false, ndpPrefixEvent{} + } - e2 := channel.New(0, 1280, linkAddr2) - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) - } - expectAutoGenAddrEvent() + expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { + select { + case e := <-ndpDisp.autoGenAddrC: + return true, e + default: + } - // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and - // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from - // llAddr4) to discover multiple routers and prefixes, and auto-gen - // multiple addresses. + return false, ndpAutoGenAddrEvent{} + } - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) - } + e1 := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) + } + // We have other tests that make sure we receive the *correct* events + // on normal discovery of routers/prefixes, and auto-generated + // addresses. Here we just make sure we get an event and let other tests + // handle the correctness check. + expectAutoGenAddrEvent() - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) - } + e2 := channel.New(0, 1280, linkAddr2) + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) + } + expectAutoGenAddrEvent() - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) - } + // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and + // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from + // llAddr4) to discover multiple routers and prefixes, and auto-gen + // multiple addresses. - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) - } + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) + } - // We should have the auto-generated addresses added. - nicinfo := s.NICInfo() - nic1Addrs := nicinfo[nicID1].ProtocolAddresses - nic2Addrs := nicinfo[nicID2].ProtocolAddresses - if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) + } - // We can't proceed any further if we already failed the test (missing - // some discovery/auto-generated address events or addresses). - if t.Failed() { - t.FailNow() - } + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) + } - s.SetForwarding(true) + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) + } - // Collect invalidation events after becoming a router - gotRouterEvents := make(map[ndpRouterEvent]int) - for i := 0; i < maxEvents; i++ { - ok, e := expectRouterEvent() - if !ok { - t.Errorf("expected %d router events after becoming a router; got = %d", maxEvents, i) - break - } - gotRouterEvents[e]++ - } - gotPrefixEvents := make(map[ndpPrefixEvent]int) - for i := 0; i < maxEvents; i++ { - ok, e := expectPrefixEvent() - if !ok { - t.Errorf("expected %d prefix events after becoming a router; got = %d", maxEvents, i) - break - } - gotPrefixEvents[e]++ - } - gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) - for i := 0; i < maxEvents; i++ { - ok, e := expectAutoGenAddrEvent() - if !ok { - t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", maxEvents, i) - break - } - gotAutoGenAddrEvents[e]++ - } + // We should have the auto-generated addresses added. + nicinfo := s.NICInfo() + nic1Addrs := nicinfo[nicID1].ProtocolAddresses + nic2Addrs := nicinfo[nicID2].ProtocolAddresses + if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } + if !containsV6Addr(nic1Addrs, e1Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if !containsV6Addr(nic1Addrs, e1Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } + if !containsV6Addr(nic2Addrs, e2Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if !containsV6Addr(nic2Addrs, e2Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } - // No need to proceed any further if we already failed the test (missing - // some invalidation events). - if t.Failed() { - t.FailNow() - } + // We can't proceed any further if we already failed the test (missing + // some discovery/auto-generated address events or addresses). + if t.Failed() { + t.FailNow() + } - expectedRouterEvents := map[ndpRouterEvent]int{ - {nicID: nicID1, addr: llAddr3, discovered: false}: 1, - {nicID: nicID1, addr: llAddr4, discovered: false}: 1, - {nicID: nicID2, addr: llAddr3, discovered: false}: 1, - {nicID: nicID2, addr: llAddr4, discovered: false}: 1, - } - if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { - t.Errorf("router events mismatch (-want +got):\n%s", diff) - } - expectedPrefixEvents := map[ndpPrefixEvent]int{ - {nicID: nicID1, prefix: subnet1, discovered: false}: 1, - {nicID: nicID1, prefix: subnet2, discovered: false}: 1, - {nicID: nicID2, prefix: subnet1, discovered: false}: 1, - {nicID: nicID2, prefix: subnet2, discovered: false}: 1, - } - if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { - t.Errorf("prefix events mismatch (-want +got):\n%s", diff) - } - expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ - {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, - } - if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { - t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) - } + test.cleanupFn(t, s) - // Make sure the auto-generated addresses got removed. - nicinfo = s.NICInfo() - nic1Addrs = nicinfo[nicID1].ProtocolAddresses - nic2Addrs = nicinfo[nicID2].ProtocolAddresses - if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } + // Collect invalidation events after having NDP state cleaned up. + gotRouterEvents := make(map[ndpRouterEvent]int) + for i := 0; i < maxRouterAndPrefixEvents; i++ { + ok, e := expectRouterEvent() + if !ok { + t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + break + } + gotRouterEvents[e]++ + } + gotPrefixEvents := make(map[ndpPrefixEvent]int) + for i := 0; i < maxRouterAndPrefixEvents; i++ { + ok, e := expectPrefixEvent() + if !ok { + t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + break + } + gotPrefixEvents[e]++ + } + gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) + for i := 0; i < test.maxAutoGenAddrEvents; i++ { + ok, e := expectAutoGenAddrEvent() + if !ok { + t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i) + break + } + gotAutoGenAddrEvents[e]++ + } - // Should not get any more events (invalidation timers should have been - // cancelled when we transitioned into a router). - time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) - select { - case <-ndpDisp.routerC: - t.Error("unexpected router event") - default: - } - select { - case <-ndpDisp.prefixC: - t.Error("unexpected prefix event") - default: - } - select { - case <-ndpDisp.autoGenAddrC: - t.Error("unexpected auto-generated address event") - default: + // No need to proceed any further if we already failed the test (missing + // some invalidation events). + if t.Failed() { + t.FailNow() + } + + expectedRouterEvents := map[ndpRouterEvent]int{ + {nicID: nicID1, addr: llAddr3, discovered: false}: 1, + {nicID: nicID1, addr: llAddr4, discovered: false}: 1, + {nicID: nicID2, addr: llAddr3, discovered: false}: 1, + {nicID: nicID2, addr: llAddr4, discovered: false}: 1, + } + if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { + t.Errorf("router events mismatch (-want +got):\n%s", diff) + } + expectedPrefixEvents := map[ndpPrefixEvent]int{ + {nicID: nicID1, prefix: subnet1, discovered: false}: 1, + {nicID: nicID1, prefix: subnet2, discovered: false}: 1, + {nicID: nicID2, prefix: subnet1, discovered: false}: 1, + {nicID: nicID2, prefix: subnet2, discovered: false}: 1, + } + if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { + t.Errorf("prefix events mismatch (-want +got):\n%s", diff) + } + expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ + {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, + } + + if !test.keepAutoGenLinkLocal { + expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1 + expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1 + } + + if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { + t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) + } + + // Make sure the auto-generated addresses got removed. + nicinfo = s.NICInfo() + nic1Addrs = nicinfo[nicID1].ProtocolAddresses + nic2Addrs = nicinfo[nicID2].ProtocolAddresses + if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } + } + if containsV6Addr(nic1Addrs, e1Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if containsV6Addr(nic1Addrs, e1Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } + } + if containsV6Addr(nic2Addrs, e2Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if containsV6Addr(nic2Addrs, e2Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } + + // Should not get any more events (invalidation timers should have been + // cancelled when the NDP state was cleaned up). + time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) + select { + case <-ndpDisp.routerC: + t.Error("unexpected router event") + default: + } + select { + case <-ndpDisp.prefixC: + t.Error("unexpected prefix event") + default: + } + select { + case <-ndpDisp.autoGenAddrC: + t.Error("unexpected auto-generated address event") + default: + } + }) } } @@ -3406,77 +3483,130 @@ func TestRouterSolicitation(t *testing.T) { }) } -// TestStopStartSolicitingRouters tests that when forwarding is enabled or -// disabled, router solicitations are stopped or started, respecitively. func TestStopStartSolicitingRouters(t *testing.T) { t.Parallel() + const nicID = 1 const interval = 500 * time.Millisecond const delay = time.Second const maxRtrSolicitations = 3 - e := channel.New(maxRtrSolicitations, 1280, linkAddr1) - waitForPkt := func(timeout time.Duration) { - t.Helper() - ctx, _ := context.WithTimeout(context.Background(), timeout) - p, ok := e.ReadContext(ctx) - if !ok { - t.Fatal("timed out waiting for packet") - return - } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS()) - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: interval, - MaxRtrSolicitationDelay: delay, + tests := []struct { + name string + startFn func(t *testing.T, s *stack.Stack) + stopFn func(t *testing.T, s *stack.Stack) + }{ + // Tests that when forwarding is enabled or disabled, router solicitations + // are stopped or started, respectively. + { + name: "Forwarding enabled and disabled", + startFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + s.SetForwarding(false) + }, + stopFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + s.SetForwarding(true) + }, }, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Enable forwarding which should stop router solicitations. - s.SetForwarding(true) - ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - // A single RS may have been sent before forwarding was enabled. - ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) - if _, ok = e.ReadContext(ctx); ok { - t.Fatal("Should not have sent more than one RS message") - } - } + // Tests that when a NIC is enabled or disabled, router solicitations + // are started or stopped, respectively. + { + name: "NIC disabled and enabled", + startFn: func(t *testing.T, s *stack.Stack) { + t.Helper() - // Enabling forwarding again should do nothing. - s.SetForwarding(true) - ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet after becoming a router") - } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + }, + stopFn: func(t *testing.T, s *stack.Stack) { + t.Helper() - // Disable forwarding which should start router solicitations. - s.SetForwarding(false) - waitForPkt(delay + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + }, + }, } - // Disabling forwarding again should do nothing. - s.SetForwarding(false) - ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet after becoming a router") + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(maxRtrSolicitations, 1280, linkAddr1) + waitForPkt := func(timeout time.Duration) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + p, ok := e.ReadContext(ctx) + if !ok { + t.Fatal("timed out waiting for packet") + return + } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS()) + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: interval, + MaxRtrSolicitationDelay: delay, + }, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // Stop soliciting routers. + test.stopFn(t, s) + ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + // A single RS may have been sent before forwarding was enabled. + ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout) + defer cancel() + if _, ok = e.ReadContext(ctx); ok { + t.Fatal("should not have sent more than one RS message") + } + } + + // Stopping router solicitations after it has already been stopped should + // do nothing. + test.stopFn(t, s) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") + } + + // Start soliciting routers. + test.startFn(t, s) + waitForPkt(delay + defaultAsyncEventTimeout) + waitForPkt(interval + defaultAsyncEventTimeout) + waitForPkt(interval + defaultAsyncEventTimeout) + ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") + } + + // Starting router solicitations after it has already completed should do + // nothing. + test.startFn(t, s) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got a packet after finishing router solicitations") + } + }) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ca3a7a07e..b2be18e47 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -27,6 +27,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) +var ipv4BroadcastAddr = tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 8 * header.IPv4AddressSize, + }, +} + // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { @@ -36,7 +44,8 @@ type NIC struct { linkEP LinkEndpoint context NICContext - stats NICStats + stats NICStats + attach sync.Once mu struct { sync.RWMutex @@ -135,7 +144,69 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC return nic } -// enable enables the NIC. enable will attach the link to its LinkEndpoint and +// enabled returns true if n is enabled. +func (n *NIC) enabled() bool { + n.mu.RLock() + enabled := n.mu.enabled + n.mu.RUnlock() + return enabled +} + +// disable disables n. +// +// It undoes the work done by enable. +func (n *NIC) disable() *tcpip.Error { + n.mu.RLock() + enabled := n.mu.enabled + n.mu.RUnlock() + if !enabled { + return nil + } + + n.mu.Lock() + defer n.mu.Unlock() + + if !n.mu.enabled { + return nil + } + + // TODO(b/147015577): Should Routes that are currently bound to n be + // invalidated? Currently, Routes will continue to work when a NIC is enabled + // again, and applications may not know that the underlying NIC was ever + // disabled. + + if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { + n.mu.ndp.stopSolicitingRouters() + n.mu.ndp.cleanupState(false /* hostOnly */) + + // Stop DAD for all the unicast IPv6 endpoints that are in the + // permanentTentative state. + for _, r := range n.mu.endpoints { + if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) { + n.mu.ndp.stopDuplicateAddressDetection(addr) + } + } + + // The NIC may have already left the multicast group. + if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + } + + if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + // The address may have already been removed. + if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + } + + // TODO(b/147015577): Should n detach from its LinkEndpoint? + + n.mu.enabled = false + return nil +} + +// enable enables n. enable will attach the nic to its LinkEndpoint and // join the IPv6 All-Nodes Multicast address (ff02::1). func (n *NIC) enable() *tcpip.Error { n.mu.RLock() @@ -158,10 +229,7 @@ func (n *NIC) enable() *tcpip.Error { // Create an endpoint to receive broadcast packets on this interface. if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, - }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { + if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { return err } } @@ -183,6 +251,14 @@ func (n *NIC) enable() *tcpip.Error { return nil } + // Join the All-Nodes multicast group before starting DAD as responses to DAD + // (NDP NS) messages may be sent to the All-Nodes multicast group if the + // source address of the NDP NS is the unspecified address, as per RFC 4861 + // section 7.2.4. + if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { + return err + } + // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent // state. // @@ -200,10 +276,6 @@ func (n *NIC) enable() *tcpip.Error { } } - if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { - return err - } - // Do not auto-generate an IPv6 link-local address for loopback devices. if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() { // The valid and preferred lifetime is infinite for the auto-generated @@ -234,7 +306,7 @@ func (n *NIC) becomeIPv6Router() { n.mu.Lock() defer n.mu.Unlock() - n.mu.ndp.cleanupHostOnlyState() + n.mu.ndp.cleanupState(true /* hostOnly */) n.mu.ndp.stopSolicitingRouters() } @@ -252,7 +324,9 @@ func (n *NIC) becomeIPv6Host() { // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it // to start delivering packets. func (n *NIC) attachLinkEndpoint() { - n.linkEP.Attach(n) + n.attach.Do(func() { + n.linkEP.Attach(n) + }) } // setPromiscuousMode enables or disables promiscuous mode. @@ -712,6 +786,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { case permanentExpired, temporary: continue } + addrs = append(addrs, tcpip.ProtocolAddress{ Protocol: ref.protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -1009,6 +1084,15 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { return nil } +// isInGroup returns true if n has joined the multicast group addr. +func (n *NIC) isInGroup(addr tcpip.Address) bool { + n.mu.RLock() + joins := n.mu.mcastJoins[NetworkEndpointID{addr}] + n.mu.RUnlock() + + return joins != 0 +} + func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr @@ -1411,7 +1495,7 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { // // r's NIC must be read locked. func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { - return r.getKind() != permanentExpired || r.nic.mu.spoofing + return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing) } // decRef decrements the ref count and cleans up the endpoint once it reaches diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6eac16e16..fabc976a7 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -921,23 +921,38 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[id] - if nic == nil { + nic, ok := s.nics[id] + if !ok { return tcpip.ErrUnknownNICID } return nic.enable() } +// DisableNIC disables the given NIC. +func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.disable() +} + // CheckNIC checks if a NIC is usable. func (s *Stack) CheckNIC(id tcpip.NICID) bool { s.mu.RLock() + defer s.mu.RUnlock() + nic, ok := s.nics[id] - s.mu.RUnlock() - if ok { - return nic.linkEP.IsAttached() + if !ok { + return false } - return false + + return nic.enabled() } // NICAddressRanges returns a map of NICIDs to their associated subnets. @@ -989,7 +1004,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { for id, nic := range s.nics { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. - Running: nic.linkEP.IsAttached(), + Running: nic.enabled(), Promiscuous: nic.isPromiscuousMode(), Loopback: nic.isLoopback(), } @@ -1151,7 +1166,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { - if nic, ok := s.nics[id]; ok { + if nic, ok := s.nics[id]; ok && nic.enabled() { if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil } @@ -1161,7 +1176,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } - if nic, ok := s.nics[route.NIC]; ok { + if nic, ok := s.nics[route.NIC]; ok && nic.enabled() { if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { if len(remoteAddr) == 0 { // If no remote address was provided, then the route @@ -1614,6 +1629,18 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC return tcpip.ErrUnknownNICID } +// IsInGroup returns true if the NIC with ID nicID has joined the multicast +// group multicastAddr. +func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[nicID]; ok { + return nic.isInGroup(multicastAddr), nil + } + return false, tcpip.ErrUnknownNICID +} + // IPTables returns the stack's iptables. func (s *Stack) IPTables() iptables.IPTables { s.tablesMu.RLock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 7ba604442..eb6f7d1fc 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -33,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -509,6 +510,257 @@ func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr } } +func TestDisableUnknownNIC(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + } +} + +func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := loopback.New() + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + checkNIC := func(enabled bool) { + t.Helper() + + allNICInfo := s.NICInfo() + nicInfo, ok := allNICInfo[nicID] + if !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } else if nicInfo.Flags.Running != enabled { + t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled) + } + + if got := s.CheckNIC(nicID); got != enabled { + t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled) + } + } + + // NIC should initially report itself as disabled. + checkNIC(false) + + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + checkNIC(true) + + // If the NIC is not reporting a correct enabled status, we cannot trust the + // next check so end the test here. + if t.Failed() { + t.FailNow() + } + + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + checkNIC(false) +} + +func TestRoutesWithDisabledNIC(t *testing.T) { + const unspecifiedNIC = 0 + const nicID1 = 1 + const nicID2 = 2 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep1 := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + + addr1 := tcpip.Address("\x01") + if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + } + + ep2 := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + + addr2 := tcpip.Address("\x02") + if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, + {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, + }) + } + + // Test routes to odd address. + testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) + testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) + testRoute(t, s, nicID1, addr1, "\x05", addr1) + + // Test routes to even address. + testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) + testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) + testRoute(t, s, nicID2, addr2, "\x06", addr2) + + // Disabling NIC1 should result in no routes to odd addresses. Routes to even + // addresses should continue to be available as NIC2 is still enabled. + if err := s.DisableNIC(nicID1); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) + } + nic1Dst := tcpip.Address("\x05") + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + nic2Dst := tcpip.Address("\x06") + testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) + testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) + testRoute(t, s, nicID2, addr2, nic2Dst, addr2) + + // Disabling NIC2 should result in no routes to even addresses. No route + // should be available to any address as routes to odd addresses were made + // unavailable by disabling NIC1 above. + if err := s.DisableNIC(nicID2); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) + } + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) + + // Enabling NIC1 should make routes to odd addresses available again. Routes + // to even addresses should continue to be unavailable as NIC2 is still + // disabled. + if err := s.EnableNIC(nicID1); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID1, err) + } + testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) + testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) + testRoute(t, s, nicID1, addr1, nic1Dst, addr1) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) +} + +func TestRouteWritePacketWithDisabledNIC(t *testing.T) { + const unspecifiedNIC = 0 + const nicID1 = 1 + const nicID2 = 2 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + + addr1 := tcpip.Address("\x01") + if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + } + + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + + addr2 := tcpip.Address("\x02") + if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, + {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, + }) + } + + nic1Dst := tcpip.Address("\x05") + r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) + } + defer r1.Release() + + nic2Dst := tcpip.Address("\x06") + r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) + } + defer r2.Release() + + // If we failed to get routes r1 or r2, we cannot proceed with the test. + if t.Failed() { + t.FailNow() + } + + buf := buffer.View([]byte{1}) + testSend(t, r1, ep1, buf) + testSend(t, r2, ep2, buf) + + // Writes with Routes that use the disabled NIC1 should fail. + if err := s.DisableNIC(nicID1); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testSend(t, r2, ep2, buf) + + // Writes with Routes that use the disabled NIC2 should fail. + if err := s.DisableNIC(nicID2); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + + // Writes with Routes that use the re-enabled NIC1 should succeed. + // TODO(b/147015577): Should we instead completely invalidate all Routes that + // were bound to a disabled NIC at some point? + if err := s.EnableNIC(nicID1); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID1, err) + } + testSend(t, r1, ep1, buf) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) +} + func TestRoutes(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has @@ -2173,13 +2425,29 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { e := channel.New(0, 1280, test.linkAddr) s := stack.New(opts) - nicOpts := stack.NICOptions{Name: test.nicName} + nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true} if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) } - var expectedMainAddr tcpip.AddressWithPrefix + // A new disabled NIC should not have any address, even if auto generation + // was enabled. + allStackAddrs := s.AllAddresses() + allNICAddrs, ok := allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + } + // Enabling the NIC should attempt auto-generation of a link-local + // address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + var expectedMainAddr tcpip.AddressWithPrefix if test.shouldGen { expectedMainAddr = tcpip.AddressWithPrefix{ Address: test.expectedAddr, @@ -2609,6 +2877,111 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { } } +func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { + const nicID = 1 + + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + allStackAddrs := s.AllAddresses() + allNICAddrs, ok := allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + } + + // Enabling the NIC should add the IPv4 broadcast address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + allStackAddrs = s.AllAddresses() + allNICAddrs, ok = allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 1 { + t.Fatalf("got len(allNICAddrs) = %d, want = 1", l) + } + want := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 32, + }, + } + if allNICAddrs[0] != want { + t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want) + } + + // Disabling the NIC should remove the IPv4 broadcast address. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + allStackAddrs = s.AllAddresses() + allNICAddrs, ok = allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + } +} + +func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { + const nicID = 1 + + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + // Should not be in the IPv6 all-nodes multicast group yet because the NIC has + // not been enabled yet. + isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + } + + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress) + } + + // The all-nodes multicast group should be left when the NIC is disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + } +} + // TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC // was disabled have DAD performed on them when the NIC is enabled. func TestDoDADWhenNICEnabled(t *testing.T) { -- cgit v1.2.3 From 97c07242c37e56f6cfdc52036d554052ba95f2ae Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 21 Feb 2020 09:53:56 -0800 Subject: Use Route.MaxHeaderLength when constructing NDP RS Test: stack_test.TestRouterSolicitation PiperOrigin-RevId: 296454766 --- pkg/tcpip/stack/ndp.go | 2 +- pkg/tcpip/stack/ndp_test.go | 78 ++++++++++++++++++++++++++++++++++++++------- 2 files changed, 68 insertions(+), 12 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 19bd05aa3..f651871ce 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -1235,7 +1235,7 @@ func (ndp *ndpState) startSolicitingRouters() { } payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize) + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize) pkt := header.ICMPv6(hdr.Prepend(payloadSize)) pkt.SetType(header.ICMPv6RouterSolicit) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index f7b75b74e..6e9306d09 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -267,6 +267,17 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s } } +// channelLinkWithHeaderLength is a channel.Endpoint with a configurable +// header length. +type channelLinkWithHeaderLength struct { + *channel.Endpoint + headerLength uint16 +} + +func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { + return l.headerLength +} + // Check e to make sure that the event is for addr on nic with ID 1, and the // resolved flag set to resolved with the specified err. func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string { @@ -323,21 +334,46 @@ func TestDADDisabled(t *testing.T) { // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid // RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. +// This tests also validates the NDP NS packet that is transmitted. func TestDADResolve(t *testing.T) { const nicID = 1 tests := []struct { name string + linkHeaderLen uint16 dupAddrDetectTransmits uint8 retransTimer time.Duration expectedRetransmitTimer time.Duration }{ - {"1:1s:1s", 1, time.Second, time.Second}, - {"2:1s:1s", 2, time.Second, time.Second}, - {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second}, + { + name: "1:1s:1s", + dupAddrDetectTransmits: 1, + retransTimer: time.Second, + expectedRetransmitTimer: time.Second, + }, + { + name: "2:1s:1s", + linkHeaderLen: 1, + dupAddrDetectTransmits: 2, + retransTimer: time.Second, + expectedRetransmitTimer: time.Second, + }, + { + name: "1:2s:2s", + linkHeaderLen: 2, + dupAddrDetectTransmits: 1, + retransTimer: 2 * time.Second, + expectedRetransmitTimer: 2 * time.Second, + }, // 0s is an invalid RetransmitTimer timer and will be fixed to // the default RetransmitTimer value of 1s. - {"1:0s:1s", 1, 0, time.Second}, + { + name: "1:0s:1s", + linkHeaderLen: 3, + dupAddrDetectTransmits: 1, + retransTimer: 0, + expectedRetransmitTimer: time.Second, + }, } for _, test := range tests { @@ -356,10 +392,13 @@ func TestDADResolve(t *testing.T) { opts.NDPConfigs.RetransmitTimer = test.retransTimer opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits - e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { + if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -445,6 +484,10 @@ func TestDADResolve(t *testing.T) { checker.NDPNSTargetAddress(addr1), checker.NDPNSOptions(nil), )) + + if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + } } }) } @@ -3336,8 +3379,11 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { func TestRouterSolicitation(t *testing.T) { t.Parallel() + const nicID = 1 + tests := []struct { name string + linkHeaderLen uint16 maxRtrSolicit uint8 rtrSolicitInt time.Duration effectiveRtrSolicitInt time.Duration @@ -3354,6 +3400,7 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Two RS with delay", + linkHeaderLen: 1, maxRtrSolicit: 2, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3362,6 +3409,7 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Single RS without delay", + linkHeaderLen: 2, maxRtrSolicit: 1, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3370,6 +3418,7 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Two RS without delay and invalid zero interval", + linkHeaderLen: 3, maxRtrSolicit: 2, rtrSolicitInt: 0, effectiveRtrSolicitInt: 4 * time.Second, @@ -3407,8 +3456,11 @@ func TestRouterSolicitation(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired waitForPkt := func(timeout time.Duration) { t.Helper() ctx, _ := context.WithTimeout(context.Background(), timeout) @@ -3434,6 +3486,10 @@ func TestRouterSolicitation(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPRS(), ) + + if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + } } waitForNothing := func(timeout time.Duration) { t.Helper() @@ -3450,8 +3506,8 @@ func TestRouterSolicitation(t *testing.T) { MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, }, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } // Make sure each RS got sent at the right -- cgit v1.2.3 From a155a23480abfafe096ff50f2c4aaf2c215b6c44 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 21 Feb 2020 11:20:15 -0800 Subject: Attach LinkEndpoint to NetworkDispatcher immediately Tests: stack_test.TestAttachToLinkEndpointImmediately PiperOrigin-RevId: 296474068 --- pkg/tcpip/stack/nic.go | 25 +++++++------------ pkg/tcpip/stack/stack.go | 5 ++-- pkg/tcpip/stack/stack_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 18 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index b2be18e47..862954ab2 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -44,8 +44,7 @@ type NIC struct { linkEP LinkEndpoint context NICContext - stats NICStats - attach sync.Once + stats NICStats mu struct { sync.RWMutex @@ -141,6 +140,8 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{} } + nic.linkEP.Attach(nic) + return nic } @@ -200,14 +201,16 @@ func (n *NIC) disable() *tcpip.Error { } } - // TODO(b/147015577): Should n detach from its LinkEndpoint? - n.mu.enabled = false return nil } -// enable enables n. enable will attach the nic to its LinkEndpoint and -// join the IPv6 All-Nodes Multicast address (ff02::1). +// enable enables n. +// +// If the stack has IPv6 enabled, enable will join the IPv6 All-Nodes Multicast +// address (ff02::1), start DAD for permanent addresses, and start soliciting +// routers if the stack is not operating as a router. If the stack is also +// configured to auto-generate a link-local address, one will be generated. func (n *NIC) enable() *tcpip.Error { n.mu.RLock() enabled := n.mu.enabled @@ -225,8 +228,6 @@ func (n *NIC) enable() *tcpip.Error { n.mu.enabled = true - n.attachLinkEndpoint() - // Create an endpoint to receive broadcast packets on this interface. if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { @@ -321,14 +322,6 @@ func (n *NIC) becomeIPv6Host() { n.mu.ndp.startSolicitingRouters() } -// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it -// to start delivering packets. -func (n *NIC) attachLinkEndpoint() { - n.attach.Do(func() { - n.linkEP.Attach(n) - }) -} - // setPromiscuousMode enables or disables promiscuous mode. func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Lock() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index fabc976a7..f0ed76fbe 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -881,6 +881,8 @@ type NICOptions struct { // CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and // NICOptions. See the documentation on type NICOptions for details on how // NICs can be configured. +// +// LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher. func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -900,7 +902,6 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp } n := newNIC(s, id, opts.Name, ep, opts.Context) - s.nics[id] = n if !opts.Disabled { return n.enable() @@ -910,7 +911,7 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp } // CreateNIC creates a NIC with the provided id and LinkEndpoint and calls -// `LinkEndpoint.Attach` to start delivering packets to it. +// LinkEndpoint.Attach to bind ep with a NetworkDispatcher. func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index eb6f7d1fc..18016e7db 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -239,6 +239,23 @@ func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } +// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify +// that LinkEndpoint.Attach was called. +type linkEPWithMockedAttach struct { + stack.LinkEndpoint + attached bool +} + +// Attach implements stack.LinkEndpoint.Attach. +func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) { + l.LinkEndpoint.Attach(d) + l.attached = true +} + +func (l *linkEPWithMockedAttach) isAttached() bool { + return l.attached +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. @@ -510,6 +527,45 @@ func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr } } +// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to +// a NetworkDispatcher when the NIC is created. +func TestAttachToLinkEndpointImmediately(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + nicOpts stack.NICOptions + }{ + { + name: "Create enabled NIC", + nicOpts: stack.NICOptions{Disabled: false}, + }, + { + name: "Create disabled NIC", + nicOpts: stack.NICOptions{Disabled: true}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := linkEPWithMockedAttach{ + LinkEndpoint: loopback.New(), + } + + if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err) + } + if !e.isAttached() { + t.Fatalf("link endpoint not attached to a network disatcher") + } + }) + } +} + func TestDisableUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, -- cgit v1.2.3 From b8f56c79be40d9c75f4e2f279c9d821d1c1c3569 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Fri, 21 Feb 2020 15:41:56 -0800 Subject: Implement tap/tun device in vfs. PiperOrigin-RevId: 296526279 --- pkg/abi/linux/BUILD | 1 + pkg/abi/linux/ioctl.go | 26 ++ pkg/abi/linux/ioctl_tun.go | 29 ++ pkg/sentry/fs/dev/BUILD | 5 + pkg/sentry/fs/dev/dev.go | 10 +- pkg/sentry/fs/dev/net_tun.go | 170 +++++++++++ pkg/syserror/syserror.go | 1 + pkg/tcpip/buffer/view.go | 6 + pkg/tcpip/link/channel/BUILD | 1 + pkg/tcpip/link/channel/channel.go | 180 +++++++++--- pkg/tcpip/link/tun/BUILD | 18 +- pkg/tcpip/link/tun/device.go | 352 +++++++++++++++++++++++ pkg/tcpip/link/tun/protocol.go | 56 ++++ pkg/tcpip/stack/nic.go | 32 +++ pkg/tcpip/stack/stack.go | 39 +++ test/syscalls/BUILD | 2 + test/syscalls/linux/BUILD | 30 ++ test/syscalls/linux/dev.cc | 7 + test/syscalls/linux/socket_netlink_route_util.cc | 163 +++++++++++ test/syscalls/linux/socket_netlink_route_util.h | 55 ++++ test/syscalls/linux/tuntap.cc | 346 ++++++++++++++++++++++ 21 files changed, 1490 insertions(+), 39 deletions(-) create mode 100644 pkg/abi/linux/ioctl_tun.go create mode 100644 pkg/sentry/fs/dev/net_tun.go create mode 100644 pkg/tcpip/link/tun/device.go create mode 100644 pkg/tcpip/link/tun/protocol.go create mode 100644 test/syscalls/linux/socket_netlink_route_util.cc create mode 100644 test/syscalls/linux/socket_netlink_route_util.h create mode 100644 test/syscalls/linux/tuntap.cc (limited to 'pkg/tcpip/stack') diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index a89f34d4b..322d1ccc4 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -30,6 +30,7 @@ go_library( "futex.go", "inotify.go", "ioctl.go", + "ioctl_tun.go", "ip.go", "ipc.go", "limits.go", diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 0e18db9ef..2062e6a4b 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -72,3 +72,29 @@ const ( SIOCGMIIPHY = 0x8947 SIOCGMIIREG = 0x8948 ) + +// ioctl(2) directions. Used to calculate requests number. +// Constants from asm-generic/ioctl.h. +const ( + _IOC_NONE = 0 + _IOC_WRITE = 1 + _IOC_READ = 2 +) + +// Constants from asm-generic/ioctl.h. +const ( + _IOC_NRBITS = 8 + _IOC_TYPEBITS = 8 + _IOC_SIZEBITS = 14 + _IOC_DIRBITS = 2 + + _IOC_NRSHIFT = 0 + _IOC_TYPESHIFT = _IOC_NRSHIFT + _IOC_NRBITS + _IOC_SIZESHIFT = _IOC_TYPESHIFT + _IOC_TYPEBITS + _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS +) + +// IOC outputs the result of _IOC macro in asm-generic/ioctl.h. +func IOC(dir, typ, nr, size uint32) uint32 { + return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT +} diff --git a/pkg/abi/linux/ioctl_tun.go b/pkg/abi/linux/ioctl_tun.go new file mode 100644 index 000000000..c59c9c136 --- /dev/null +++ b/pkg/abi/linux/ioctl_tun.go @@ -0,0 +1,29 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// ioctl(2) request numbers from linux/if_tun.h +var ( + TUNSETIFF = IOC(_IOC_WRITE, 'T', 202, 4) + TUNGETIFF = IOC(_IOC_READ, 'T', 210, 4) +) + +// Flags from net/if_tun.h +const ( + IFF_TUN = 0x0001 + IFF_TAP = 0x0002 + IFF_NO_PI = 0x1000 + IFF_NOFILTER = 0x1000 +) diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 4c4b7d5cc..9b6bb26d0 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -9,6 +9,7 @@ go_library( "device.go", "fs.go", "full.go", + "net_tun.go", "null.go", "random.go", "tty.go", @@ -19,15 +20,19 @@ go_library( "//pkg/context", "//pkg/rand", "//pkg/safemem", + "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/ramfs", "//pkg/sentry/fs/tmpfs", + "//pkg/sentry/kernel", "//pkg/sentry/memmap", "//pkg/sentry/mm", "//pkg/sentry/pgalloc", + "//pkg/sentry/socket/netstack", "//pkg/syserror", + "//pkg/tcpip/link/tun", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go index 35bd23991..7e66c29b0 100644 --- a/pkg/sentry/fs/dev/dev.go +++ b/pkg/sentry/fs/dev/dev.go @@ -66,8 +66,8 @@ func newMemDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo }) } -func newDirectory(ctx context.Context, msrc *fs.MountSource) *fs.Inode { - iops := ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0555)) +func newDirectory(ctx context.Context, contents map[string]*fs.Inode, msrc *fs.MountSource) *fs.Inode { + iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) return fs.NewInode(ctx, iops, msrc, fs.StableAttr{ DeviceID: devDevice.DeviceID(), InodeID: devDevice.NextIno(), @@ -111,7 +111,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { // A devpts is typically mounted at /dev/pts to provide // pseudoterminal support. Place an empty directory there for // the devpts to be mounted over. - "pts": newDirectory(ctx, msrc), + "pts": newDirectory(ctx, nil, msrc), // Similarly, applications expect a ptmx device at /dev/ptmx // connected to the terminals provided by /dev/pts/. Rather // than creating a device directly (which requires a hairy @@ -124,6 +124,10 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { "ptmx": newSymlink(ctx, "pts/ptmx", msrc), "tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor), + + "net": newDirectory(ctx, map[string]*fs.Inode{ + "tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor), + }, msrc), } iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go new file mode 100644 index 000000000..755644488 --- /dev/null +++ b/pkg/sentry/fs/dev/net_tun.go @@ -0,0 +1,170 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/link/tun" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + netTunDevMajor = 10 + netTunDevMinor = 200 +) + +// +stateify savable +type netTunInodeOperations struct { + fsutil.InodeGenericChecker `state:"nosave"` + fsutil.InodeNoExtendedAttributes `state:"nosave"` + fsutil.InodeNoopAllocate `state:"nosave"` + fsutil.InodeNoopRelease `state:"nosave"` + fsutil.InodeNoopTruncate `state:"nosave"` + fsutil.InodeNoopWriteOut `state:"nosave"` + fsutil.InodeNotDirectory `state:"nosave"` + fsutil.InodeNotMappable `state:"nosave"` + fsutil.InodeNotSocket `state:"nosave"` + fsutil.InodeNotSymlink `state:"nosave"` + fsutil.InodeVirtual `state:"nosave"` + + fsutil.InodeSimpleAttributes +} + +var _ fs.InodeOperations = (*netTunInodeOperations)(nil) + +func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *netTunInodeOperations { + return &netTunInodeOperations{ + InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC), + } +} + +// GetFile implements fs.InodeOperations.GetFile. +func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil +} + +// +stateify savable +type netTunFileOperations struct { + fsutil.FileNoSeek `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + device tun.Device +} + +var _ fs.FileOperations = (*netTunFileOperations)(nil) + +// Release implements fs.FileOperations.Release. +func (fops *netTunFileOperations) Release() { + fops.device.Release() +} + +// Ioctl implements fs.FileOperations.Ioctl. +func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + request := args[1].Uint() + data := args[2].Pointer() + + switch request { + case linux.TUNSETIFF: + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("Ioctl should be called from a task context") + } + if !t.HasCapability(linux.CAP_NET_ADMIN) { + return 0, syserror.EPERM + } + stack, ok := t.NetworkContext().(*netstack.Stack) + if !ok { + return 0, syserror.EINVAL + } + + var req linux.IFReq + if _, err := usermem.CopyObjectIn(ctx, io, data, &req, usermem.IOOpts{ + AddressSpaceActive: true, + }); err != nil { + return 0, err + } + flags := usermem.ByteOrder.Uint16(req.Data[:]) + return 0, fops.device.SetIff(stack.Stack, req.Name(), flags) + + case linux.TUNGETIFF: + var req linux.IFReq + + copy(req.IFName[:], fops.device.Name()) + + // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when + // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c. + flags := fops.device.Flags() | linux.IFF_NOFILTER + usermem.ByteOrder.PutUint16(req.Data[:], flags) + + _, err := usermem.CopyObjectOut(ctx, io, data, &req, usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err + + default: + return 0, syserror.ENOTTY + } +} + +// Write implements fs.FileOperations.Write. +func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + data := make([]byte, src.NumBytes()) + if _, err := src.CopyIn(ctx, data); err != nil { + return 0, err + } + return fops.device.Write(data) +} + +// Read implements fs.FileOperations.Read. +func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + data, err := fops.device.Read() + if err != nil { + return 0, err + } + n, err := dst.CopyOut(ctx, data) + if n > 0 && n < len(data) { + // Not an error for partial copying. Packet truncated. + err = nil + } + return int64(n), err +} + +// Readiness implements watier.Waitable.Readiness. +func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { + return fops.device.Readiness(mask) +} + +// EventRegister implements watier.Waitable.EventRegister. +func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fops.device.EventRegister(e, mask) +} + +// EventUnregister implements watier.Waitable.EventUnregister. +func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) { + fops.device.EventUnregister(e) +} diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go index 2269f6237..4b5a0fca6 100644 --- a/pkg/syserror/syserror.go +++ b/pkg/syserror/syserror.go @@ -29,6 +29,7 @@ var ( EACCES = error(syscall.EACCES) EAGAIN = error(syscall.EAGAIN) EBADF = error(syscall.EBADF) + EBADFD = error(syscall.EBADFD) EBUSY = error(syscall.EBUSY) ECHILD = error(syscall.ECHILD) ECONNREFUSED = error(syscall.ECONNREFUSED) diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 150310c11..17e94c562 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -156,3 +156,9 @@ func (vv *VectorisedView) Append(vv2 VectorisedView) { vv.views = append(vv.views, vv2.views...) vv.size += vv2.size } + +// AppendView appends the given view into this vectorised view. +func (vv *VectorisedView) AppendView(v View) { + vv.views = append(vv.views, v) + vv.size += len(v) +} diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 3974c464e..b8b93e78e 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -7,6 +7,7 @@ go_library( srcs = ["channel.go"], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 78d447acd..5944ba190 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -20,6 +20,7 @@ package channel import ( "context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -33,6 +34,118 @@ type PacketInfo struct { Route stack.Route } +// Notification is the interface for receiving notification from the packet +// queue. +type Notification interface { + // WriteNotify will be called when a write happens to the queue. + WriteNotify() +} + +// NotificationHandle is an opaque handle to the registered notification target. +// It can be used to unregister the notification when no longer interested. +// +// +stateify savable +type NotificationHandle struct { + n Notification +} + +type queue struct { + // mu protects fields below. + mu sync.RWMutex + // c is the outbound packet channel. Sending to c should hold mu. + c chan PacketInfo + numWrite int + numRead int + notify []*NotificationHandle +} + +func (q *queue) Close() { + close(q.c) +} + +func (q *queue) Read() (PacketInfo, bool) { + q.mu.Lock() + defer q.mu.Unlock() + select { + case p := <-q.c: + q.numRead++ + return p, true + default: + return PacketInfo{}, false + } +} + +func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) { + // We have to receive from channel without holding the lock, since it can + // block indefinitely. This will cause a window that numWrite - numRead + // produces a larger number, but won't go to negative. numWrite >= numRead + // still holds. + select { + case pkt := <-q.c: + q.mu.Lock() + defer q.mu.Unlock() + q.numRead++ + return pkt, true + case <-ctx.Done(): + return PacketInfo{}, false + } +} + +func (q *queue) Write(p PacketInfo) bool { + wrote := false + + // It's important to make sure nobody can see numWrite until we increment it, + // so numWrite >= numRead holds. + q.mu.Lock() + select { + case q.c <- p: + wrote = true + q.numWrite++ + default: + } + notify := q.notify + q.mu.Unlock() + + if wrote { + // Send notification outside of lock. + for _, h := range notify { + h.n.WriteNotify() + } + } + return wrote +} + +func (q *queue) Num() int { + q.mu.RLock() + defer q.mu.RUnlock() + n := q.numWrite - q.numRead + if n < 0 { + panic("numWrite < numRead") + } + return n +} + +func (q *queue) AddNotify(notify Notification) *NotificationHandle { + q.mu.Lock() + defer q.mu.Unlock() + h := &NotificationHandle{n: notify} + q.notify = append(q.notify, h) + return h +} + +func (q *queue) RemoveNotify(handle *NotificationHandle) { + q.mu.Lock() + defer q.mu.Unlock() + // Make a copy, since we reads the array outside of lock when notifying. + notify := make([]*NotificationHandle, 0, len(q.notify)) + for _, h := range q.notify { + if h != handle { + notify = append(notify, h) + } + } + q.notify = notify +} + // Endpoint is link layer endpoint that stores outbound packets in a channel // and allows injection of inbound packets. type Endpoint struct { @@ -41,14 +154,16 @@ type Endpoint struct { linkAddr tcpip.LinkAddress LinkEPCapabilities stack.LinkEndpointCapabilities - // c is where outbound packets are queued. - c chan PacketInfo + // Outbound packet queue. + q *queue } // New creates a new channel endpoint. func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { return &Endpoint{ - c: make(chan PacketInfo, size), + q: &queue{ + c: make(chan PacketInfo, size), + }, mtu: mtu, linkAddr: linkAddr, } @@ -57,43 +172,36 @@ func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { // Close closes e. Further packet injections will panic. Reads continue to // succeed until all packets are read. func (e *Endpoint) Close() { - close(e.c) + e.q.Close() } -// Read does non-blocking read for one packet from the outbound packet queue. +// Read does non-blocking read one packet from the outbound packet queue. func (e *Endpoint) Read() (PacketInfo, bool) { - select { - case pkt := <-e.c: - return pkt, true - default: - return PacketInfo{}, false - } + return e.q.Read() } // ReadContext does blocking read for one packet from the outbound packet queue. // It can be cancelled by ctx, and in this case, it returns false. func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) { - select { - case pkt := <-e.c: - return pkt, true - case <-ctx.Done(): - return PacketInfo{}, false - } + return e.q.ReadContext(ctx) } // Drain removes all outbound packets from the channel and counts them. func (e *Endpoint) Drain() int { c := 0 for { - select { - case <-e.c: - c++ - default: + if _, ok := e.Read(); !ok { return c } + c++ } } +// NumQueued returns the number of packet queued for outbound. +func (e *Endpoint) NumQueued() int { + return e.q.Num() +} + // InjectInbound injects an inbound packet. func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) @@ -155,10 +263,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne Route: route, } - select { - case e.c <- p: - default: - } + e.q.Write(p) return nil } @@ -171,7 +276,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac route.Release() payloadView := pkts[0].Data.ToView() n := 0 -packetLoop: for _, pkt := range pkts { off := pkt.DataOffset size := pkt.DataSize @@ -185,12 +289,10 @@ packetLoop: Route: route, } - select { - case e.c <- p: - n++ - default: - break packetLoop + if !e.q.Write(p) { + break } + n++ } return n, nil @@ -204,13 +306,21 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { GSO: nil, } - select { - case e.c <- p: - default: - } + e.q.Write(p) return nil } // Wait implements stack.LinkEndpoint.Wait. func (*Endpoint) Wait() {} + +// AddNotify adds a notification target for receiving event about outgoing +// packets. +func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle { + return e.q.AddNotify(notify) +} + +// RemoveNotify removes handle from the list of notification targets. +func (e *Endpoint) RemoveNotify(handle *NotificationHandle) { + e.q.RemoveNotify(handle) +} diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index e5096ea38..e0db6cf54 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -4,6 +4,22 @@ package(licenses = ["notice"]) go_library( name = "tun", - srcs = ["tun_unsafe.go"], + srcs = [ + "device.go", + "protocol.go", + "tun_unsafe.go", + ], visibility = ["//visibility:public"], + deps = [ + "//pkg/abi/linux", + "//pkg/refs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], ) diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go new file mode 100644 index 000000000..6ff47a742 --- /dev/null +++ b/pkg/tcpip/link/tun/device.go @@ -0,0 +1,352 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + // drivers/net/tun.c:tun_net_init() + defaultDevMtu = 1500 + + // Queue length for outbound packet, arriving at fd side for read. Overflow + // causes packet drops. gVisor implementation-specific. + defaultDevOutQueueLen = 1024 +) + +var zeroMAC [6]byte + +// Device is an opened /dev/net/tun device. +// +// +stateify savable +type Device struct { + waiter.Queue + + mu sync.RWMutex `state:"nosave"` + endpoint *tunEndpoint + notifyHandle *channel.NotificationHandle + flags uint16 +} + +// beforeSave is invoked by stateify. +func (d *Device) beforeSave() { + d.mu.Lock() + defer d.mu.Unlock() + // TODO(b/110961832): Restore the device to stack. At this moment, the stack + // is not savable. + if d.endpoint != nil { + panic("/dev/net/tun does not support save/restore when a device is associated with it.") + } +} + +// Release implements fs.FileOperations.Release. +func (d *Device) Release() { + d.mu.Lock() + defer d.mu.Unlock() + + // Decrease refcount if there is an endpoint associated with this file. + if d.endpoint != nil { + d.endpoint.RemoveNotify(d.notifyHandle) + d.endpoint.DecRef() + d.endpoint = nil + } +} + +// SetIff services TUNSETIFF ioctl(2) request. +func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { + d.mu.Lock() + defer d.mu.Unlock() + + if d.endpoint != nil { + return syserror.EINVAL + } + + // Input validations. + isTun := flags&linux.IFF_TUN != 0 + isTap := flags&linux.IFF_TAP != 0 + supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI) + if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 { + return syserror.EINVAL + } + + prefix := "tun" + if isTap { + prefix = "tap" + } + + endpoint, err := attachOrCreateNIC(s, name, prefix) + if err != nil { + return syserror.EINVAL + } + + d.endpoint = endpoint + d.notifyHandle = d.endpoint.AddNotify(d) + d.flags = flags + return nil +} + +func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error) { + for { + // 1. Try to attach to an existing NIC. + if name != "" { + if nic, found := s.GetNICByName(name); found { + endpoint, ok := nic.LinkEndpoint().(*tunEndpoint) + if !ok { + // Not a NIC created by tun device. + return nil, syserror.EOPNOTSUPP + } + if !endpoint.TryIncRef() { + // Race detected: NIC got deleted in between. + continue + } + return endpoint, nil + } + } + + // 2. Creating a new NIC. + id := tcpip.NICID(s.UniqueID()) + endpoint := &tunEndpoint{ + Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""), + stack: s, + nicID: id, + name: name, + } + if endpoint.name == "" { + endpoint.name = fmt.Sprintf("%s%d", prefix, id) + } + err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{ + Name: endpoint.name, + }) + switch err { + case nil: + return endpoint, nil + case tcpip.ErrDuplicateNICID: + // Race detected: A NIC has been created in between. + continue + default: + return nil, syserror.EINVAL + } + } +} + +// Write inject one inbound packet to the network interface. +func (d *Device) Write(data []byte) (int64, error) { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint == nil { + return 0, syserror.EBADFD + } + if !endpoint.IsAttached() { + return 0, syserror.EIO + } + + dataLen := int64(len(data)) + + // Packet information. + var pktInfoHdr PacketInfoHeader + if !d.hasFlags(linux.IFF_NO_PI) { + if len(data) < PacketInfoHeaderSize { + // Ignore bad packet. + return dataLen, nil + } + pktInfoHdr = PacketInfoHeader(data[:PacketInfoHeaderSize]) + data = data[PacketInfoHeaderSize:] + } + + // Ethernet header (TAP only). + var ethHdr header.Ethernet + if d.hasFlags(linux.IFF_TAP) { + if len(data) < header.EthernetMinimumSize { + // Ignore bad packet. + return dataLen, nil + } + ethHdr = header.Ethernet(data[:header.EthernetMinimumSize]) + data = data[header.EthernetMinimumSize:] + } + + // Try to determine network protocol number, default zero. + var protocol tcpip.NetworkProtocolNumber + switch { + case pktInfoHdr != nil: + protocol = pktInfoHdr.Protocol() + case ethHdr != nil: + protocol = ethHdr.Type() + } + + // Try to determine remote link address, default zero. + var remote tcpip.LinkAddress + switch { + case ethHdr != nil: + remote = ethHdr.SourceAddress() + default: + remote = tcpip.LinkAddress(zeroMAC[:]) + } + + pkt := tcpip.PacketBuffer{ + Data: buffer.View(data).ToVectorisedView(), + } + if ethHdr != nil { + pkt.LinkHeader = buffer.View(ethHdr) + } + endpoint.InjectLinkAddr(protocol, remote, pkt) + return dataLen, nil +} + +// Read reads one outgoing packet from the network interface. +func (d *Device) Read() ([]byte, error) { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint == nil { + return nil, syserror.EBADFD + } + + for { + info, ok := endpoint.Read() + if !ok { + return nil, syserror.ErrWouldBlock + } + + v, ok := d.encodePkt(&info) + if !ok { + // Ignore unsupported packet. + continue + } + return v, nil + } +} + +// encodePkt encodes packet for fd side. +func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { + var vv buffer.VectorisedView + + // Packet information. + if !d.hasFlags(linux.IFF_NO_PI) { + hdr := make(PacketInfoHeader, PacketInfoHeaderSize) + hdr.Encode(&PacketInfoFields{ + Protocol: info.Proto, + }) + vv.AppendView(buffer.View(hdr)) + } + + // If the packet does not already have link layer header, and the route + // does not exist, we can't compute it. This is possibly a raw packet, tun + // device doesn't support this at the moment. + if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" { + return nil, false + } + + // Ethernet header (TAP only). + if d.hasFlags(linux.IFF_TAP) { + // Add ethernet header if not provided. + if info.Pkt.LinkHeader == nil { + hdr := &header.EthernetFields{ + SrcAddr: info.Route.LocalLinkAddress, + DstAddr: info.Route.RemoteLinkAddress, + Type: info.Proto, + } + if hdr.SrcAddr == "" { + hdr.SrcAddr = d.endpoint.LinkAddress() + } + + eth := make(header.Ethernet, header.EthernetMinimumSize) + eth.Encode(hdr) + vv.AppendView(buffer.View(eth)) + } else { + vv.AppendView(info.Pkt.LinkHeader) + } + } + + // Append upper headers. + vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):])) + // Append data payload. + vv.Append(info.Pkt.Data) + + return vv.ToView(), true +} + +// Name returns the name of the attached network interface. Empty string if +// unattached. +func (d *Device) Name() string { + d.mu.RLock() + defer d.mu.RUnlock() + if d.endpoint != nil { + return d.endpoint.name + } + return "" +} + +// Flags returns the flags set for d. Zero value if unset. +func (d *Device) Flags() uint16 { + d.mu.RLock() + defer d.mu.RUnlock() + return d.flags +} + +func (d *Device) hasFlags(flags uint16) bool { + return d.flags&flags == flags +} + +// Readiness implements watier.Waitable.Readiness. +func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask { + if mask&waiter.EventIn != 0 { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint != nil && endpoint.NumQueued() == 0 { + mask &= ^waiter.EventIn + } + } + return mask & (waiter.EventIn | waiter.EventOut) +} + +// WriteNotify implements channel.Notification.WriteNotify. +func (d *Device) WriteNotify() { + d.Notify(waiter.EventIn) +} + +// tunEndpoint is the link endpoint for the NIC created by the tun device. +// +// It is ref-counted as multiple opening files can attach to the same NIC. +// The last owner is responsible for deleting the NIC. +type tunEndpoint struct { + *channel.Endpoint + + refs.AtomicRefCount + + stack *stack.Stack + nicID tcpip.NICID + name string +} + +// DecRef decrements refcount of e, removes NIC if refcount goes to 0. +func (e *tunEndpoint) DecRef() { + e.DecRefWithDestructor(func() { + e.stack.RemoveNIC(e.nicID) + }) +} diff --git a/pkg/tcpip/link/tun/protocol.go b/pkg/tcpip/link/tun/protocol.go new file mode 100644 index 000000000..89d9d91a9 --- /dev/null +++ b/pkg/tcpip/link/tun/protocol.go @@ -0,0 +1,56 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun + +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // PacketInfoHeaderSize is the size of the packet information header. + PacketInfoHeaderSize = 4 + + offsetFlags = 0 + offsetProtocol = 2 +) + +// PacketInfoFields contains fields sent through the wire if IFF_NO_PI flag is +// not set. +type PacketInfoFields struct { + Flags uint16 + Protocol tcpip.NetworkProtocolNumber +} + +// PacketInfoHeader is the wire representation of the packet information sent if +// IFF_NO_PI flag is not set. +type PacketInfoHeader []byte + +// Encode encodes f into h. +func (h PacketInfoHeader) Encode(f *PacketInfoFields) { + binary.BigEndian.PutUint16(h[offsetFlags:][:2], f.Flags) + binary.BigEndian.PutUint16(h[offsetProtocol:][:2], uint16(f.Protocol)) +} + +// Flags returns the flag field in h. +func (h PacketInfoHeader) Flags() uint16 { + return binary.BigEndian.Uint16(h[offsetFlags:]) +} + +// Protocol returns the protocol field in h. +func (h PacketInfoHeader) Protocol() tcpip.NetworkProtocolNumber { + return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(h[offsetProtocol:])) +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 862954ab2..46d3a6646 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -298,6 +298,33 @@ func (n *NIC) enable() *tcpip.Error { return nil } +// remove detaches NIC from the link endpoint, and marks existing referenced +// network endpoints expired. This guarantees no packets between this NIC and +// the network stack. +func (n *NIC) remove() *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + // Detach from link endpoint, so no packet comes in. + n.linkEP.Attach(nil) + + // Remove permanent and permanentTentative addresses, so no packet goes out. + var errs []*tcpip.Error + for nid, ref := range n.mu.endpoints { + switch ref.getKind() { + case permanentTentative, permanent: + if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil { + errs = append(errs, err) + } + } + } + if len(errs) > 0 { + return errs[0] + } + + return nil +} + // becomeIPv6Router transitions n into an IPv6 router. // // When transitioning into an IPv6 router, host-only state (NDP discovered @@ -1302,6 +1329,11 @@ func (n *NIC) Stack() *Stack { return n.stack } +// LinkEndpoint returns the link endpoint of n. +func (n *NIC) LinkEndpoint() LinkEndpoint { + return n.linkEP +} + // isAddrTentative returns true if addr is tentative on n. // // Note that if addr is not associated with n, then this function will return diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index f0ed76fbe..900dd46c5 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -916,6 +916,18 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } +// GetNICByName gets the NIC specified by name. +func (s *Stack) GetNICByName(name string) (*NIC, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, nic := range s.nics { + if nic.Name() == name { + return nic, true + } + } + return nil, false +} + // EnableNIC enables the given NIC so that the link-layer endpoint can start // delivering packets to it. func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { @@ -956,6 +968,33 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return nic.enabled() } +// RemoveNIC removes NIC and all related routes from the network stack. +func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + delete(s.nics, id) + + // Remove routes in-place. n tracks the number of routes written. + n := 0 + for i, r := range s.routeTable { + if r.NIC != id { + // Keep this route. + if i > n { + s.routeTable[n] = r + } + n++ + } + } + s.routeTable = s.routeTable[:n] + + return nic.remove() +} + // NICAddressRanges returns a map of NICIDs to their associated subnets. func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index d1977d4de..3518e862d 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -678,6 +678,8 @@ syscall_test( test = "//test/syscalls/linux:truncate_test", ) +syscall_test(test = "//test/syscalls/linux:tuntap_test") + syscall_test(test = "//test/syscalls/linux:udp_bind_test") syscall_test( diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index aa303af84..704bae17b 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -131,6 +131,17 @@ cc_library( ], ) +cc_library( + name = "socket_netlink_route_util", + testonly = 1, + srcs = ["socket_netlink_route_util.cc"], + hdrs = ["socket_netlink_route_util.h"], + deps = [ + ":socket_netlink_util", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "socket_test_util", testonly = 1, @@ -3430,6 +3441,25 @@ cc_binary( ], ) +cc_binary( + name = "tuntap_test", + testonly = 1, + srcs = ["tuntap.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + gtest, + "//test/syscalls/linux:socket_netlink_route_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "udp_socket_test_cases", testonly = 1, diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc index 4dd302eed..4e473268c 100644 --- a/test/syscalls/linux/dev.cc +++ b/test/syscalls/linux/dev.cc @@ -153,6 +153,13 @@ TEST(DevTest, TTYExists) { EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); } +TEST(DevTest, NetTunExists) { + struct stat statbuf = {}; + ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallSucceeds()); + // Check that it's a character device with rw-rw-rw- permissions. + EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc new file mode 100644 index 000000000..53eb3b6b2 --- /dev/null +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -0,0 +1,163 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_netlink_route_util.h" + +#include +#include +#include + +#include "absl/types/optional.h" +#include "test/syscalls/linux/socket_netlink_util.h" + +namespace gvisor { +namespace testing { +namespace { + +constexpr uint32_t kSeq = 12345; + +} // namespace + +PosixError DumpLinks( + const FileDescriptor& fd, uint32_t seq, + const std::function& fn) { + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = seq; + req.ifm.ifi_family = AF_UNSPEC; + + return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false); +} + +PosixErrorOr> DumpLinks() { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + std::vector links; + RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { + if (hdr->nlmsg_type != RTM_NEWLINK || + hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) { + return; + } + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); + if (rta == nullptr) { + // Ignore links that do not have a name. + return; + } + + links.emplace_back(); + links.back().index = msg->ifi_index; + links.back().type = msg->ifi_type; + links.back().name = + std::string(reinterpret_cast(RTA_DATA(rta))); + })); + return links; +} + +PosixErrorOr> FindLoopbackLink() { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + for (const auto& link : links) { + if (link.type == ARPHRD_LOOPBACK) { + return absl::optional(link); + } + } + return absl::optional(); +} + +PosixError LinkAddLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifaddrmsg ifaddr; + char attrbuf[512]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); + req.hdr.nlmsg_type = RTM_NEWADDR; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifaddr.ifa_index = index; + req.ifaddr.ifa_family = family; + req.ifaddr.ifa_prefixlen = prefixlen; + + struct rtattr* rta = reinterpret_cast( + reinterpret_cast(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFA_LOCAL; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifinfo; + char pad[NLMSG_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo)); + req.hdr.nlmsg_type = RTM_NEWLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifinfo.ifi_index = index; + req.ifinfo.ifi_flags = flags; + req.ifinfo.ifi_change = change; + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +PosixError LinkSetMacAddr(int index, const void* addr, int addrlen) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifinfo; + char attrbuf[512]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo)); + req.hdr.nlmsg_type = RTM_NEWLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifinfo.ifi_index = index; + + struct rtattr* rta = reinterpret_cast( + reinterpret_cast(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFLA_ADDRESS; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h new file mode 100644 index 000000000..2c018e487 --- /dev/null +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -0,0 +1,55 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ + +#include +#include + +#include + +#include "absl/types/optional.h" +#include "test/syscalls/linux/socket_netlink_util.h" + +namespace gvisor { +namespace testing { + +struct Link { + int index; + int16_t type; + std::string name; +}; + +PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq, + const std::function& fn); + +PosixErrorOr> DumpLinks(); + +PosixErrorOr> FindLoopbackLink(); + +// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface. +PosixError LinkAddLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen); + +// LinkChangeFlags changes interface flags. E.g. IFF_UP. +PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change); + +// LinkSetMacAddr sets IFLA_ADDRESS attribute of the interface. +PosixError LinkSetMacAddr(int index, const void* addr, int addrlen); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc new file mode 100644 index 000000000..f6ac9d7b8 --- /dev/null +++ b/test/syscalls/linux/tuntap.cc @@ -0,0 +1,346 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_split.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { +namespace { + +constexpr int kIPLen = 4; + +constexpr const char kDevNetTun[] = "/dev/net/tun"; +constexpr const char kTapName[] = "tap0"; + +constexpr const uint8_t kMacA[ETH_ALEN] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}; +constexpr const uint8_t kMacB[ETH_ALEN] = {0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}; + +PosixErrorOr> DumpLinkNames() { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + std::set names; + for (const auto& link : links) { + names.emplace(link.name); + } + return names; +} + +PosixErrorOr> GetLinkByName(const std::string& name) { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + for (const auto& link : links) { + if (link.name == name) { + return absl::optional(link); + } + } + return absl::optional(); +} + +struct pihdr { + uint16_t pi_flags; + uint16_t pi_protocol; +} __attribute__((packed)); + +struct ping_pkt { + pihdr pi; + struct ethhdr eth; + struct iphdr ip; + struct icmphdr icmp; + char payload[64]; +} __attribute__((packed)); + +ping_pkt CreatePingPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, + const uint8_t dstmac[ETH_ALEN], const char* dstip) { + ping_pkt pkt = {}; + + pkt.pi.pi_protocol = htons(ETH_P_IP); + + memcpy(pkt.eth.h_dest, dstmac, sizeof(pkt.eth.h_dest)); + memcpy(pkt.eth.h_source, srcmac, sizeof(pkt.eth.h_source)); + pkt.eth.h_proto = htons(ETH_P_IP); + + pkt.ip.ihl = 5; + pkt.ip.version = 4; + pkt.ip.tos = 0; + pkt.ip.tot_len = htons(sizeof(struct iphdr) + sizeof(struct icmphdr) + + sizeof(pkt.payload)); + pkt.ip.id = 1; + pkt.ip.frag_off = 1 << 6; // Do not fragment + pkt.ip.ttl = 64; + pkt.ip.protocol = IPPROTO_ICMP; + inet_pton(AF_INET, dstip, &pkt.ip.daddr); + inet_pton(AF_INET, srcip, &pkt.ip.saddr); + pkt.ip.check = IPChecksum(pkt.ip); + + pkt.icmp.type = ICMP_ECHO; + pkt.icmp.code = 0; + pkt.icmp.checksum = 0; + pkt.icmp.un.echo.sequence = 1; + pkt.icmp.un.echo.id = 1; + + strncpy(pkt.payload, "abcd", sizeof(pkt.payload)); + pkt.icmp.checksum = ICMPChecksum(pkt.icmp, pkt.payload, sizeof(pkt.payload)); + + return pkt; +} + +struct arp_pkt { + pihdr pi; + struct ethhdr eth; + struct arphdr arp; + uint8_t arp_sha[ETH_ALEN]; + uint8_t arp_spa[kIPLen]; + uint8_t arp_tha[ETH_ALEN]; + uint8_t arp_tpa[kIPLen]; +} __attribute__((packed)); + +std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, + const uint8_t dstmac[ETH_ALEN], const char* dstip) { + std::string buffer; + buffer.resize(sizeof(arp_pkt)); + + arp_pkt* pkt = reinterpret_cast(&buffer[0]); + { + pkt->pi.pi_protocol = htons(ETH_P_ARP); + + memcpy(pkt->eth.h_dest, kMacA, sizeof(pkt->eth.h_dest)); + memcpy(pkt->eth.h_source, kMacB, sizeof(pkt->eth.h_source)); + pkt->eth.h_proto = htons(ETH_P_ARP); + + pkt->arp.ar_hrd = htons(ARPHRD_ETHER); + pkt->arp.ar_pro = htons(ETH_P_IP); + pkt->arp.ar_hln = ETH_ALEN; + pkt->arp.ar_pln = kIPLen; + pkt->arp.ar_op = htons(ARPOP_REPLY); + + memcpy(pkt->arp_sha, srcmac, sizeof(pkt->arp_sha)); + inet_pton(AF_INET, srcip, pkt->arp_spa); + memcpy(pkt->arp_tha, dstmac, sizeof(pkt->arp_tha)); + inet_pton(AF_INET, dstip, pkt->arp_tpa); + } + return buffer; +} + +} // namespace + +class TuntapTest : public ::testing::Test { + protected: + void TearDown() override { + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))) { + // Bring back capability if we had dropped it in test case. + ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true)); + } + } +}; + +TEST_F(TuntapTest, CreateInterfaceNoCap) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, false)); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + strncpy(ifr.ifr_name, kTapName, IFNAMSIZ); + + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallFailsWithErrno(EPERM)); +} + +TEST_F(TuntapTest, CreateFixedNameInterface) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr_set = {}; + ifr_set.ifr_flags = IFF_TAP; + strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ); + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set), + SyscallSucceedsWithValue(0)); + + struct ifreq ifr_get = {}; + EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get), + SyscallSucceedsWithValue(0)); + + struct ifreq ifr_expect = ifr_set; + // See __tun_chr_ioctl() in net/drivers/tun.c. + ifr_expect.ifr_flags |= IFF_NOFILTER; + + EXPECT_THAT(DumpLinkNames(), + IsPosixErrorOkAndHolds(::testing::Contains(kTapName))); + EXPECT_THAT(memcmp(&ifr_expect, &ifr_get, sizeof(ifr_get)), ::testing::Eq(0)); +} + +TEST_F(TuntapTest, CreateInterface) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + // Empty ifr.ifr_name. Let kernel assign. + + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0)); + + struct ifreq ifr_get = {}; + EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get), + SyscallSucceedsWithValue(0)); + + std::string ifname = ifr_get.ifr_name; + EXPECT_THAT(ifname, ::testing::StartsWith("tap")); + EXPECT_THAT(DumpLinkNames(), + IsPosixErrorOkAndHolds(::testing::Contains(ifname))); +} + +TEST_F(TuntapTest, InvalidReadWrite) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + char buf[128] = {}; + EXPECT_THAT(read(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD)); + EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD)); +} + +TEST_F(TuntapTest, WriteToDownDevice) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // FIXME: gVisor always creates enabled/up'd interfaces. + SKIP_IF(IsRunningOnGvisor()); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + // Device created should be down by default. + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0)); + + char buf[128] = {}; + EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EIO)); +} + +// This test sets up a TAP device and pings kernel by sending ICMP echo request. +// +// It works as the following: +// * Open /dev/net/tun, and create kTapName interface. +// * Use rtnetlink to do initial setup of the interface: +// * Assign IP address 10.0.0.1/24 to kernel. +// * MAC address: kMacA +// * Bring up the interface. +// * Send an ICMP echo reqest (ping) packet from 10.0.0.2 (kMacB) to kernel. +// * Loop to receive packets from TAP device/fd: +// * If packet is an ICMP echo reply, it stops and passes the test. +// * If packet is an ARP request, it responds with canned reply and resends +// the +// ICMP request packet. +TEST_F(TuntapTest, PingKernel) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // Interface creation. + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr_set = {}; + ifr_set.ifr_flags = IFF_TAP; + strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ); + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set), + SyscallSucceedsWithValue(0)); + + absl::optional link = + ASSERT_NO_ERRNO_AND_VALUE(GetLinkByName(kTapName)); + ASSERT_TRUE(link.has_value()); + + // Interface setup. + struct in_addr addr; + inet_pton(AF_INET, "10.0.0.1", &addr); + EXPECT_NO_ERRNO(LinkAddLocalAddr(link->index, AF_INET, /*prefixlen=*/24, + &addr, sizeof(addr))); + + if (!IsRunningOnGvisor()) { + // FIXME: gVisor doesn't support setting MAC address on interfaces yet. + EXPECT_NO_ERRNO(LinkSetMacAddr(link->index, kMacA, sizeof(kMacA))); + + // FIXME: gVisor always creates enabled/up'd interfaces. + EXPECT_NO_ERRNO(LinkChangeFlags(link->index, IFF_UP, IFF_UP)); + } + + ping_pkt ping_req = CreatePingPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1"); + std::string arp_rep = CreateArpPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1"); + + // Send ping, this would trigger an ARP request on Linux. + EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)), + SyscallSucceedsWithValue(sizeof(ping_req))); + + // Receive loop to process inbound packets. + struct inpkt { + union { + pihdr pi; + ping_pkt ping; + arp_pkt arp; + }; + }; + while (1) { + inpkt r = {}; + int n = read(fd.get(), &r, sizeof(r)); + EXPECT_THAT(n, SyscallSucceeds()); + + if (n < sizeof(pihdr)) { + std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol + << " len: " << n << std::endl; + continue; + } + + // Process ARP packet. + if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) { + // Respond with canned ARP reply. + EXPECT_THAT(write(fd.get(), arp_rep.data(), arp_rep.size()), + SyscallSucceedsWithValue(arp_rep.size())); + // First ping request might have been dropped due to mac address not in + // ARP cache. Send it again. + EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)), + SyscallSucceedsWithValue(sizeof(ping_req))); + } + + // Process ping response packet. + if (n >= sizeof(ping_pkt) && r.pi.pi_protocol == ping_req.pi.pi_protocol && + r.ping.ip.protocol == ping_req.ip.protocol && + !memcmp(&r.ping.ip.saddr, &ping_req.ip.daddr, kIPLen) && + !memcmp(&r.ping.ip.daddr, &ping_req.ip.saddr, kIPLen) && + r.ping.icmp.type == 0 && r.ping.icmp.code == 0) { + // Ends and passes the test. + break; + } + } +} + +} // namespace testing +} // namespace gvisor -- cgit v1.2.3 From c37b196455e8b3816298e3eea98e4ee2dab8d368 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Mon, 24 Feb 2020 10:31:01 -0800 Subject: Add support for tearing down protocol dispatchers and TIME_WAIT endpoints. Protocol dispatchers were previously leaked. Bypassing TIME_WAIT is required to test this change. Also fix a race when a socket in SYN-RCVD is closed. This is also required to test this change. PiperOrigin-RevId: 296922548 --- pkg/tcpip/adapters/gonet/gonet_test.go | 63 ++++++++++++++++++++++++++-------- pkg/tcpip/network/arp/arp.go | 20 +++++++---- pkg/tcpip/network/ipv4/ipv4.go | 6 ++++ pkg/tcpip/network/ipv6/ipv6.go | 6 ++++ pkg/tcpip/stack/registration.go | 23 ++++++++++--- pkg/tcpip/stack/stack.go | 14 +++++++- pkg/tcpip/stack/stack_test.go | 6 ++++ pkg/tcpip/stack/transport_demuxer.go | 20 ----------- pkg/tcpip/stack/transport_test.go | 15 +++++++- pkg/tcpip/tcpip.go | 8 ++++- pkg/tcpip/transport/icmp/endpoint.go | 5 +++ pkg/tcpip/transport/icmp/protocol.go | 16 ++++++--- pkg/tcpip/transport/packet/endpoint.go | 5 +++ pkg/tcpip/transport/raw/endpoint.go | 5 +++ pkg/tcpip/transport/tcp/accept.go | 9 ++++- pkg/tcpip/transport/tcp/connect.go | 4 +-- pkg/tcpip/transport/tcp/dispatcher.go | 31 ++++++++++++++++- pkg/tcpip/transport/tcp/endpoint.go | 33 ++++++++++++++++-- pkg/tcpip/transport/tcp/protocol.go | 14 ++++++-- pkg/tcpip/transport/udp/endpoint.go | 5 +++ pkg/tcpip/transport/udp/protocol.go | 14 +++++--- 21 files changed, 256 insertions(+), 66 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index ea0a0409a..3c552988a 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -127,6 +127,10 @@ func TestCloseReader(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} @@ -175,6 +179,10 @@ func TestCloseReaderWithForwarder(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -225,30 +233,21 @@ func TestCloseRead(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) + _, err := r.CreateEndpoint(&wq) if err != nil { t.Fatalf("r.CreateEndpoint() = %v", err) } - defer ep.Close() - r.Complete(false) - - c := NewTCPConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e) - } - - if n, e = c.Write([]byte("abc123")); e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } + // Endpoint will be closed in deferred s.Close (above). }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) @@ -278,6 +277,10 @@ func TestCloseWrite(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -334,6 +337,10 @@ func TestUDPForwarder(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -391,6 +398,10 @@ func TestDeadlineChange(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} @@ -440,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -492,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -562,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { stop = func() { c1.Close() c2.Close() + s.Close() + s.Wait() } if err := l.Close(); err != nil { @@ -624,6 +645,10 @@ func TestTCPDialError(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -641,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -659,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 4da13c5df..e9fcc89a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -148,12 +148,12 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi }, nil } -// LinkAddressProtocol implements stack.LinkAddressResolver. +// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } -// LinkAddressRequest implements stack.LinkAddressResolver. +// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ RemoteLinkAddress: broadcastMAC, @@ -172,7 +172,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) } -// ResolveStaticAddress implements stack.LinkAddressResolver. +// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { return broadcastMAC, true @@ -183,16 +183,22 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo return tcpip.LinkAddress([]byte(nil)), false } -// SetOption implements NetworkProtocol. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.NetworkProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements NetworkProtocol. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.NetworkProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) // NewProtocol returns an ARP network protocol. diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 6597e6781..4f1742938 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -473,6 +473,12 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 180a480fd..9aef5234b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -265,6 +265,12 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index d83adf0ec..f9fd8f18f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -74,10 +74,11 @@ type TransportEndpoint interface { // HandleControlPacket takes ownership of pkt. HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) - // Close puts the endpoint in a closed state and frees all resources - // associated with it. This cleanup may happen asynchronously. Wait can - // be used to block on this asynchronous cleanup. - Close() + // Abort initiates an expedited endpoint teardown. It puts the endpoint + // in a closed state and frees all resources associated with it. This + // cleanup may happen asynchronously. Wait can be used to block on this + // asynchronous cleanup. + Abort() // Wait waits for any worker goroutines owned by the endpoint to stop. // @@ -160,6 +161,13 @@ type TransportProtocol interface { // Option returns an error if the option is not supported or the // provided option value is invalid. Option(option interface{}) *tcpip.Error + + // Close requests that any worker goroutines owned by the protocol + // stop. + Close() + + // Wait waits for any worker goroutines owned by the protocol to stop. + Wait() } // TransportDispatcher contains the methods used by the network stack to deliver @@ -293,6 +301,13 @@ type NetworkProtocol interface { // Option returns an error if the option is not supported or the // provided option value is invalid. Option(option interface{}) *tcpip.Error + + // Close requests that any worker goroutines owned by the protocol + // stop. + Close() + + // Wait waits for any worker goroutines owned by the protocol to stop. + Wait() } // NetworkDispatcher contains the methods used by the network stack to deliver diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 900dd46c5..ebb6c5e3b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1446,7 +1446,13 @@ func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { // Endpoints created or modified during this call may not get closed. func (s *Stack) Close() { for _, e := range s.RegisteredEndpoints() { - e.Close() + e.Abort() + } + for _, p := range s.transportProtocols { + p.proto.Close() + } + for _, p := range s.networkProtocols { + p.Close() } } @@ -1464,6 +1470,12 @@ func (s *Stack) Wait() { for _, e := range s.CleanupEndpoints() { e.Wait() } + for _, p := range s.transportProtocols { + p.proto.Wait() + } + for _, p := range s.networkProtocols { + p.Wait() + } s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 18016e7db..edf6bec52 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -235,6 +235,12 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { } } +// Close implements TransportProtocol.Close. +func (*fakeNetworkProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeNetworkProtocol) Wait() {} + func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index d686e6eb8..778c0a4d6 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -306,26 +306,6 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p ep.mu.RUnlock() // Don't use defer for performance reasons. } -// Close implements stack.TransportEndpoint.Close. -func (ep *multiPortEndpoint) Close() { - ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) - ep.mu.RUnlock() - for _, e := range eps { - e.Close() - } -} - -// Wait implements stack.TransportEndpoint.Wait. -func (ep *multiPortEndpoint) Wait() { - ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) - ep.mu.RUnlock() - for _, e := range eps { - e.Wait() - } -} - // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 869c69a6d..5d1da2f8b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -61,6 +61,10 @@ func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netP return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } +func (f *fakeTransportEndpoint) Abort() { + f.Close() +} + func (f *fakeTransportEndpoint) Close() { f.route.Release() } @@ -272,7 +276,7 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } -func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return nil, tcpip.ErrUnknownProtocol } @@ -310,6 +314,15 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } } +// Abort implements TransportProtocol.Abort. +func (*fakeTransportProtocol) Abort() {} + +// Close implements tcpip.Endpoint.Close. +func (*fakeTransportProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeTransportProtocol) Wait() {} + func fakeTransFactory() stack.TransportProtocol { return &fakeTransportProtocol{} } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index ce5527391..3dc5d87d6 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -341,9 +341,15 @@ type ControlMessages struct { // networking stack. type Endpoint interface { // Close puts the endpoint in a closed state and frees all resources - // associated with it. + // associated with it. Close initiates the teardown process, the + // Endpoint may not be fully closed when Close returns. Close() + // Abort initiates an expedited endpoint teardown. As compared to + // Close, Abort prioritizes closing the Endpoint quickly over cleanly. + // Abort is best effort; implementing Abort with Close is acceptable. + Abort() + // Read reads data from the endpoint and optionally returns the sender. // // This method does not block if there is no data pending. It will also diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 42afb3f5b..426da1ee6 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -96,6 +96,11 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 9ce500e80..113d92901 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,20 +104,26 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } -// SetOption implements TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.TransportProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.TransportProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // NewProtocol4 returns an ICMPv4 transport protocol. func NewProtocol4() stack.TransportProtocol { return &protocol{ProtocolNumber4} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index fc5bc69fa..5722815e9 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -98,6 +98,11 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb return ep, nil } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close implements tcpip.Endpoint.Close. func (ep *endpoint) Close() { ep.mu.Lock() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ee9c4c58b..2ef5fac76 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -121,6 +121,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt return e, nil } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close implements tcpip.Endpoint.Close. func (e *endpoint) Close() { e.mu.Lock() diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 08afb7c17..13e383ffc 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -299,6 +299,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.Close() + // Wake up any waiters. This is strictly not required normally + // as a socket that was never accepted can't really have any + // registered waiters except when stack.Wait() is called which + // waits for all registered endpoints to stop and expects an + // EventHUp. + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + if l.listenEP != nil { l.removePendingEndpoint(ep) } @@ -607,7 +614,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Unlock() // Notify waiters that the endpoint is shutdown. - e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() s := sleep.Sleeper{} diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5c5397823..7730e6445 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1372,7 +1372,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.snd.updateMaxPayloadSize(mtu, count) } - if n¬ifyReset != 0 { + if n¬ifyReset != 0 || n¬ifyAbort != 0 { return tcpip.ErrConnectionAborted } @@ -1655,7 +1655,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { } case notification: n := e.fetchNotifications() - if n¬ifyClose != 0 { + if n¬ifyClose != 0 || n¬ifyAbort != 0 { return nil } if n¬ifyDrain != 0 { diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index e18012ac0..d792b07d6 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -68,17 +68,28 @@ func (q *epQueue) empty() bool { type processor struct { epQ epQueue newEndpointWaker sleep.Waker + closeWaker sleep.Waker id int + wg sync.WaitGroup } func newProcessor(id int) *processor { p := &processor{ id: id, } + p.wg.Add(1) go p.handleSegments() return p } +func (p *processor) close() { + p.closeWaker.Assert() +} + +func (p *processor) wait() { + p.wg.Wait() +} + func (p *processor) queueEndpoint(ep *endpoint) { // Queue an endpoint for processing by the processor goroutine. p.epQ.enqueue(ep) @@ -87,11 +98,17 @@ func (p *processor) queueEndpoint(ep *endpoint) { func (p *processor) handleSegments() { const newEndpointWaker = 1 + const closeWaker = 2 s := sleep.Sleeper{} s.AddWaker(&p.newEndpointWaker, newEndpointWaker) + s.AddWaker(&p.closeWaker, closeWaker) defer s.Done() for { - s.Fetch(true) + id, ok := s.Fetch(true) + if ok && id == closeWaker { + p.wg.Done() + return + } for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { if ep.segmentQueue.empty() { continue @@ -160,6 +177,18 @@ func newDispatcher(nProcessors int) *dispatcher { } } +func (d *dispatcher) close() { + for _, p := range d.processors { + p.close() + } +} + +func (d *dispatcher) wait() { + for _, p := range d.processors { + p.wait() + } +} + func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { ep := stackEP.(*endpoint) s := newSegment(r, id, pkt) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f2be0e651..f1ad19dac 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -121,6 +121,8 @@ const ( notifyDrain notifyReset notifyResetByPeer + // notifyAbort is a request for an expedited teardown. + notifyAbort notifyKeepaliveChanged notifyMSSChanged // notifyTickleWorker is used to tickle the protocol main loop during a @@ -785,6 +787,24 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { } } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + // The abort notification is not processed synchronously, so no + // synchronization is needed. + // + // If the endpoint becomes connected after this check, we still close + // the endpoint. This worst case results in a slower abort. + // + // If the endpoint disconnected after the check, nothing needs to be + // done, so sending a notification which will potentially be ignored is + // fine. + if e.EndpointState().connected() { + e.notifyProtocolGoroutine(notifyAbort) + return + } + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources associated // with it. It must be called only once and with no other concurrent calls to // the endpoint. @@ -829,9 +849,18 @@ func (e *endpoint) closeNoShutdown() { // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) - if !e.workerRunning { + switch e.EndpointState() { + // Sockets in StateSynRecv state(passive connections) are closed when + // the handshake fails or if the listening socket is closed while + // handshake was in progress. In such cases the handshake goroutine + // is already gone by the time Close is called and we need to cleanup + // here. + case StateInitial, StateBound, StateSynRecv: e.cleanupLocked() - } else { + e.setEndpointState(StateClose) + case StateError, StateClose: + // do nothing. + default: e.workerCleanup = true e.notifyProtocolGoroutine(notifyClose) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 958c06fa7..73098d904 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -194,7 +194,7 @@ func replyWithReset(s *segment) { sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } -// SetOption implements TransportProtocol.SetOption. +// SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { case SACKEnabled: @@ -269,7 +269,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } } -// Option implements TransportProtocol.Option. +// Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { case *SACKEnabled: @@ -331,6 +331,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { } } +// Close implements stack.TransportProtocol.Close. +func (p *protocol) Close() { + p.dispatcher.close() +} + +// Wait implements stack.TransportProtocol.Wait. +func (p *protocol) Wait() { + p.dispatcher.wait() +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index eff7f3600..1c6a600b8 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -186,6 +186,11 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 259c3072a..8df089d22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -180,16 +180,22 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans return true } -// SetOption implements TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.TransportProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.TransportProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{} -- cgit v1.2.3 From acc405ba60834f5dce9ce04cd762d5cda02224cb Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Tue, 25 Feb 2020 15:03:51 -0800 Subject: Add nat table support for iptables. - commit the changes for the comments. --- pkg/abi/linux/netfilter.go | 23 +++++++++-- pkg/sentry/socket/netfilter/netfilter.go | 33 +++++++++++----- pkg/tcpip/iptables/iptables.go | 10 ++++- pkg/tcpip/iptables/targets.go | 66 ++++++++++++++++++++++---------- pkg/tcpip/iptables/types.go | 4 +- pkg/tcpip/stack/nic.go | 13 +------ test/iptables/nat.go | 12 +----- 7 files changed, 103 insertions(+), 58 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index ba4d84962..2179ac995 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -250,8 +250,24 @@ type XTErrorTarget struct { // SizeOfXTErrorTarget is the size of an XTErrorTarget. const SizeOfXTErrorTarget = 64 +// Flag values for NfNATIPV4Range. The values indicate whether to map +// protocol specific part(ports) or IPs. It corresponds to values in +// include/uapi/linux/netfilter/nf_nat.h. +const ( + NF_NAT_RANGE_MAP_IPS = 1 << 0 + NF_NAT_RANGE_PROTO_SPECIFIED = 1 << 1 + NF_NAT_RANGE_PROTO_RANDOM = 1 << 2 + NF_NAT_RANGE_PERSISTENT = 1 << 3 + NF_NAT_RANGE_PROTO_RANDOM_FULLY = 1 << 4 + NF_NAT_RANGE_PROTO_RANDOM_ALL = (NF_NAT_RANGE_PROTO_RANDOM | NF_NAT_RANGE_PROTO_RANDOM_FULLY) + NF_NAT_RANGE_MASK = (NF_NAT_RANGE_MAP_IPS | + NF_NAT_RANGE_PROTO_SPECIFIED | NF_NAT_RANGE_PROTO_RANDOM | + NF_NAT_RANGE_PERSISTENT | NF_NAT_RANGE_PROTO_RANDOM_FULLY) +) + // NfNATIPV4Range. It corresponds to struct nf_nat_ipv4_range -// in include/uapi/linux/netfilter/nf_nat.h. +// in include/uapi/linux/netfilter/nf_nat.h. The fields are in +// network byte order. type NfNATIPV4Range struct { Flags uint32 MinIP [4]byte @@ -263,11 +279,12 @@ type NfNATIPV4Range struct { // NfNATIPV4MultiRangeCompat. It corresponds to struct // nf_nat_ipv4_multi_range_compat in include/uapi/linux/netfilter/nf_nat.h. type NfNATIPV4MultiRangeCompat struct { - Rangesize uint32 - RangeIPV4 [1]NfNATIPV4Range + RangeSize uint32 + RangeIPV4 NfNATIPV4Range } // XTRedirectTarget triggers a redirect when reached. +// Adding 4 bytes of padding to make the struct 8 byte aligned. type XTRedirectTarget struct { Target XTEntryTarget NfRange NfNATIPV4MultiRangeCompat diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 512ad624a..257cb485b 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" @@ -288,6 +289,7 @@ func marshalRedirectTarget() []byte { TargetSize: linux.SizeOfXTRedirectTarget, }, } + copy(target.Target.Name[:], redirectTargetName) ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) return binary.Marshal(ret, usermem.ByteOrder, target) @@ -405,7 +407,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) return syserr.ErrInvalidArgument } - target, err := parseTarget(optVal[:targetSize]) + target, err := parseTarget(filter, optVal[:targetSize]) if err != nil { nflog("failed to parse target: %v", err) return syserr.ErrInvalidArgument @@ -552,7 +554,7 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma // parseTarget parses a target from optVal. optVal should contain only the // target. -func parseTarget(optVal []byte) (iptables.Target, error) { +func parseTarget(filter iptables.IPHeaderFilter, optVal []byte) (iptables.Target, error) { nflog("set entries: parsing target of size %d", len(optVal)) if len(optVal) < linux.SizeOfXTEntryTarget { return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal)) @@ -604,6 +606,10 @@ func parseTarget(optVal []byte) (iptables.Target, error) { return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal)) } + if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + var redirectTarget linux.XTRedirectTarget buf = optVal[:linux.SizeOfXTRedirectTarget] binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) @@ -612,21 +618,30 @@ func parseTarget(optVal []byte) (iptables.Target, error) { var target iptables.RedirectTarget nfRange := redirectTarget.NfRange - target.RangeSize = nfRange.Rangesize - target.Flags = nfRange.RangeIPV4[0].Flags + // RangeSize should be 1. + if nfRange.RangeSize != 1 { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + + // TODO(gvisor.dev/issue/170): Check if the flags are valid. + // Also check if we need to map ports or IP. + // For now, redirect target only supports dest port change. + if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument.") + } + target.Flags = nfRange.RangeIPV4.Flags - target.MinIP = tcpip.Address(nfRange.RangeIPV4[0].MinIP[:]) - target.MaxIP = tcpip.Address(nfRange.RangeIPV4[0].MaxIP[:]) + target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:]) // Convert port from big endian to little endian. port := make([]byte, 2) - binary.BigEndian.PutUint16(port, nfRange.RangeIPV4[0].MinPort) + binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort) target.MinPort = binary.LittleEndian.Uint16(port) - binary.BigEndian.PutUint16(port, nfRange.RangeIPV4[0].MaxPort) + binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort) target.MaxPort = binary.LittleEndian.Uint16(port) return target, nil - } // Unknown target. diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index c00d012c0..f7dc4f720 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -207,7 +207,7 @@ func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename stri underflow := table.Rules[table.Underflows[hook]] // Underflow is guaranteed to be an unconditional // ACCEPT or DROP. - switch v, _ := underflow.Target.Action(pkt); v { + switch v, _ := underflow.Target.Action(pkt, underflow.Filter); v { case RuleAccept: return TableAccept case RuleDrop: @@ -233,6 +233,12 @@ func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename stri func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) RuleVerdict { rule := table.Rules[ruleIdx] + // If pkt.NetworkHeader hasn't been set yet, it will be contained in + // pkt.Data.First(). + if pkt.NetworkHeader == nil { + pkt.NetworkHeader = pkt.Data.First() + } + // First check whether the packet matches the IP header filter. // TODO(gvisor.dev/issue/170): Support other fields of the filter. if rule.Filter.Protocol != 0 && rule.Filter.Protocol != header.IPv4(pkt.NetworkHeader).TransportProtocol() { @@ -252,6 +258,6 @@ func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ru } // All the matchers matched, so run the target. - verdict, _ := rule.Target.Action(pkt) + verdict, _ := rule.Target.Action(pkt, rule.Filter) return verdict } diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go index 06e65bece..a75938da3 100644 --- a/pkg/tcpip/iptables/targets.go +++ b/pkg/tcpip/iptables/targets.go @@ -26,7 +26,7 @@ import ( type AcceptTarget struct{} // Action implements Target.Action. -func (AcceptTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, string) { +func (AcceptTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) { return RuleAccept, "" } @@ -34,7 +34,7 @@ func (AcceptTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, string) { type DropTarget struct{} // Action implements Target.Action. -func (DropTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, string) { +func (DropTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) { return RuleDrop, "" } @@ -43,7 +43,7 @@ func (DropTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, string) { type ErrorTarget struct{} // Action implements Target.Action. -func (ErrorTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, string) { +func (ErrorTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) { log.Debugf("ErrorTarget triggered.") return RuleDrop, "" } @@ -54,7 +54,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (UserChainTarget) Action(tcpip.PacketBuffer) (RuleVerdict, string) { +func (UserChainTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, string) { panic("UserChainTarget should never be called.") } @@ -63,29 +63,55 @@ func (UserChainTarget) Action(tcpip.PacketBuffer) (RuleVerdict, string) { type ReturnTarget struct{} // Action implements Target.Action. -func (ReturnTarget) Action(tcpip.PacketBuffer) (RuleVerdict, string) { +func (ReturnTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, string) { return RuleReturn, "" } // RedirectTarget redirects the packet by modifying the destination port/IP. +// Min and Max values for IP and Ports in the struct indicate the range of +// values which can be used to redirect. type RedirectTarget struct { - RangeSize uint32 - Flags uint32 - MinIP tcpip.Address - MaxIP tcpip.Address - MinPort uint16 - MaxPort uint16 -} + // Flags to check if the redirect is for address or ports. + Flags uint32 -// Action implements Target.Action. -func (rt RedirectTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, string) { - log.Infof("RedirectTarget triggered.") + // Min address used to redirect. + MinIP tcpip.Address + + // Max address used to redirect. + MaxIP tcpip.Address - // TODO(gvisor.dev/issue/170): Checking only for UDP protocol. - // We're yet to support for TCP protocol. - headerView := packet.Data.First() - h := header.UDP(headerView) - h.SetDestinationPort(rt.MinPort) + // Min port used to redirect. + MinPort uint16 + // Max port used to redirect. + MaxPort uint16 +} + +// Action implements Target.Action. +func (rt RedirectTarget) Action(pkt tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) { + headerView := pkt.Data.First() + + // Network header should be set. + netHeader := header.IPv4(headerView) + if netHeader == nil { + return RuleDrop, "" + } + + // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if + // we need to change dest address (for OUTPUT chain) or ports. + hlen := int(netHeader.HeaderLength()) + + switch protocol := filter.Protocol; protocol { + case header.UDPProtocolNumber: + udp := header.UDP(headerView[hlen:]) + udp.SetDestinationPort(rt.MinPort) + case header.TCPProtocolNumber: + // TODO(gvisor.dev/issue/170): Need to recompute checksum + // and implement nat connection tracking to support TCP. + tcp := header.TCP(headerView[hlen:]) + tcp.SetDestinationPort(rt.MinPort) + default: + return RuleDrop, "" + } return RuleAccept, "" } diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 5735d001b..0102831d0 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -63,7 +63,7 @@ const ( // TableAccept indicates the packet should continue through netstack. TableAccept TableVerdict = iota - // TableAccept indicates the packet should be dropped. + // TableDrop indicates the packet should be dropped. TableDrop ) @@ -175,5 +175,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the name of the chain to jump to. - Action(packet tcpip.PacketBuffer) (RuleVerdict, string) + Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 2028f5201..a75dc0322 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1087,19 +1087,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. if protocol == header.IPv4ProtocolNumber { - newPkt := pkt.Clone() - - headerView := newPkt.Data.First() - h := header.IPv4(headerView) - newPkt.NetworkHeader = headerView[:h.HeaderLength()] - - hlen := int(h.HeaderLength()) - tlen := int(h.TotalLength()) - newPkt.Data.TrimFront(hlen) - newPkt.Data.CapLength(tlen - hlen) - ipt := n.stack.IPTables() - if ok := ipt.Check(iptables.Prerouting, newPkt); !ok { + if ok := ipt.Check(iptables.Prerouting, pkt); !ok { // iptables is telling us to drop the packet. return } diff --git a/test/iptables/nat.go b/test/iptables/nat.go index 306cbd1b3..899d1c9d3 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -71,20 +71,12 @@ func (NATRedirectTCPPort) ContainerAction(ip net.IP) error { } // Listen for TCP packets on redirect port. - if err := listenTCP(redirectPort, sendloopDuration); err != nil { - return fmt.Errorf("connection on port %d should be accepted, but got error %v", redirectPort, err) - } - - return nil + return listenTCP(redirectPort, sendloopDuration) } // LocalAction implements TestCase.LocalAction. func (NATRedirectTCPPort) LocalAction(ip net.IP) error { - if err := connectTCP(ip, dropPort, acceptPort, sendloopDuration); err != nil { - return fmt.Errorf("connection destined to port %d should be accepted, but got error %v", dropPort, err) - } - - return nil + return connectTCP(ip, dropPort, acceptPort, sendloopDuration) } // NATDropUDP tests that packets are not received in ports other than redirect port. -- cgit v1.2.3 From 5f1f9dd9d23d2b805c77b5c38d5900d13e6a29fe Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 25 Feb 2020 15:15:28 -0800 Subject: Use link-local source address for link-local multicast Tests: - header_test.TestIsV6LinkLocalMulticastAddress - header_test.TestScopeForIPv6Address - stack_test.TestIPv6SourceAddressSelectionScopeAndSameAddress PiperOrigin-RevId: 297215576 --- pkg/tcpip/header/ipv6.go | 22 ++++++++++ pkg/tcpip/header/ipv6_test.go | 98 ++++++++++++++++++++++++++++++++++++++++--- pkg/tcpip/stack/stack_test.go | 27 ++++++++---- 3 files changed, 134 insertions(+), 13 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 70e6ce095..76e88e9b3 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -115,6 +115,19 @@ const ( // for the secret key used to generate an opaque interface identifier as // outlined by RFC 7217. OpaqueIIDSecretKeyMinBytes = 16 + + // ipv6MulticastAddressScopeByteIdx is the byte where the scope (scop) field + // is located within a multicast IPv6 address, as per RFC 4291 section 2.7. + ipv6MulticastAddressScopeByteIdx = 1 + + // ipv6MulticastAddressScopeMask is the mask for the scope (scop) field, + // within the byte holding the field, as per RFC 4291 section 2.7. + ipv6MulticastAddressScopeMask = 0xF + + // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within + // a multicast IPv6 address that indicates the address has link-local scope, + // as per RFC 4291 section 2.7. + ipv6LinkLocalMulticastScope = 2 ) // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the @@ -340,6 +353,12 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool { return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } +// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6 +// link-local multicast address. +func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { + return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope +} + // IsV6UniqueLocalAddress determines if the provided address is an IPv6 // unique-local address (within the prefix FC00::/7). func IsV6UniqueLocalAddress(addr tcpip.Address) bool { @@ -411,6 +430,9 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { } switch { + case IsV6LinkLocalMulticastAddress(addr): + return LinkLocalScope, nil + case IsV6LinkLocalAddress(addr): return LinkLocalScope, nil diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index c3ad503aa..426a873b1 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -27,11 +27,12 @@ import ( ) const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") ) func TestEthernetAdddressToModifiedEUI64(t *testing.T) { @@ -256,6 +257,85 @@ func TestIsV6UniqueLocalAddress(t *testing.T) { } } +func TestIsV6LinkLocalMulticastAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid Link Local Multicast", + addr: linkLocalMulticastAddr, + expected: true, + }, + { + name: "Valid Link Local Multicast with flags", + addr: "\xff\xf2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + expected: true, + }, + { + name: "Link Local Unicast", + addr: linkLocalAddr, + expected: false, + }, + { + name: "IPv4 Multicast", + addr: "\xe0\x00\x00\x01", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV6LinkLocalMulticastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + +func TestIsV6LinkLocalAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid Link Local Unicast", + addr: linkLocalAddr, + expected: true, + }, + { + name: "Link Local Multicast", + addr: linkLocalMulticastAddr, + expected: false, + }, + { + name: "Unique Local", + addr: uniqueLocalAddr1, + expected: false, + }, + { + name: "Global", + addr: globalAddr, + expected: false, + }, + { + name: "IPv4 Link Local", + addr: "\xa9\xfe\x00\x01", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + func TestScopeForIPv6Address(t *testing.T) { tests := []struct { name string @@ -270,11 +350,17 @@ func TestScopeForIPv6Address(t *testing.T) { err: nil, }, { - name: "Link Local", + name: "Link Local Unicast", addr: linkLocalAddr, scope: header.LinkLocalScope, err: nil, }, + { + name: "Link Local Multicast", + addr: linkLocalMulticastAddr, + scope: header.LinkLocalScope, + err: nil, + }, { name: "Global", addr: globalAddr, diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index edf6bec52..e15db40fb 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2790,13 +2790,14 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { const ( - linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - nicID = 1 + linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + nicID = 1 ) // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test. @@ -2869,6 +2870,18 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { connectAddr: linkLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, + { + name: "Link Local most preferred for link local multicast (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalMulticastAddr, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local most preferred for link local multicast (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalMulticastAddr, + expectedLocalAddr: linkLocalAddr1, + }, { name: "Unique Local most preferred (last address)", nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, -- cgit v1.2.3 From 8821a7104f8c8f263b88def1a646d518ec3f5dd2 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Mon, 2 Mar 2020 13:14:12 -0800 Subject: Do not read-lock NIC recursively A deadlock may occur if a write lock on a RWMutex is blocked between nested read lock attempts as the inner read lock attempt will be blocked in this scenario. Example (T1 and T2 are differnt goroutines): T1: obtain read-lock T2: attempt write-lock (blocks) T1: attempt inner/nested read-lock (blocks) Here we can see that T1 and T2 are deadlocked. Tests: Existing tests pass. PiperOrigin-RevId: 298426678 --- pkg/tcpip/stack/nic.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 46d3a6646..3e6196aee 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -451,7 +451,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs)) for _, r := range primaryAddrs { // If r is not valid for outgoing connections, it is not a valid endpoint. - if !r.isValidForOutgoing() { + if !r.isValidForOutgoingRLocked() { continue } -- cgit v1.2.3 From c15b8515eb4a07699e5f2401f0332286f0a51043 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Tue, 3 Mar 2020 13:40:59 -0800 Subject: Fix datarace on TransportEndpointInfo.ID and clean up semantics. Ensures that all access to TransportEndpointInfo.ID is either: * In a function ending in a Locked suffix. * While holding the appropriate mutex. This primary affects the checkV4Mapped method on affected endpoints, which has been renamed to checkV4MappedLocked. Also document the method and change its argument to be a value instead of a pointer which had caused some awkwardness. This race was possible in the udp and icmp endpoints between Connect and uses of TransportEndpointInfo.ID including in both itself and Bind. The tcp endpoint did not suffer from this bug, but benefited from better documentation. Updates #357 PiperOrigin-RevId: 298682913 --- pkg/tcpip/stack/stack.go | 12 +++++++----- pkg/tcpip/transport/icmp/endpoint.go | 23 +++++++++++------------ pkg/tcpip/transport/tcp/endpoint.go | 15 ++++++++------- pkg/tcpip/transport/udp/endpoint.go | 30 ++++++++++++++++-------------- pkg/tcpip/transport/udp/endpoint_state.go | 3 +++ 5 files changed, 45 insertions(+), 38 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index ebb6c5e3b..13354d884 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -551,11 +551,13 @@ type TransportEndpointInfo struct { RegisterNICID tcpip.NICID } -// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address -// and returns the network protocol number to be used to communicate with the -// specified address. It returns an error if the passed address is incompatible -// with the receiver. -func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +// AddrNetProtoLocked unwraps the specified address if it is a V4-mapped V6 +// address and returns the network protocol number to be used to communicate +// with the specified address. It returns an error if the passed address is +// incompatible with the receiver. +// +// Preconditon: the parent endpoint mu must be held while calling this method. +func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { netProto := e.NetProto switch len(addr.Addr) { case header.IPv4AddressSize: diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 426da1ee6..2a396e9bc 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -291,15 +291,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID = e.BindNICID } - toCopy := *to - to = &toCopy - netProto, err := e.checkV4Mapped(to) + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err } - // Find the enpoint. - r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) + // Find the endpoint. + r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) if err != nil { return 0, nil, err } @@ -480,13 +478,14 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err }) } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, false /* v6only */) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -517,7 +516,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -630,7 +629,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8b9154e69..40cc664c0 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1874,13 +1874,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -1910,7 +1911,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc connectingAddr := addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -2276,7 +2277,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } e.BindAddr = addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 1c6a600b8..0af4514e1 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -443,19 +443,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrBroadcastDisabled } - netProto, err := e.checkV4Mapped(to) + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err } - r, _, err := e.connectRoute(nicID, *to, netProto) + r, _, err := e.connectRoute(nicID, dst, netProto) if err != nil { return 0, nil, err } defer r.Release() route = &r - dstPort = to.Port + dstPort = dst.Port } if route.IsResolutionRequired() { @@ -566,7 +566,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - netProto, err := e.checkV4Mapped(&fa) + fa, netProto, err := e.checkV4MappedLocked(fa) if err != nil { return err } @@ -927,13 +927,14 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u return nil } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -981,10 +982,6 @@ func (e *endpoint) Disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - netProto, err := e.checkV4Mapped(&addr) - if err != nil { - return err - } if addr.Port == 0 { // We don't support connecting to port zero. return tcpip.ErrInvalidEndpointState @@ -1012,6 +1009,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + r, nicID, err := e.connectRoute(nicID, addr, netProto) if err != nil { return err @@ -1139,7 +1141,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 43fb047ed..466bd9381 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -69,6 +69,9 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.mu.Lock() + defer e.mu.Unlock() + e.stack = s for _, m := range e.multicastMemberships { -- cgit v1.2.3 From d6f5e71df2c8ff3d763cba703786af68af1f9841 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 6 Mar 2020 08:01:45 -0800 Subject: Get strings for stack.DHCPv6ConfigurationFromNDPRA Useful for logs to print the string representation of the value instead of the integer value. PiperOrigin-RevId: 299356847 --- pkg/tcpip/stack/BUILD | 1 + .../stack/dhcpv6configurationfromndpra_string.go | 39 ++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 705cf01ee..8febd54c8 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -18,6 +18,7 @@ go_template_instance( go_library( name = "stack", srcs = [ + "dhcpv6configurationfromndpra_string.go", "icmp_rate_limit.go", "linkaddrcache.go", "linkaddrentry_list.go", diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go new file mode 100644 index 000000000..8b4213eec --- /dev/null +++ b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type=DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[DHCPv6NoConfiguration-0] + _ = x[DHCPv6ManagedAddress-1] + _ = x[DHCPv6OtherConfigurations-2] +} + +const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations" + +var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66} + +func (i DHCPv6ConfigurationFromNDPRA) String() string { + if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) { + return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]] +} -- cgit v1.2.3 From f50d9a31e9e734a02e0191f6bc91b387bb21f9ab Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 6 Mar 2020 11:32:13 -0800 Subject: Specify the source of outgoing NDP RS If the NIC has a valid IPv6 address assigned, use it as the source address for outgoing NDP Router Solicitation packets. Test: stack_test.TestRouterSolicitation PiperOrigin-RevId: 299398763 --- pkg/tcpip/checker/checker.go | 106 +++++++++++++++++++++++++++---------------- pkg/tcpip/stack/ndp.go | 29 ++++++++++-- pkg/tcpip/stack/ndp_test.go | 39 +++++++++++++--- 3 files changed, 125 insertions(+), 49 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index c6c160dfc..8dc0f7c0e 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -785,6 +785,52 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { } } +// ndpOptions checks that optsBuf only contains opts. +func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) { + t.Helper() + + it, err := optsBuf.Iter(true) + if err != nil { + t.Errorf("optsBuf.Iter(true): %s", err) + return + } + + i := 0 + for { + opt, done, err := it.Next() + if err != nil { + // This should never happen as Iter(true) above did not return an error. + t.Fatalf("unexpected error when iterating over NDP options: %s", err) + } + if done { + break + } + + if i >= len(opts) { + t.Errorf("got unexpected option: %s", opt) + continue + } + + switch wantOpt := opts[i].(type) { + case header.NDPSourceLinkLayerAddressOption: + gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) + if !ok { + t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) + } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { + t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) + } + default: + t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) + } + + i++ + } + + if missing := opts[i:]; len(missing) > 0 { + t.Errorf("missing options: %s", missing) + } +} + // NDPNSOptions creates a checker that checks that the packet contains the // provided NDP options within an NDP Neighbor Solicitation message. // @@ -796,47 +842,31 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker { icmp := h.(header.ICMPv6) ns := header.NDPNeighborSolicit(icmp.NDPPayload()) - it, err := ns.Options().Iter(true) - if err != nil { - t.Errorf("opts.Iter(true): %s", err) - return - } - - i := 0 - for { - opt, done, _ := it.Next() - if done { - break - } - - if i >= len(opts) { - t.Errorf("got unexpected option: %s", opt) - continue - } - - switch wantOpt := opts[i].(type) { - case header.NDPSourceLinkLayerAddressOption: - gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) - if !ok { - t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) - } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { - t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) - } - default: - panic("not implemented") - } - - i++ - } - - if missing := opts[i:]; len(missing) > 0 { - t.Errorf("missing options: %s", missing) - } + ndpOptions(t, ns.Options(), opts) } } // NDPRS creates a checker that checks that the packet contains a valid NDP // Router Solicitation message (as per the raw wire format). -func NDPRS() NetworkChecker { - return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize) +// +// checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDPRS as far as the size of the message is concerned. The values within the +// message are up to checkers to validate. +func NDPRS(checkers ...TransportChecker) NetworkChecker { + return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...) +} + +// NDPRSOptions creates a checker that checks that the packet contains the +// provided NDP options within an NDP Router Solicitation message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPRS message as far as the size is concerned. +func NDPRSOptions(opts []header.NDPOption) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + rs := header.NDPRouterSolicit(icmp.NDPPayload()) + ndpOptions(t, rs.Options(), opts) + } } diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index f651871ce..a9f4d5dad 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -1220,9 +1220,15 @@ func (ndp *ndpState) startSolicitingRouters() { } ndp.rtrSolicitTimer = time.AfterFunc(delay, func() { - // Send an RS message with the unspecified source address. - ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) + // As per RFC 4861 section 4.1, the source of the RS is an address assigned + // to the sending interface, or the unspecified address if no address is + // assigned to the sending interface. + ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress) + if ref == nil { + ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) + } + localAddr := ref.ep.ID().LocalAddress + r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() // Route should resolve immediately since @@ -1234,10 +1240,25 @@ func (ndp *ndpState) startSolicitingRouters() { log.Fatalf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()) } - payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source + // link-layer address option if the source address of the NDP RS is + // specified. This option MUST NOT be included if the source address is + // unspecified. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + var optsSerializer header.NDPOptionsSerializer + if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) { + optsSerializer = header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress), + } + } + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize) pkt := header.ICMPv6(hdr.Prepend(payloadSize)) pkt.SetType(header.ICMPv6RouterSolicit) + rs := header.NDPRouterSolicit(pkt.NDPPayload()) + rs.Options().Serialize(optsSerializer) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) sent := r.Stats().ICMP.V6PacketsSent diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 6e9306d09..98b1c807c 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -3384,6 +3384,10 @@ func TestRouterSolicitation(t *testing.T) { tests := []struct { name string linkHeaderLen uint16 + linkAddr tcpip.LinkAddress + nicAddr tcpip.Address + expectedSrcAddr tcpip.Address + expectedNDPOpts []header.NDPOption maxRtrSolicit uint8 rtrSolicitInt time.Duration effectiveRtrSolicitInt time.Duration @@ -3392,6 +3396,7 @@ func TestRouterSolicitation(t *testing.T) { }{ { name: "Single RS with delay", + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 1, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3401,6 +3406,8 @@ func TestRouterSolicitation(t *testing.T) { { name: "Two RS with delay", linkHeaderLen: 1, + nicAddr: llAddr1, + expectedSrcAddr: llAddr1, maxRtrSolicit: 2, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3408,8 +3415,14 @@ func TestRouterSolicitation(t *testing.T) { effectiveMaxRtrSolicitDelay: 500 * time.Millisecond, }, { - name: "Single RS without delay", - linkHeaderLen: 2, + name: "Single RS without delay", + linkHeaderLen: 2, + linkAddr: linkAddr1, + nicAddr: llAddr1, + expectedSrcAddr: llAddr1, + expectedNDPOpts: []header.NDPOption{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }, maxRtrSolicit: 1, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3419,6 +3432,8 @@ func TestRouterSolicitation(t *testing.T) { { name: "Two RS without delay and invalid zero interval", linkHeaderLen: 3, + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 2, rtrSolicitInt: 0, effectiveRtrSolicitInt: 4 * time.Second, @@ -3427,6 +3442,8 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Three RS without delay", + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 3, rtrSolicitInt: 500 * time.Millisecond, effectiveRtrSolicitInt: 500 * time.Millisecond, @@ -3435,6 +3452,8 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Two RS with invalid negative delay", + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 2, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3457,7 +3476,7 @@ func TestRouterSolicitation(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1), + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), headerLength: test.linkHeaderLen, } e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -3481,10 +3500,10 @@ func TestRouterSolicitation(t *testing.T) { checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), + checker.SrcAddr(test.expectedSrcAddr), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), - checker.NDPRS(), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), ) if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { @@ -3510,13 +3529,19 @@ func TestRouterSolicitation(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - // Make sure each RS got sent at the right - // times. + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + } + } + + // Make sure each RS is sent at the right time. remaining := test.maxRtrSolicit if remaining > 0 { waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) remaining-- } + for ; remaining > 0; remaining-- { waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout) waitForPkt(defaultAsyncEventTimeout) -- cgit v1.2.3 From d6440ec5a125746b76f189c0a5d5946dde9afc37 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Tue, 10 Mar 2020 14:49:16 -0700 Subject: The packet forwarding should resolve the link address if necessary. Fixes #1510 Test: - stack_test.TestForwardingWithStaticResolver - stack_test.TestForwardingWithFakeResolver - stack_test.TestForwardingWithNoResolver - stack_test.TestForwardingWithFakeResolverPartialTimeout - stack_test.TestForwardingWithFakeResolverTwoPackets - stack_test.TestForwardingWithFakeResolverManyPackets - stack_test.TestForwardingWithFakeResolverManyResolutions PiperOrigin-RevId: 300182570 --- pkg/tcpip/stack/BUILD | 2 + pkg/tcpip/stack/forwarder.go | 131 ++++++++ pkg/tcpip/stack/forwarder_test.go | 635 ++++++++++++++++++++++++++++++++++++++ pkg/tcpip/stack/nic.go | 51 ++- pkg/tcpip/stack/stack.go | 5 + 5 files changed, 808 insertions(+), 16 deletions(-) create mode 100644 pkg/tcpip/stack/forwarder.go create mode 100644 pkg/tcpip/stack/forwarder_test.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 8febd54c8..6c029b2fb 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -19,6 +19,7 @@ go_library( name = "stack", srcs = [ "dhcpv6configurationfromndpra_string.go", + "forwarder.go", "icmp_rate_limit.go", "linkaddrcache.go", "linkaddrentry_list.go", @@ -80,6 +81,7 @@ go_test( name = "stack_test", size = "small", srcs = [ + "forwarder_test.go", "linkaddrcache_test.go", "nic_test.go", ], diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go new file mode 100644 index 000000000..631953935 --- /dev/null +++ b/pkg/tcpip/stack/forwarder.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // maxPendingResolutions is the maximum number of pending link-address + // resolutions. + maxPendingResolutions = 64 + maxPendingPacketsPerResolution = 256 +) + +type pendingPacket struct { + nic *NIC + route *Route + proto tcpip.NetworkProtocolNumber + pkt tcpip.PacketBuffer +} + +type forwardQueue struct { + sync.Mutex + + // The packets to send once the resolver completes. + packets map[<-chan struct{}][]*pendingPacket + + // FIFO of channels used to cancel the oldest goroutine waiting for + // link-address resolution. + cancelChans []chan struct{} +} + +func newForwardQueue() *forwardQueue { + return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} +} + +func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + shouldWait := false + + f.Lock() + packets, ok := f.packets[ch] + if !ok { + shouldWait = true + } + for len(packets) == maxPendingPacketsPerResolution { + p := packets[0] + packets = packets[1:] + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + p.route.Release() + } + if l := len(packets); l >= maxPendingPacketsPerResolution { + panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution)) + } + f.packets[ch] = append(packets, &pendingPacket{ + nic: n, + route: r, + proto: protocol, + pkt: pkt, + }) + f.Unlock() + + if !shouldWait { + return + } + + // Wait for the link-address resolution to complete. + // Start a goroutine with a forwarding-cancel channel so that we can + // limit the maximum number of goroutines running concurrently. + cancel := f.newCancelChannel() + go func() { + cancelled := false + select { + case <-ch: + case <-cancel: + cancelled = true + } + + f.Lock() + packets := f.packets[ch] + delete(f.packets, ch) + f.Unlock() + + for _, p := range packets { + if cancelled { + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + } else if _, err := p.route.Resolve(nil); err != nil { + p.nic.stack.stats.IP.OutgoingPacketErrors.Increment() + } else { + p.nic.forwardPacket(p.route, p.proto, p.pkt) + } + p.route.Release() + } + }() +} + +// newCancelChannel creates a channel that can cancel a pending forwarding +// activity. The oldest channel is closed if the number of open channels would +// exceed maxPendingResolutions. +func (f *forwardQueue) newCancelChannel() chan struct{} { + f.Lock() + defer f.Unlock() + + if len(f.cancelChans) == maxPendingResolutions { + ch := f.cancelChans[0] + f.cancelChans = f.cancelChans[1:] + close(ch) + } + if l := len(f.cancelChans); l >= maxPendingResolutions { + panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions)) + } + + ch := make(chan struct{}) + f.cancelChans = append(f.cancelChans, ch) + return ch +} diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go new file mode 100644 index 000000000..321b7524d --- /dev/null +++ b/pkg/tcpip/stack/forwarder_test.go @@ -0,0 +1,635 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "encoding/binary" + "math" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +const ( + fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + fwdTestNetHeaderLen = 12 + fwdTestNetDefaultPrefixLen = 8 + + // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, + // except where another value is explicitly used. It is chosen to match + // the MTU of loopback interfaces on linux systems. + fwdTestNetDefaultMTU = 65536 +) + +// fwdTestNetworkEndpoint is a network-layer protocol endpoint. +// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only +// use the first three: destination address, source address, and transport +// protocol. They're all one byte fields to simplify parsing. +type fwdTestNetworkEndpoint struct { + nicID tcpip.NICID + id NetworkEndpointID + prefixLen int + proto *fwdTestNetworkProtocol + dispatcher TransportDispatcher + ep LinkEndpoint +} + +func (f *fwdTestNetworkEndpoint) MTU() uint32 { + return f.ep.MTU() - uint32(f.MaxHeaderLength()) +} + +func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID { + return f.nicID +} + +func (f *fwdTestNetworkEndpoint) PrefixLen() int { + return f.prefixLen +} + +func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { + return 123 +} + +func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { + return &f.id +} + +func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt tcpip.PacketBuffer) { + // Consume the network header. + b := pkt.Data.First() + pkt.Data.TrimFront(fwdTestNetHeaderLen) + + // Dispatch the packet to the transport protocol. + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) +} + +func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { + return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen +} + +func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { + return 0 +} + +func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities { + return f.ep.Capabilities() +} + +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { + // Add the protocol's header to the packet and send it to the link + // endpoint. + b := pkt.Header.Prepend(fwdTestNetHeaderLen) + b[0] = r.RemoteAddress[0] + b[1] = f.id.LocalAddress[0] + b[2] = byte(params.Protocol) + + return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { + panic("not implemented") +} + +func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + +func (*fwdTestNetworkEndpoint) Close() {} + +// fwdTestNetworkProtocol is a network-layer protocol that implements Address +// resolution. +type fwdTestNetworkProtocol struct { + addrCache *linkAddrCache + addrResolveDelay time.Duration + onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address) + onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) +} + +func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { + return fwdTestNetNumber +} + +func (f *fwdTestNetworkProtocol) MinimumPacketSize() int { + return fwdTestNetHeaderLen +} + +func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { + return fwdTestNetDefaultPrefixLen +} + +func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) +} + +func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { + return &fwdTestNetworkEndpoint{ + nicID: nicID, + id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, + proto: f, + dispatcher: dispatcher, + ep: ep, + }, nil +} + +func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func (f *fwdTestNetworkProtocol) Close() {} + +func (f *fwdTestNetworkProtocol) Wait() {} + +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error { + if f.addrCache != nil && f.onLinkAddressResolved != nil { + time.AfterFunc(f.addrResolveDelay, func() { + f.onLinkAddressResolved(f.addrCache, addr) + }) + } + return nil +} + +func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if f.onResolveStaticAddress != nil { + return f.onResolveStaticAddress(addr) + } + return "", false +} + +func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return fwdTestNetNumber +} + +// fwdTestPacketInfo holds all the information about an outbound packet. +type fwdTestPacketInfo struct { + RemoteLinkAddress tcpip.LinkAddress + LocalLinkAddress tcpip.LinkAddress + Pkt tcpip.PacketBuffer +} + +type fwdTestLinkEndpoint struct { + dispatcher NetworkDispatcher + mtu uint32 + linkAddr tcpip.LinkAddress + + // C is where outbound packets are queued. + C chan fwdTestPacketInfo +} + +// InjectInbound injects an inbound packet. +func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + e.InjectLinkAddr(protocol, "", pkt) +} + +// InjectLinkAddr injects an inbound packet with a remote link address. +func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) +} + +// Attach saves the stack network-layer dispatcher for use later when packets +// are injected. +func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *fwdTestLinkEndpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +func (e *fwdTestLinkEndpoint) MTU() uint32 { + return e.mtu +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities { + caps := LinkEndpointCapabilities(0) + return caps | CapabilityResolutionRequired +} + +// GSOMaxSize returns the maximum GSO packet size. +func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 { + return 1 << 15 +} + +// MaxHeaderLength returns the maximum size of the link layer header. Given it +// doesn't have a header, it just returns 0. +func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { + return e.linkAddr +} + +func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + p := fwdTestPacketInfo{ + RemoteLinkAddress: r.RemoteLinkAddress, + LocalLinkAddress: r.LocalLinkAddress, + Pkt: pkt, + } + + select { + case e.C <- p: + default: + } + + return nil +} + +// WritePackets stores outbound packets into the channel. +func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + n := 0 + for _, pkt := range pkts { + e.WritePacket(r, gso, protocol, pkt) + n++ + } + + return n, nil +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + p := fwdTestPacketInfo{ + Pkt: tcpip.PacketBuffer{Data: vv}, + } + + select { + case e.C <- p: + default: + } + + return nil +} + +// Wait implements stack.LinkEndpoint.Wait. +func (*fwdTestLinkEndpoint) Wait() {} + +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { + // Create a stack with the network protocol and two NICs. + s := New(Options{ + NetworkProtocols: []NetworkProtocol{proto}, + }) + + proto.addrCache = s.linkAddrCache + + // Enable forwarding. + s.SetForwarding(true) + + // NIC 1 has the link address "a", and added the network address 1. + ep1 = &fwdTestLinkEndpoint{ + C: make(chan fwdTestPacketInfo, 300), + mtu: fwdTestNetDefaultMTU, + linkAddr: "a", + } + if err := s.CreateNIC(1, ep1); err != nil { + t.Fatal("CreateNIC #1 failed:", err) + } + if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { + t.Fatal("AddAddress #1 failed:", err) + } + + // NIC 2 has the link address "b", and added the network address 2. + ep2 = &fwdTestLinkEndpoint{ + C: make(chan fwdTestPacketInfo, 300), + mtu: fwdTestNetDefaultMTU, + linkAddr: "b", + } + if err := s.CreateNIC(2, ep2); err != nil { + t.Fatal("CreateNIC #2 failed:", err) + } + if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { + t.Fatal("AddAddress #2 failed:", err) + } + + // Route all packets to NIC 2. + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) + } + + return ep1, ep2 +} + +func TestForwardingWithStaticResolver(t *testing.T) { + // Create a network protocol with a static resolver. + proto := &fwdTestNetworkProtocol{ + onResolveStaticAddress: + // The network address 3 is resolved to the link address "c". + func(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "\x03" { + return "c", true + } + return "", false + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + default: + t.Fatal("packet not forwarded") + } + + // Test that the static address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } +} + +func TestForwardingWithFakeResolver(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Any address will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } +} + +func TestForwardingWithNoResolver(t *testing.T) { + // Create a network protocol without a resolver. + proto := &fwdTestNetworkProtocol{} + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + select { + case <-ep2.C: + t.Fatal("Packet should not be forwarded") + case <-time.After(time.Second): + } +} + +func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + } + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject an inbound packet to address 4 on NIC 1. This packet should + // not be forwarded. + buf := buffer.NewView(30) + buf[0] = 4 + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf = buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := p.Pkt.Header.View() + if b[0] != 3 { + t.Fatalf("got b[0] = %d, want = 3", b[0]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } +} + +func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + // Inject two inbound packets to address 3 on NIC 1. + for i := 0; i < 2; i++ { + buf := buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + } + + for i := 0; i < 2; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := p.Pkt.Header.View() + if b[0] != 3 { + t.Fatalf("got b[0] = %d, want = 3", b[0]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } +} + +func TestForwardingWithFakeResolverManyPackets(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + for i := 0; i < maxPendingPacketsPerResolution+5; i++ { + // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. + buf := buffer.NewView(30) + buf[0] = 3 + // Set the packet sequence number. + binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + } + + for i := 0; i < maxPendingPacketsPerResolution; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := p.Pkt.Header.View() + if b[0] != 3 { + t.Fatalf("got b[0] = %d, want = 3", b[0]) + } + // The first 5 packets should not be forwarded so the the + // sequemnce number should start with 5. + want := uint16(i + 5) + if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want { + t.Fatalf("got the packet #%d, want = #%d", n, want) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } +} + +func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { + // Create a network protocol with a fake resolver. + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto) + + for i := 0; i < maxPendingResolutions+5; i++ { + // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. + // Each packet has a different destination address (3 to + // maxPendingResolutions + 7). + buf := buffer.NewView(30) + buf[0] = byte(3 + i) + ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + } + + for i := 0; i < maxPendingResolutions; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // The first 5 packets (address 3 to 7) should not be forwarded + // because their address resolutions are interrupted. + b := p.Pkt.Header.View() + if b[0] < 8 { + t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3e6196aee..cd9202aed 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1201,10 +1201,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() return } - defer r.Release() - - r.LocalLinkAddress = n.linkEP.LinkAddress() - r.RemoteLinkAddress = remote // Found a NIC. n := r.ref.nic @@ -1213,24 +1209,33 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef() n.mu.RUnlock() if ok { + r.LocalLinkAddress = n.linkEP.LinkAddress() + r.RemoteLinkAddress = remote r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. ref.ep.HandlePacket(&r, pkt) ref.decRef() - } else { - // n doesn't have a destination endpoint. - // Send the packet out of n. - pkt.Header = buffer.NewPrependableFromView(pkt.Data.First()) - pkt.Data.RemoveFirst() - - // TODO(b/128629022): use route.WritePacket. - if err := n.linkEP.WritePacket(&r, nil /* gso */, protocol, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - } else { - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) + r.Release() + return + } + + // n doesn't have a destination endpoint. + // Send the packet out of n. + // TODO(b/128629022): move this logic to route.WritePacket. + if ch, err := r.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) + // forwarder will release route. + return } + n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() + r.Release() + return } + + // The link-address resolution finished immediately. + n.forwardPacket(&r, protocol, pkt) + r.Release() return } @@ -1240,6 +1245,20 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } } +func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + pkt.Header = buffer.NewPrependableFromView(pkt.Data.First()) + pkt.Data.RemoveFirst() + + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return + } + + n.stats.Tx.Packets.Increment() + n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) +} + // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 13354d884..6f423874a 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -462,6 +462,10 @@ type Stack struct { // opaqueIIDOpts hold the options for generating opaque interface identifiers // (IIDs) as outlined by RFC 7217. opaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // forwarder holds the packets that wait for their link-address resolutions + // to complete, and forwards them when each resolution is done. + forwarder *forwardQueue } // UniqueID is an abstract generator of unique identifiers. @@ -641,6 +645,7 @@ func New(opts Options) *Stack { uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, opaqueIIDOpts: opts.OpaqueIIDOpts, + forwarder: newForwardQueue(), } // Add specified network protocols. -- cgit v1.2.3 From f56fe66b13b979f2ac96e8fce6fb0a5dec9a32e0 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 10 Mar 2020 17:50:47 -0700 Subject: Honour the link's MaxHeaderLength when forwarding This change also updates where the IP packet buffer is held in an outbound tcpip.PacketBuffer from Header to Data. This change removes unncessary copying of the IP packet buffer when forwarding. Test: stack_test.TestNICForwarding PiperOrigin-RevId: 300217972 --- pkg/tcpip/packet_buffer.go | 8 ++- pkg/tcpip/stack/forwarder_test.go | 8 +-- pkg/tcpip/stack/nic.go | 6 +- pkg/tcpip/stack/stack_test.go | 112 ++++++++++++++++++++++++-------------- pkg/tcpip/stack/transport_test.go | 4 +- 5 files changed, 85 insertions(+), 53 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go index ab24372e7..04852132c 100644 --- a/pkg/tcpip/packet_buffer.go +++ b/pkg/tcpip/packet_buffer.go @@ -39,8 +39,12 @@ type PacketBuffer struct { // payload. DataSize int - // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. + // Header holds the headers of outbound packets generated by the netstack. As + // a packet is passed down the stack, each layer adds to Header. + // + // Note, if a packet is being forwarded at the IP layer, the headers for the + // IP layer and above (transport) will be held in Data as the packet was not + // passed down the stack it arrived at before being forwarded. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 321b7524d..5a04590d5 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -473,7 +473,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.First() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -517,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.First() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -564,7 +564,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.First() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -619,7 +619,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Header.View() + b := p.Pkt.Data.First() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index cd9202aed..e46bd86c6 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1246,10 +1246,10 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - pkt.Header = buffer.NewPrependableFromView(pkt.Data.First()) - pkt.Data.RemoveFirst() + // TODO(b/143425874): Decrease the TTL field in forwarded packets. + // pkt.Header should have enough capacity to hold the link's headers. + pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength())) if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e15db40fb..9515426d6 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2240,56 +2240,84 @@ func TestNICStats(t *testing.T) { } func TestNICForwarding(t *testing.T) { - // Create a stack with the fake network protocol, two NICs, each with - // an address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - s.SetForwarding(true) + const nicID1 = 1 + const nicID2 = 2 + const dstAddr = tcpip.Address("\x03") - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC #1 failed:", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) + tests := []struct { + name string + headerLen uint16 + }{ + { + name: "Zero header length", + }, + { + name: "Non-zero header length", + headerLen: 16, + }, } - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC #2 failed:", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + s.SetForwarding(true) - // Route all packets to address 3 to NIC 2. - { - subnet, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}}) - } + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) + } - // Send a packet to address 3. - buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + ep2 := channelLinkWithHeaderLength{ + Endpoint: channel.New(10, defaultMTU, ""), + headerLength: test.headerLen, + } + if err := s.CreateNIC(nicID2, &ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) + } - if _, ok := ep2.Read(); !ok { - t.Fatal("Packet not forwarded") - } + // Route all packets to dstAddr to NIC 2. + { + subnet, err := tcpip.NewSubnet(dstAddr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) + } - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } + // Send a packet to dstAddr. + buf := buffer.NewView(30) + buf[0] = dstAddr[0] + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) - if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + pkt, ok := ep2.Read() + if !ok { + t.Fatal("packet not forwarded") + } + + // Test that the link's MaxHeaderLength is honoured. + if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want { + t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want) + } + + // Test that forwarding increments Tx stats correctly. + if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) + } + + if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + }) } } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 5d1da2f8b..3609a25b6 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -641,10 +641,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.Header.View()[0]; dst != 3 { + if dst := p.Pkt.Data.First()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.Header.View()[1]; src != 1 { + if src := p.Pkt.Data.First()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } -- cgit v1.2.3 From 7bca09107b4efc0a7f36f932612061f13a146d6f Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Wed, 11 Mar 2020 06:07:52 -0700 Subject: Automated rollback of changelist 300217972 PiperOrigin-RevId: 300308974 --- pkg/tcpip/packet_buffer.go | 8 +-- pkg/tcpip/stack/forwarder_test.go | 8 +-- pkg/tcpip/stack/nic.go | 6 +- pkg/tcpip/stack/stack_test.go | 112 ++++++++++++++------------------------ pkg/tcpip/stack/transport_test.go | 4 +- 5 files changed, 53 insertions(+), 85 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go index 04852132c..ab24372e7 100644 --- a/pkg/tcpip/packet_buffer.go +++ b/pkg/tcpip/packet_buffer.go @@ -39,12 +39,8 @@ type PacketBuffer struct { // payload. DataSize int - // Header holds the headers of outbound packets generated by the netstack. As - // a packet is passed down the stack, each layer adds to Header. - // - // Note, if a packet is being forwarded at the IP layer, the headers for the - // IP layer and above (transport) will be held in Data as the packet was not - // passed down the stack it arrived at before being forwarded. + // Header holds the headers of outbound packets. As a packet is passed + // down the stack, each layer adds to Header. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 5a04590d5..321b7524d 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -473,7 +473,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.First() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -517,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.First() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -564,7 +564,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.First() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -619,7 +619,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Data.First() + b := p.Pkt.Header.View() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index e46bd86c6..cd9202aed 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1246,10 +1246,10 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { - // TODO(b/143425874): Decrease the TTL field in forwarded packets. + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + pkt.Header = buffer.NewPrependableFromView(pkt.Data.First()) + pkt.Data.RemoveFirst() - // pkt.Header should have enough capacity to hold the link's headers. - pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength())) if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9515426d6..e15db40fb 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2240,84 +2240,56 @@ func TestNICStats(t *testing.T) { } func TestNICForwarding(t *testing.T) { - const nicID1 = 1 - const nicID2 = 2 - const dstAddr = tcpip.Address("\x03") + // Create a stack with the fake network protocol, two NICs, each with + // an address. + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + s.SetForwarding(true) - tests := []struct { - name string - headerLen uint16 - }{ - { - name: "Zero header length", - }, - { - name: "Non-zero header length", - headerLen: 16, - }, + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { + t.Fatal("CreateNIC #1 failed:", err) + } + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatal("AddAddress #1 failed:", err) } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - s.SetForwarding(true) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) - } - - ep2 := channelLinkWithHeaderLength{ - Endpoint: channel.New(10, defaultMTU, ""), - headerLength: test.headerLen, - } - if err := s.CreateNIC(nicID2, &ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) - } - - // Route all packets to dstAddr to NIC 2. - { - subnet, err := tcpip.NewSubnet(dstAddr, "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) - } + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { + t.Fatal("CreateNIC #2 failed:", err) + } + if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { + t.Fatal("AddAddress #2 failed:", err) + } - // Send a packet to dstAddr. - buf := buffer.NewView(30) - buf[0] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + // Route all packets to address 3 to NIC 2. + { + subnet, err := tcpip.NewSubnet("\x03", "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}}) + } - pkt, ok := ep2.Read() - if !ok { - t.Fatal("packet not forwarded") - } + // Send a packet to address 3. + buf := buffer.NewView(30) + buf[0] = 3 + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) - // Test that the link's MaxHeaderLength is honoured. - if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want { - t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want) - } + if _, ok := ep2.Read(); !ok { + t.Fatal("Packet not forwarded") + } - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } + // Test that forwarding increments Tx stats correctly. + if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) + } - if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } - }) + if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 3609a25b6..5d1da2f8b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -641,10 +641,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.Data.First()[0]; dst != 3 { + if dst := p.Pkt.Header.View()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.Data.First()[1]; src != 1 { + if src := p.Pkt.Header.View()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } -- cgit v1.2.3 From 035f7434e978f3f246ae05e9c748e8ca7d8d7fd1 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Wed, 11 Mar 2020 21:12:41 -0700 Subject: Use a heap in transport demuxer ...instead of sorting at various times. Plug a memory leak by setting removed elements to nil. PiperOrigin-RevId: 300471087 --- pkg/tcpip/stack/transport_demuxer.go | 155 +++++++++++++++--------------- pkg/tcpip/stack/transport_demuxer_test.go | 14 ++- 2 files changed, 89 insertions(+), 80 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 778c0a4d6..ff1845bfb 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -15,9 +15,9 @@ package stack import ( + "container/heap" "fmt" "math/rand" - "sort" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -141,16 +141,17 @@ func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto t epsByNic.mu.Lock() defer epsByNic.mu.Unlock() - if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok { - // There was already a bind. - return multiPortEp.singleRegisterEndpoint(t, reusePort) + multiPortEp, ok := epsByNic.endpoints[bindToDevice] + if !ok { + multiPortEp = &multiPortEndpoint{ + demux: d, + netProto: netProto, + transProto: transProto, + reuse: reusePort, + } + epsByNic.endpoints[bindToDevice] = multiPortEp } - // This is a new binding. - multiPortEp := &multiPortEndpoint{demux: d, netProto: netProto, transProto: transProto} - multiPortEp.endpointsMap = make(map[TransportEndpoint]int) - multiPortEp.reuse = reusePort - epsByNic.endpoints[bindToDevice] = multiPortEp return multiPortEp.singleRegisterEndpoint(t, reusePort) } @@ -222,6 +223,35 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum return nil } +type transportEndpointHeap []TransportEndpoint + +var _ heap.Interface = (*transportEndpointHeap)(nil) + +func (h *transportEndpointHeap) Len() int { + return len(*h) +} + +func (h *transportEndpointHeap) Less(i, j int) bool { + return (*h)[i].UniqueID() < (*h)[j].UniqueID() +} + +func (h *transportEndpointHeap) Swap(i, j int) { + (*h)[i], (*h)[j] = (*h)[j], (*h)[i] +} + +func (h *transportEndpointHeap) Push(x interface{}) { + *h = append(*h, x.(TransportEndpoint)) +} + +func (h *transportEndpointHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + old[n-1] = nil + *h = old[:n-1] + return x +} + // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. @@ -237,15 +267,14 @@ type multiPortEndpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber - endpointsArr []TransportEndpoint - endpointsMap map[TransportEndpoint]int + endpoints transportEndpointHeap // reuse indicates if more than one endpoint is allowed. reuse bool } func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) + eps := append([]TransportEndpoint(nil), ep.endpoints...) ep.mu.RUnlock() return eps } @@ -262,8 +291,8 @@ func reciprocalScale(val, n uint32) uint32 { // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { - if len(mpep.endpointsArr) == 1 { - return mpep.endpointsArr[0] + if len(mpep.endpoints) == 1 { + return mpep.endpoints[0] } payload := []byte{ @@ -279,29 +308,26 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr))) - return mpep.endpointsArr[idx] + idx := reciprocalScale(hash, uint32(len(mpep.endpoints))) + return mpep.endpoints[idx] } func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] - for i, endpoint := range ep.endpointsArr { - // HandlePacket takes ownership of pkt, so each endpoint needs - // its own copy except for the final one. - if i == len(ep.endpointsArr)-1 { - if mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt) - break - } - endpoint.HandlePacket(r, id, pkt) - break - } + // HandlePacket takes ownership of pkt, so each endpoint needs + // its own copy except for the final one. + for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { if mustQueue { queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) - continue + } else { + endpoint.HandlePacket(r, id, pkt.Clone()) } - endpoint.HandlePacket(r, id, pkt.Clone()) + } + if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { + queuedProtocol.QueuePacket(r, endpoint, id, pkt) + } else { + endpoint.HandlePacket(r, id, pkt) } ep.mu.RUnlock() // Don't use defer for performance reasons. } @@ -312,26 +338,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo ep.mu.Lock() defer ep.mu.Unlock() - if len(ep.endpointsArr) > 0 { + if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. if !ep.reuse || !reusePort { return tcpip.ErrPortInUse } } - // A new endpoint is added into endpointsArr and its index there is saved in - // endpointsMap. This will allow us to remove endpoint from the array fast. - ep.endpointsMap[t] = len(ep.endpointsArr) - ep.endpointsArr = append(ep.endpointsArr, t) + heap.Push(&ep.endpoints, t) - // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints - // can be restored in the same order. - sort.Slice(ep.endpointsArr, func(i, j int) bool { - return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID() - }) - for i, e := range ep.endpointsArr { - ep.endpointsMap[e] = i - } return nil } @@ -340,21 +355,13 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { ep.mu.Lock() defer ep.mu.Unlock() - idx, ok := ep.endpointsMap[t] - if !ok { - return false - } - delete(ep.endpointsMap, t) - l := len(ep.endpointsArr) - if l > 1 { - // The last endpoint in endpointsArr is moved instead of the deleted one. - lastEp := ep.endpointsArr[l-1] - ep.endpointsArr[idx] = lastEp - ep.endpointsMap[lastEp] = idx - ep.endpointsArr = ep.endpointsArr[0 : l-1] - return false + for i, endpoint := range ep.endpoints { + if endpoint == t { + heap.Remove(&ep.endpoints, i) + break + } } - return true + return len(ep.endpoints) == 0 } func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { @@ -371,17 +378,14 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.mu.Lock() defer eps.mu.Unlock() - if epsByNic, ok := eps.endpoints[id]; ok { - // There was already a binding. - return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) - } - - // This is a new binding. - epsByNic := &endpointsByNic{ - endpoints: make(map[tcpip.NICID]*multiPortEndpoint), - seed: rand.Uint32(), + epsByNic, ok := eps.endpoints[id] + if !ok { + epsByNic = &endpointsByNic{ + endpoints: make(map[tcpip.NICID]*multiPortEndpoint), + seed: rand.Uint32(), + } + eps.endpoints[id] = epsByNic } - eps.endpoints[id] = epsByNic return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } @@ -396,14 +400,6 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN } } -var loopbackSubnet = func() tcpip.Subnet { - sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00") - if err != nil { - panic(err) - } - return sn -}() - // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. @@ -601,8 +597,8 @@ func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNum } eps.mu.Lock() - defer eps.mu.Unlock() eps.rawEndpoints = append(eps.rawEndpoints, ep) + eps.mu.Unlock() return nil } @@ -616,13 +612,16 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN } eps.mu.Lock() - defer eps.mu.Unlock() for i, rawEP := range eps.rawEndpoints { if rawEP == ep { - eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...) - return + lastIdx := len(eps.rawEndpoints) - 1 + eps.rawEndpoints[i] = eps.rawEndpoints[lastIdx] + eps.rawEndpoints[lastIdx] = nil + eps.rawEndpoints = eps.rawEndpoints[:lastIdx] + break } } + eps.mu.Unlock() } func isMulticastOrBroadcast(addr tcpip.Address) bool { diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 5e9237de9..0e3e239c5 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -167,8 +167,18 @@ func TestTransportDemuxerRegister(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + tEP, ok := ep.(stack.TransportEndpoint) + if !ok { + t.Fatalf("%T does not implement stack.TransportEndpoint", ep) + } + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want { t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) } }) -- cgit v1.2.3 From 28d26d2c4f231c447a10bcbcfb8223a804c9d8bc Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 13 Mar 2020 10:43:09 -0700 Subject: Honour the link's MaxHeaderLength when forwarding LinkEndpoints may expect/assume that the a tcpip.PacketBuffer's Header has enough capacity for its own headers, as per documentation for LinkEndpoint.MaxHeaderLength. Test: stack_test.TestNICForwarding PiperOrigin-RevId: 300784192 --- pkg/tcpip/stack/nic.go | 18 ++++++- pkg/tcpip/stack/stack_test.go | 112 ++++++++++++++++++++++++++---------------- 2 files changed, 87 insertions(+), 43 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 3cd5fec71..230ee0697 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -15,6 +15,7 @@ package stack import ( + "fmt" "log" "reflect" "sort" @@ -1259,9 +1260,24 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - pkt.Header = buffer.NewPrependableFromView(pkt.Data.First()) + + firstData := pkt.Data.First() pkt.Data.RemoveFirst() + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { + pkt.Header = buffer.NewPrependableFromView(firstData) + } else { + firstDataLen := len(firstData) + + // pkt.Header should have enough capacity to hold n.linkEP's headers. + pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) + + // TODO(b/151227689): avoid copying the packet when forwarding + if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) + } + } + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e15db40fb..9515426d6 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2240,56 +2240,84 @@ func TestNICStats(t *testing.T) { } func TestNICForwarding(t *testing.T) { - // Create a stack with the fake network protocol, two NICs, each with - // an address. - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - s.SetForwarding(true) + const nicID1 = 1 + const nicID2 = 2 + const dstAddr = tcpip.Address("\x03") - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC #1 failed:", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress #1 failed:", err) + tests := []struct { + name string + headerLen uint16 + }{ + { + name: "Zero header length", + }, + { + name: "Non-zero header length", + headerLen: 16, + }, } - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatal("CreateNIC #2 failed:", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatal("AddAddress #2 failed:", err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + s.SetForwarding(true) - // Route all packets to address 3 to NIC 2. - { - subnet, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}}) - } + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) + } - // Send a packet to address 3. - buf := buffer.NewView(30) - buf[0] = 3 - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + ep2 := channelLinkWithHeaderLength{ + Endpoint: channel.New(10, defaultMTU, ""), + headerLength: test.headerLen, + } + if err := s.CreateNIC(nicID2, &ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) + } - if _, ok := ep2.Read(); !ok { - t.Fatal("Packet not forwarded") - } + // Route all packets to dstAddr to NIC 2. + { + subnet, err := tcpip.NewSubnet(dstAddr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) + } - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } + // Send a packet to dstAddr. + buf := buffer.NewView(30) + buf[0] = dstAddr[0] + ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) - if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + pkt, ok := ep2.Read() + if !ok { + t.Fatal("packet not forwarded") + } + + // Test that the link's MaxHeaderLength is honoured. + if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want { + t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want) + } + + // Test that forwarding increments Tx stats correctly. + if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) + } + + if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + }) } } -- cgit v1.2.3 From 86409c91813256f45ebcb08efeac9d7f9e56a804 Mon Sep 17 00:00:00 2001 From: Jamie Liu Date: Fri, 13 Mar 2020 12:20:09 -0700 Subject: Avoid unnecessary work in transportDemuxer.deliverPacket(). - Don't allocate []*endpointsByNic in transportDemuxer.deliverPacket() unless actually needed for UDP broadcast/multicast. - Don't allocate []*endpointsByNic via transportDemuxer.findEndpointLocked() => transportDemuxer.findAllEndpointsLocked(). - Skip unnecessary map lookups in transportDemuxer.findEndpointLocked() => transportDemuxer.findAllEndpointsLocked() (now iterEndpointsLocked). For most deliverable packets other than UDP broadcast/multicast packets, this saves two slice allocations and three map lookups per packet. PiperOrigin-RevId: 300804135 --- pkg/tcpip/stack/transport_demuxer.go | 112 ++++++++++++++++++----------------- 1 file changed, 59 insertions(+), 53 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index ff1845bfb..d4c0359e8 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -409,61 +409,45 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto return false } - eps.mu.RLock() - - // Determine which transport endpoint or endpoints to deliver this packet to. // If the packet is a UDP broadcast or multicast, then find all matching - // transport endpoints. If the packet is a TCP packet with a non-unicast - // source or destination address, then do nothing further and instruct - // the caller to do the same. - var destEps []*endpointsByNic - switch protocol { - case header.UDPProtocolNumber: - if isMulticastOrBroadcast(id.LocalAddress) { - destEps = d.findAllEndpointsLocked(eps, id) - break - } - - if ep := d.findEndpointLocked(eps, id); ep != nil { - destEps = append(destEps, ep) + // transport endpoints. + if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + eps.mu.RLock() + destEPs := d.findAllEndpointsLocked(eps, id) + eps.mu.RUnlock() + // Fail if we didn't find at least one matching transport endpoint. + if len(destEPs) == 0 { + r.Stats().UDP.UnknownPortErrors.Increment() + return false } - - case header.TCPProtocolNumber: - if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) { - // TCP can only be used to communicate between a single - // source and a single destination; the addresses must - // be unicast. - eps.mu.RUnlock() - r.Stats().TCP.InvalidSegmentsReceived.Increment() - return true + // handlePacket takes ownership of pkt, so each endpoint needs its own + // copy except for the final one. + for _, ep := range destEPs[:len(destEPs)-1] { + ep.handlePacket(r, id, pkt.Clone()) } + destEPs[len(destEPs)-1].handlePacket(r, id, pkt) + return true + } - fallthrough - - default: - if ep := d.findEndpointLocked(eps, id); ep != nil { - destEps = append(destEps, ep) - } + // If the packet is a TCP packet with a non-unicast source or destination + // address, then do nothing further and instruct the caller to do the same. + if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) { + // TCP can only be used to communicate between a single source and a + // single destination; the addresses must be unicast. + r.Stats().TCP.InvalidSegmentsReceived.Increment() + return true } + eps.mu.RLock() + ep := d.findEndpointLocked(eps, id) eps.mu.RUnlock() - - // Fail if we didn't find at least one matching transport endpoint. - if len(destEps) == 0 { - // UDP packet could not be delivered to an unknown destination port. + if ep == nil { if protocol == header.UDPProtocolNumber { r.Stats().UDP.UnknownPortErrors.Increment() } return false } - - // HandlePacket takes ownership of pkt, so each endpoint needs its own - // copy except for the final one. - for _, ep := range destEps[:len(destEps)-1] { - ep.handlePacket(r, id, pkt.Clone()) - } - destEps[len(destEps)-1].handlePacket(r, id, pkt) - + ep.handlePacket(r, id, pkt) return true } @@ -515,11 +499,17 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco return true } -func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { - var matchedEPs []*endpointsByNic +// iterEndpointsLocked yields all endpointsByNic in eps that match id, in +// descending order of match quality. If a call to yield returns false, +// iterEndpointsLocked stops iteration and returns immediately. +// +// Preconditions: eps.mu must be locked. +func (d *transportDemuxer) iterEndpointsLocked(eps *transportEndpoints, id TransportEndpointID, yield func(*endpointsByNic) bool) { // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { - matchedEPs = append(matchedEPs, ep) + if !yield(ep) { + return + } } // Try to find a match with the id minus the local address. @@ -527,7 +517,9 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) + if !yield(ep) { + return + } } // Try to find a match with the id minus the remote part. @@ -535,14 +527,26 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr nid.RemoteAddress = "" nid.RemotePort = 0 if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) + if !yield(ep) { + return + } } // Try to find a match with only the local port. nid.LocalAddress = "" if ep, ok := eps.endpoints[nid]; ok { - matchedEPs = append(matchedEPs, ep) + if !yield(ep) { + return + } } +} + +func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { + var matchedEPs []*endpointsByNic + d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool { + matchedEPs = append(matchedEPs, ep) + return true + }) return matchedEPs } @@ -580,10 +584,12 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN // findEndpointLocked returns the endpoint that most closely matches the given // id. func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { - if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 { - return matchedEPs[0] - } - return nil + var matchedEP *endpointsByNic + d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool { + matchedEP = ep + return false + }) + return matchedEP } // registerRawEndpoint registers the given endpoint with the dispatcher such -- cgit v1.2.3 From 530a31f3c08b10fbd2f8135c5b76380cf5e7f4e8 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 13 Mar 2020 12:29:19 -0700 Subject: Disable a NIC before removing it When a NIC is removed, attempt to disable the NIC first to cleanup dynamic state and stop ongoing periodic tasks (e.g. IPv6 router solicitations, DAD) so that a removed NIC does not attempt to send packets. Tests: - stack_test.TestRemoveUnknownNIC - stack_test.TestRemoveNIC - stack_test.TestDADStop - stack_test.TestCleanupNDPState - stack_test.TestRouteWithDownNIC - stack_test.TestStopStartSolicitingRouters PiperOrigin-RevId: 300805857 --- pkg/tcpip/stack/ndp_test.go | 154 ++++++++++----- pkg/tcpip/stack/nic.go | 77 +++++--- pkg/tcpip/stack/registration.go | 3 + pkg/tcpip/stack/stack_test.go | 408 ++++++++++++++++++++++++---------------- 4 files changed, 413 insertions(+), 229 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 98b1c807c..4368c236c 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -639,8 +639,9 @@ func TestDADStop(t *testing.T) { const nicID = 1 tests := []struct { - name string - stopFn func(t *testing.T, s *stack.Stack) + name string + stopFn func(t *testing.T, s *stack.Stack) + skipFinalAddrCheck bool }{ // Tests to make sure that DAD stops when an address is removed. { @@ -661,6 +662,19 @@ func TestDADStop(t *testing.T) { } }, }, + + // Tests to make sure that DAD stops when the NIC is removed. + { + name: "Remove NIC", + stopFn: func(t *testing.T, s *stack.Stack) { + if err := s.RemoveNIC(nicID); err != nil { + t.Fatalf("RemoveNIC(%d): %s", nicID, err) + } + }, + // The NIC is removed so we can't check its addresses after calling + // stopFn. + skipFinalAddrCheck: true, + }, } for _, test := range tests { @@ -710,12 +724,15 @@ func TestDADStop(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + + if !test.skipFinalAddrCheck { + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } } // Should not have sent more than 1 NS message. @@ -2983,11 +3000,12 @@ func TestCleanupNDPState(t *testing.T) { cleanupFn func(t *testing.T, s *stack.Stack) keepAutoGenLinkLocal bool maxAutoGenAddrEvents int + skipFinalAddrCheck bool }{ // A NIC should still keep its auto-generated link-local address when // becoming a router. { - name: "Forwarding Enable", + name: "Enable forwarding", cleanupFn: func(t *testing.T, s *stack.Stack) { t.Helper() s.SetForwarding(true) @@ -2998,7 +3016,7 @@ func TestCleanupNDPState(t *testing.T) { // A NIC should cleanup all NDP state when it is disabled. { - name: "NIC Disable", + name: "Disable NIC", cleanupFn: func(t *testing.T, s *stack.Stack) { t.Helper() @@ -3012,6 +3030,26 @@ func TestCleanupNDPState(t *testing.T) { keepAutoGenLinkLocal: false, maxAutoGenAddrEvents: 6, }, + + // A NIC should cleanup all NDP state when it is removed. + { + name: "Remove NIC", + cleanupFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + + if err := s.RemoveNIC(nicID1); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err) + } + if err := s.RemoveNIC(nicID2); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err) + } + }, + keepAutoGenLinkLocal: false, + maxAutoGenAddrEvents: 6, + // The NICs are removed so we can't check their addresses after calling + // stopFn. + skipFinalAddrCheck: true, + }, } for _, test := range tests { @@ -3230,35 +3268,37 @@ func TestCleanupNDPState(t *testing.T) { t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) } - // Make sure the auto-generated addresses got removed. - nicinfo = s.NICInfo() - nic1Addrs = nicinfo[nicID1].ProtocolAddresses - nic2Addrs = nicinfo[nicID2].ProtocolAddresses - if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { - if test.keepAutoGenLinkLocal { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } else { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + if !test.skipFinalAddrCheck { + // Make sure the auto-generated addresses got removed. + nicinfo = s.NICInfo() + nic1Addrs = nicinfo[nicID1].ProtocolAddresses + nic2Addrs = nicinfo[nicID2].ProtocolAddresses + if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } } - } - if containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { - if test.keepAutoGenLinkLocal { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } else { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + if containsV6Addr(nic1Addrs, e1Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if containsV6Addr(nic1Addrs, e1Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } + } + if containsV6Addr(nic2Addrs, e2Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if containsV6Addr(nic2Addrs, e2Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) } - } - if containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) } // Should not get any more events (invalidation timers should have been @@ -3575,17 +3615,19 @@ func TestStopStartSolicitingRouters(t *testing.T) { tests := []struct { name string startFn func(t *testing.T, s *stack.Stack) - stopFn func(t *testing.T, s *stack.Stack) + // first is used to tell stopFn that it is being called for the first time + // after router solicitations were last enabled. + stopFn func(t *testing.T, s *stack.Stack, first bool) }{ // Tests that when forwarding is enabled or disabled, router solicitations // are stopped or started, respectively. { - name: "Forwarding enabled and disabled", + name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() s.SetForwarding(false) }, - stopFn: func(t *testing.T, s *stack.Stack) { + stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() s.SetForwarding(true) }, @@ -3594,7 +3636,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Tests that when a NIC is enabled or disabled, router solicitations // are started or stopped, respectively. { - name: "NIC disabled and enabled", + name: "Enable and disable NIC", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() @@ -3602,7 +3644,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } }, - stopFn: func(t *testing.T, s *stack.Stack) { + stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() if err := s.DisableNIC(nicID); err != nil { @@ -3610,6 +3652,25 @@ func TestStopStartSolicitingRouters(t *testing.T) { } }, }, + + // Tests that when a NIC is removed, router solicitations are stopped. We + // cannot start router solications on a removed NIC. + { + name: "Remove NIC", + stopFn: func(t *testing.T, s *stack.Stack, first bool) { + t.Helper() + + // Only try to remove the NIC the first time stopFn is called since it's + // impossible to remove an already removed NIC. + if !first { + return + } + + if err := s.RemoveNIC(nicID); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) + } + }, + }, } for _, test := range tests { @@ -3648,7 +3709,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { } // Stop soliciting routers. - test.stopFn(t, s) + test.stopFn(t, s, true /* first */) ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { @@ -3662,13 +3723,18 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stopping router solicitations after it has already been stopped should // do nothing. - test.stopFn(t, s) + test.stopFn(t, s, false /* first */) ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") } + // If test.startFn is nil, there is no way to restart router solications. + if test.startFn == nil { + return + } + // Start soliciting routers. test.startFn(t, s) waitForPkt(delay + defaultAsyncEventTimeout) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 230ee0697..11eaa6a2c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -56,7 +56,7 @@ type NIC struct { primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint endpoints map[NetworkEndpointID]*referencedNetworkEndpoint addressRanges []tcpip.Subnet - mcastJoins map[NetworkEndpointID]int32 + mcastJoins map[NetworkEndpointID]uint32 // packetEPs is protected by mu, but the contained PacketEndpoint // values are not. packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint @@ -123,7 +123,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC } nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) - nic.mu.mcastJoins = make(map[NetworkEndpointID]int32) + nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) nic.mu.ndp = ndpState{ nic: nic, @@ -167,8 +167,17 @@ func (n *NIC) disable() *tcpip.Error { } n.mu.Lock() - defer n.mu.Unlock() + err := n.disableLocked() + n.mu.Unlock() + return err +} +// disableLocked disables n. +// +// It undoes the work done by enable. +// +// n MUST be locked. +func (n *NIC) disableLocked() *tcpip.Error { if !n.mu.enabled { return nil } @@ -191,7 +200,7 @@ func (n *NIC) disable() *tcpip.Error { } // The NIC may have already left the multicast group. - if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { return err } } @@ -307,24 +316,33 @@ func (n *NIC) remove() *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - // Detach from link endpoint, so no packet comes in. - n.linkEP.Attach(nil) + n.disableLocked() + + // TODO(b/151378115): come up with a better way to pick an error than the + // first one. + var err *tcpip.Error + + // Forcefully leave multicast groups. + for nid := range n.mu.mcastJoins { + if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil { + err = tempErr + } + } // Remove permanent and permanentTentative addresses, so no packet goes out. - var errs []*tcpip.Error for nid, ref := range n.mu.endpoints { switch ref.getKind() { case permanentTentative, permanent: - if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil { - errs = append(errs, err) + if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil { + err = tempErr } } } - if len(errs) > 0 { - return errs[0] - } - return nil + // Detach from link endpoint, so no packet comes in. + n.linkEP.Attach(nil) + + return err } // becomeIPv6Router transitions n into an IPv6 router. @@ -971,6 +989,7 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { for i, ref := range refs { if ref == r { n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + refs[len(refs)-1] = nil break } } @@ -1021,9 +1040,12 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { // If we are removing an IPv6 unicast address, leave the solicited-node // multicast address. + // + // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node + // multicast group may be left by user action. if isIPv6Unicast { snmc := header.SolicitedNodeAddr(addr) - if err := n.leaveGroupLocked(snmc); err != nil { + if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { return err } } @@ -1083,26 +1105,31 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - return n.leaveGroupLocked(addr) + return n.leaveGroupLocked(addr, false /* force */) } // leaveGroupLocked decrements the count for the given multicast address, and // when it reaches zero removes the endpoint for this address. n MUST be locked // before leaveGroupLocked is called. -func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { +// +// If force is true, then the count for the multicast addres is ignored and the +// endpoint will be removed immediately. +func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error { id := NetworkEndpointID{addr} - joins := n.mu.mcastJoins[id] - switch joins { - case 0: + joins, ok := n.mu.mcastJoins[id] + if !ok { // There are no joins with this address on this NIC. return tcpip.ErrBadLocalAddress - case 1: - // This is the last one, clean up. - if err := n.removePermanentAddressLocked(addr); err != nil { - return err - } } - n.mu.mcastJoins[id] = joins - 1 + + joins-- + if force || joins == 0 { + // There are no outstanding joins or we are forced to leave, clean up. + delete(n.mu.mcastJoins, id) + return n.removePermanentAddressLocked(addr) + } + + n.mu.mcastJoins[id] = joins return nil } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index f9fd8f18f..fa28b46b1 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -401,6 +401,9 @@ type LinkEndpoint interface { // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. + // + // Attach will be called with a nil dispatcher if the receiver's associated + // NIC is being removed. Attach(dispatcher NetworkDispatcher) // IsAttached returns whether a NetworkDispatcher is attached to the diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9515426d6..9836b340f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -255,7 +255,7 @@ type linkEPWithMockedAttach struct { // Attach implements stack.LinkEndpoint.Attach. func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) { l.LinkEndpoint.Attach(d) - l.attached = true + l.attached = d != nil } func (l *linkEPWithMockedAttach) isAttached() bool { @@ -566,7 +566,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) { t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err) } if !e.isAttached() { - t.Fatalf("link endpoint not attached to a network disatcher") + t.Fatal("link endpoint not attached to a network dispatcher") } }) } @@ -631,196 +631,240 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { checkNIC(false) } -func TestRoutesWithDisabledNIC(t *testing.T) { - const unspecifiedNIC = 0 - const nicID1 = 1 - const nicID2 = 2 - +func TestRemoveUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) - ep1 := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) } +} - addr1 := tcpip.Address("\x01") - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) - } +func TestRemoveNIC(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - ep2 := channel.New(0, defaultMTU, "") - if err := s.CreateNIC(nicID2, ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + e := linkEPWithMockedAttach{ + LinkEndpoint: loopback.New(), + } + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - addr2 := tcpip.Address("\x02") - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + // NIC should be present in NICInfo and attached to a NetworkDispatcher. + allNICInfo := s.NICInfo() + if _, ok := allNICInfo[nicID]; !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } + if !e.isAttached() { + t.Fatal("link endpoint not attached to a network dispatcher") } - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, - {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, - }) + // Removing a NIC should remove it from NICInfo and e should be detached from + // the NetworkDispatcher. + if err := s.RemoveNIC(nicID); err != nil { + t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) } + if nicInfo, ok := s.NICInfo()[nicID]; ok { + t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) + } + if e.isAttached() { + t.Error("link endpoint for removed NIC still attached to a network dispatcher") + } +} - // Test routes to odd address. - testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) - testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) - testRoute(t, s, nicID1, addr1, "\x05", addr1) +func TestRouteWithDownNIC(t *testing.T) { + tests := []struct { + name string + downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error + upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error + }{ + { + name: "Disabled NIC", + downFn: (*stack.Stack).DisableNIC, + upFn: (*stack.Stack).EnableNIC, + }, + + // Once a NIC is removed, it cannot be brought up. + { + name: "Removed NIC", + downFn: (*stack.Stack).RemoveNIC, + }, + } - // Test routes to even address. - testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) - testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) - testRoute(t, s, nicID2, addr2, "\x06", addr2) - - // Disabling NIC1 should result in no routes to odd addresses. Routes to even - // addresses should continue to be available as NIC2 is still enabled. - if err := s.DisableNIC(nicID1); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) - } - nic1Dst := tcpip.Address("\x05") - testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) - testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) - testNoRoute(t, s, nicID1, addr1, nic1Dst) - nic2Dst := tcpip.Address("\x06") - testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) - testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) - testRoute(t, s, nicID2, addr2, nic2Dst, addr2) - - // Disabling NIC2 should result in no routes to even addresses. No route - // should be available to any address as routes to odd addresses were made - // unavailable by disabling NIC1 above. - if err := s.DisableNIC(nicID2); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) - } - testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) - testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) - testNoRoute(t, s, nicID1, addr1, nic1Dst) - testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) - testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) - testNoRoute(t, s, nicID2, addr2, nic2Dst) - - // Enabling NIC1 should make routes to odd addresses available again. Routes - // to even addresses should continue to be unavailable as NIC2 is still - // disabled. - if err := s.EnableNIC(nicID1); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID1, err) - } - testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) - testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) - testRoute(t, s, nicID1, addr1, nic1Dst, addr1) - testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) - testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) - testNoRoute(t, s, nicID2, addr2, nic2Dst) -} - -func TestRouteWritePacketWithDisabledNIC(t *testing.T) { const unspecifiedNIC = 0 const nicID1 = 1 const nicID2 = 2 + const addr1 = tcpip.Address("\x01") + const addr2 = tcpip.Address("\x02") + const nic1Dst = tcpip.Address("\x05") + const nic2Dst = tcpip.Address("\x06") + + setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } - ep1 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } + if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + } - addr1 := tcpip.Address("\x01") - if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) - } + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } - ep2 := channel.New(1, defaultMTU, "") - if err := s.CreateNIC(nicID2, ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } + if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + } - addr2 := tcpip.Address("\x02") - if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { - t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, + {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, + }) + } + + return s, ep1, ep2 } - // Set a route table that sends all packets with odd destination - // addresses through the first NIC, and all even destination address - // through the second one. - { - subnet0, err := tcpip.NewSubnet("\x00", "\x01") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\x01") - if err != nil { - t.Fatal(err) + // Tests that routes through a down NIC are not used when looking up a route + // for a destination. + t.Run("Find", func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, _, _ := setup(t) + + // Test routes to odd address. + testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) + testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) + testRoute(t, s, nicID1, addr1, "\x05", addr1) + + // Test routes to even address. + testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) + testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) + testRoute(t, s, nicID2, addr2, "\x06", addr2) + + // Bringing NIC1 down should result in no routes to odd addresses. Routes to + // even addresses should continue to be available as NIC2 is still up. + if err := test.downFn(s, nicID1); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID1, err) + } + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) + testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) + testRoute(t, s, nicID2, addr2, nic2Dst, addr2) + + // Bringing NIC2 down should result in no routes to even addresses. No + // route should be available to any address as routes to odd addresses + // were made unavailable by bringing NIC1 down above. + if err := test.downFn(s, nicID2); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID2, err) + } + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) + + if upFn := test.upFn; upFn != nil { + // Bringing NIC1 up should make routes to odd addresses available + // again. Routes to even addresses should continue to be unavailable + // as NIC2 is still down. + if err := upFn(s, nicID1); err != nil { + t.Fatalf("test.upFn(_, %d): %s", nicID1, err) + } + testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) + testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) + testRoute(t, s, nicID1, addr1, nic1Dst, addr1) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) + } + }) } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, - {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, - }) - } + }) - nic1Dst := tcpip.Address("\x05") - r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) - } - defer r1.Release() + // Tests that writing a packet using a Route through a down NIC fails. + t.Run("WritePacket", func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, ep1, ep2 := setup(t) - nic2Dst := tcpip.Address("\x06") - r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) - if err != nil { - t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) - } - defer r2.Release() + r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) + } + defer r1.Release() - // If we failed to get routes r1 or r2, we cannot proceed with the test. - if t.Failed() { - t.FailNow() - } + r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) + } + defer r2.Release() - buf := buffer.View([]byte{1}) - testSend(t, r1, ep1, buf) - testSend(t, r2, ep2, buf) + // If we failed to get routes r1 or r2, we cannot proceed with the test. + if t.Failed() { + t.FailNow() + } - // Writes with Routes that use the disabled NIC1 should fail. - if err := s.DisableNIC(nicID1); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) - } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) - testSend(t, r2, ep2, buf) + buf := buffer.View([]byte{1}) + testSend(t, r1, ep1, buf) + testSend(t, r2, ep2, buf) - // Writes with Routes that use the disabled NIC2 should fail. - if err := s.DisableNIC(nicID2); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) - } - testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + // Writes with Routes that use NIC1 after being brought down should fail. + if err := test.downFn(s, nicID1); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID1, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testSend(t, r2, ep2, buf) - // Writes with Routes that use the re-enabled NIC1 should succeed. - // TODO(b/147015577): Should we instead completely invalidate all Routes that - // were bound to a disabled NIC at some point? - if err := s.EnableNIC(nicID1); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID1, err) - } - testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + // Writes with Routes that use NIC2 after being brought down should fail. + if err := test.downFn(s, nicID2); err != nil { + t.Fatalf("test.downFn(_, %d): %s", nicID2, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + + if upFn := test.upFn; upFn != nil { + // Writes with Routes that use NIC1 after being brought up should + // succeed. + // + // TODO(b/147015577): Should we instead completely invalidate all + // Routes that were bound to a NIC that was brought down at some + // point? + if err := upFn(s, nicID1); err != nil { + t.Fatalf("test.upFn(_, %d): %s", nicID1, err) + } + testSend(t, r1, ep1, buf) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + } + }) + } + }) } func TestRoutes(t *testing.T) { @@ -3038,6 +3082,50 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { } } +// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6 +// address after leaving its solicited node multicast address does not result in +// an error. +func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + e := channel.New(10, 1280, linkAddr1) + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err) + } + + // The NIC should have joined addr1's solicited node multicast address. + snmc := header.SolicitedNodeAddr(addr1) + in, err := s.IsInGroup(nicID, snmc) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) + } + if !in { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc) + } + + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err) + } + in, err = s.IsInGroup(nicID, snmc) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err) + } + if in { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc) + } + + if err := s.RemoveAddress(nicID, addr1); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) + } +} + func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { const nicID = 1 -- cgit v1.2.3 From 645b1b2e9cd40084a42d6168de72a915449780b7 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 13 Mar 2020 14:58:16 -0700 Subject: Refactor SLAAC address state into SLAAC prefix state Previously, SLAAC related state was stored on a per-address basis. This was sufficient for the simple case of a single SLAAC address per prefix, but future CLs will introduce temporary addresses which will result in multiple SLAAC addresses for a prefix. This refactor allows storing multiple addresses for a prefix in a single SLAAC prefix state. No behaviour changes - existing tests continue to pass. PiperOrigin-RevId: 300832812 --- pkg/tcpip/stack/ndp.go | 296 ++++++++++++++++++++++++++----------------------- pkg/tcpip/stack/nic.go | 20 ++-- 2 files changed, 167 insertions(+), 149 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index a9f4d5dad..d689a006d 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -361,16 +361,16 @@ type ndpState struct { // The default routers discovered through Router Advertisements. defaultRouters map[tcpip.Address]defaultRouterState + // The timer used to send the next router solicitation message. + rtrSolicitTimer *time.Timer + // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState - // The timer used to send the next router solicitation message. - // If routers are being solicited, rtrSolicitTimer MUST NOT be nil. - rtrSolicitTimer *time.Timer - - // The addresses generated by SLAAC. - autoGenAddresses map[tcpip.Address]autoGenAddressState + // The SLAAC prefixes discovered through Router Advertisements' Prefix + // Information option. + slaacPrefixes map[tcpip.Subnet]slaacPrefixState // The last learned DHCPv6 configuration from an NDP RA. dhcpv6Configuration DHCPv6ConfigurationFromNDPRA @@ -402,18 +402,16 @@ type onLinkPrefixState struct { invalidationTimer tcpip.CancellableTimer } -// autoGenAddressState holds data associated with an address generated via -// SLAAC. -type autoGenAddressState struct { - // A reference to the referencedNetworkEndpoint that this autoGenAddressState - // is holding state for. - ref *referencedNetworkEndpoint - +// slaacPrefixState holds state associated with a SLAAC prefix. +type slaacPrefixState struct { deprecationTimer tcpip.CancellableTimer invalidationTimer tcpip.CancellableTimer // Nonzero only when the address is not valid forever. validUntil time.Time + + // The prefix's permanent address endpoint. + ref *referencedNetworkEndpoint } // startDuplicateAddressDetection performs Duplicate Address Detection. @@ -899,23 +897,15 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform prefix := pi.Subnet() - // Check if we already have an auto-generated address for prefix. - for addr, addrState := range ndp.autoGenAddresses { - refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: addrState.ref.ep.PrefixLen()} - if refAddrWithPrefix.Subnet() != prefix { - continue - } - - // At this point, we know we are refreshing a SLAAC generated IPv6 address - // with the prefix prefix. Do the work as outlined by RFC 4862 section - // 5.5.3.e. - ndp.refreshAutoGenAddressLifetimes(addr, pl, vl) + // Check if we already maintain SLAAC state for prefix. + if _, ok := ndp.slaacPrefixes[prefix]; ok { + // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes. + ndp.refreshSLAACPrefixLifetimes(prefix, pl, vl) return } - // We do not already have an address with the prefix prefix. Do the - // work as outlined by RFC 4862 section 5.5.3.d if n is configured - // to auto-generate global addresses by SLAAC. + // prefix is a new SLAAC prefix. Do the work as outlined by RFC 4862 section + // 5.5.3.d if ndp is configured to auto-generate new addresses via SLAAC. if !ndp.configs.AutoGenGlobalAddresses { return } @@ -927,6 +917,8 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform // for prefix. // // pl is the new preferred lifetime. vl is the new valid lifetime. +// +// The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // If we do not already have an address for this prefix and the valid // lifetime is 0, no need to do anything further, as per RFC 4862 @@ -942,9 +934,59 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { return } + // If the preferred lifetime is zero, then the prefix should be considered + // deprecated. + deprecated := pl == 0 + ref := ndp.addSLAACAddr(prefix, deprecated) + if ref == nil { + // We were unable to generate a permanent address for prefix so do nothing + // further as there is no reason to maintain state for a SLAAC prefix we + // cannot generate a permanent address for. + return + } + + state := slaacPrefixState{ + deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + log.Fatalf("ndp: must have a slaacPrefixes entry for the SLAAC prefix %s", prefix) + } + + ndp.deprecateSLAACAddress(prefixState.ref) + }), + invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + ndp.invalidateSLAACPrefix(prefix, true) + }), + ref: ref, + } + + // Setup the initial timers to deprecate and invalidate prefix. + + if !deprecated && pl < header.NDPInfiniteLifetime { + state.deprecationTimer.Reset(pl) + } + + if vl < header.NDPInfiniteLifetime { + state.invalidationTimer.Reset(vl) + state.validUntil = time.Now().Add(vl) + } + + ndp.slaacPrefixes[prefix] = state +} + +// addSLAACAddr adds a SLAAC address for prefix. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referencedNetworkEndpoint { addrBytes := []byte(prefix.ID()) if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { - addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), 0 /* dadCounter */, oIID.SecretKey) + addrBytes = header.AppendOpaqueInterfaceIdentifier( + addrBytes[:header.IIDOffsetInIPv6Address], + prefix, + oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), + 0, /* dadCounter */ + oIID.SecretKey, + ) } else { // Only attempt to generate an interface-specific IID if we have a valid // link address. @@ -953,137 +995,103 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // LinkEndpoint.LinkAddress) before reaching this point. linkAddr := ndp.nic.linkEP.LinkAddress() if !header.IsValidUnicastEthernetAddress(linkAddr) { - return + return nil } // Generate an address within prefix from the modified EUI-64 of ndp's NIC's // Ethernet MAC address. header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) } - addr := tcpip.Address(addrBytes) - addrWithPrefix := tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: validPrefixLenForAutoGen, + + generatedAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: validPrefixLenForAutoGen, + }, } // If the nic already has this address, do nothing further. - if ndp.nic.hasPermanentAddrLocked(addr) { - return + if ndp.nic.hasPermanentAddrLocked(generatedAddr.AddressWithPrefix.Address) { + return nil } // Inform the integrator that we have a new SLAAC address. ndpDisp := ndp.nic.stack.ndpDisp if ndpDisp == nil { - return + return nil } - if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) { + + if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), generatedAddr.AddressWithPrefix) { // Informed by the integrator not to add the address. - return + return nil } - protocolAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addrWithPrefix, - } - // If the preferred lifetime is zero, then the address should be considered - // deprecated. - deprecated := pl == 0 - ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) + ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) if err != nil { - log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err) - } - - state := autoGenAddressState{ - ref: ref, - deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { - addrState, ok := ndp.autoGenAddresses[addr] - if !ok { - log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr) - } - addrState.ref.deprecated = true - ndp.notifyAutoGenAddressDeprecated(addr) - }), - invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { - ndp.invalidateAutoGenAddress(addr) - }), + log.Fatalf("ndp: error when adding address %+v: %s", generatedAddr, err) } - // Setup the initial timers to deprecate and invalidate this newly generated - // address. - - if !deprecated && pl < header.NDPInfiniteLifetime { - state.deprecationTimer.Reset(pl) - } - - if vl < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(vl) - state.validUntil = time.Now().Add(vl) - } - - ndp.autoGenAddresses[addr] = state + return ref } -// refreshAutoGenAddressLifetimes refreshes the lifetime of a SLAAC generated -// address addr. +// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix. // // pl is the new preferred lifetime. vl is the new valid lifetime. -func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl time.Duration) { - addrState, ok := ndp.autoGenAddresses[addr] +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl time.Duration) { + prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { - log.Fatalf("ndp: SLAAC state not found to refresh lifetimes for %s", addr) + log.Fatalf("ndp: SLAAC prefix state not found to refresh lifetimes for %s", prefix) } - defer func() { ndp.autoGenAddresses[addr] = addrState }() + defer func() { ndp.slaacPrefixes[prefix] = prefixState }() - // If the preferred lifetime is zero, then the address should be considered - // deprecated. + // If the preferred lifetime is zero, then the prefix should be deprecated. deprecated := pl == 0 - wasDeprecated := addrState.ref.deprecated - addrState.ref.deprecated = deprecated - - // Only send the deprecation event if the deprecated status for addr just - // changed from non-deprecated to deprecated. - if !wasDeprecated && deprecated { - ndp.notifyAutoGenAddressDeprecated(addr) + if deprecated { + ndp.deprecateSLAACAddress(prefixState.ref) + } else { + prefixState.ref.deprecated = false } - // If addr was preferred for some finite lifetime before, stop the deprecation - // timer so it can be reset. - addrState.deprecationTimer.StopLocked() + // If prefix was preferred for some finite lifetime before, stop the + // deprecation timer so it can be reset. + prefixState.deprecationTimer.StopLocked() - // Reset the deprecation timer if addr has a finite preferred lifetime. + // Reset the deprecation timer if prefix has a finite preferred lifetime. if !deprecated && pl < header.NDPInfiniteLifetime { - addrState.deprecationTimer.Reset(pl) + prefixState.deprecationTimer.Reset(pl) } - // As per RFC 4862 section 5.5.3.e, the valid lifetime of the address - // + // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix: // // 1) If the received Valid Lifetime is greater than 2 hours or greater than - // RemainingLifetime, set the valid lifetime of the address to the + // RemainingLifetime, set the valid lifetime of the prefix to the // advertised Valid Lifetime. // // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the // advertised Valid Lifetime. // - // 3) Otherwise, reset the valid lifetime of the address to 2 hours. + // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. // Handle the infinite valid lifetime separately as we do not keep a timer in // this case. if vl >= header.NDPInfiniteLifetime { - addrState.invalidationTimer.StopLocked() - addrState.validUntil = time.Time{} + prefixState.invalidationTimer.StopLocked() + prefixState.validUntil = time.Time{} return } var effectiveVl time.Duration var rl time.Duration - // If the address was originally set to be valid forever, assume the remaining + // If the prefix was originally set to be valid forever, assume the remaining // time to be the maximum possible value. - if addrState.validUntil == (time.Time{}) { + if prefixState.validUntil == (time.Time{}) { rl = header.NDPInfiniteLifetime } else { - rl = time.Until(addrState.validUntil) + rl = time.Until(prefixState.validUntil) } if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { @@ -1094,58 +1102,66 @@ func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl t effectiveVl = MinPrefixInformationValidLifetimeForUpdate } - addrState.invalidationTimer.StopLocked() - addrState.invalidationTimer.Reset(effectiveVl) - addrState.validUntil = time.Now().Add(effectiveVl) -} - -// notifyAutoGenAddressDeprecated notifies the stack's NDP dispatcher that addr -// has been deprecated. -func (ndp *ndpState) notifyAutoGenAddressDeprecated(addr tcpip.Address) { - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: validPrefixLenForAutoGen, - }) - } + prefixState.invalidationTimer.StopLocked() + prefixState.invalidationTimer.Reset(effectiveVl) + prefixState.validUntil = time.Now().Add(effectiveVl) } -// invalidateAutoGenAddress invalidates an auto-generated address. +// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP +// dispatcher that ref has been deprecated. +// +// deprecateSLAACAddress does nothing if ref is already deprecated. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) { - if !ndp.cleanupAutoGenAddrResourcesAndNotify(addr) { +func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { + if ref.deprecated { return } - ndp.nic.removePermanentAddressLocked(addr) + ref.deprecated = true + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{ + Address: ref.ep.ID().LocalAddress, + PrefixLen: ref.ep.PrefixLen(), + }) + } } -// cleanupAutoGenAddrResourcesAndNotify cleans up an invalidated auto-generated -// address's resources from ndp. If the stack has an NDP dispatcher, it will -// be notified that addr has been invalidated. -// -// Returns true if ndp had resources for addr to cleanup. +// invalidateSLAACPrefix invalidates a SLAAC prefix. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool { - state, ok := ndp.autoGenAddresses[addr] +func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, removeAddr bool) { + state, ok := ndp.slaacPrefixes[prefix] if !ok { - return false + return } state.deprecationTimer.StopLocked() state.invalidationTimer.StopLocked() - delete(ndp.autoGenAddresses, addr) + delete(ndp.slaacPrefixes, prefix) + + addr := state.ref.ep.ID().LocalAddress + + if removeAddr { + if err := ndp.nic.removePermanentAddressLocked(addr); err != nil { + log.Fatalf("ndp: removePermanentAddressLocked(%s): %s", addr, err) + } + } if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ Address: addr, - PrefixLen: validPrefixLenForAutoGen, + PrefixLen: state.ref.ep.PrefixLen(), }) } +} - return true +// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC +// address's resources from ndp. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix) { + ndp.invalidateSLAACPrefix(addr.Subnet(), false) } // cleanupState cleans up ndp's state. @@ -1163,21 +1179,21 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupState(hostOnly bool) { linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() - linkLocalAddrs := 0 - for addr := range ndp.autoGenAddresses { + linkLocalPrefixes := 0 + for prefix := range ndp.slaacPrefixes { // RFC 4862 section 5 states that routers are also expected to generate a // link-local address so we do not invalidate them if we are cleaning up // host-only state. - if hostOnly && linkLocalSubnet.Contains(addr) { - linkLocalAddrs++ + if hostOnly && prefix == linkLocalSubnet { + linkLocalPrefixes++ continue } - ndp.invalidateAutoGenAddress(addr) + ndp.invalidateSLAACPrefix(prefix, true) } - if got := len(ndp.autoGenAddresses); got != linkLocalAddrs { - log.Fatalf("ndp: still have non-linklocal auto-generated addresses after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalAddrs) + if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { + log.Fatalf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes) } for prefix := range ndp.onLinkPrefixes { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 11eaa6a2c..9dcb1d52c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -126,12 +126,12 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32) nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) nic.mu.ndp = ndpState{ - nic: nic, - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), + nic: nic, + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), } // Register supported packet endpoint protocols. @@ -1017,8 +1017,7 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) if isIPv6Unicast { - // If we are removing a tentative IPv6 unicast address, stop - // DAD. + // If we are removing a tentative IPv6 unicast address, stop DAD. if kind == permanentTentative { n.mu.ndp.stopDuplicateAddressDetection(addr) } @@ -1026,7 +1025,10 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. if r.configType == slaac { - n.mu.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) + n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: r.ep.PrefixLen(), + }) } } -- cgit v1.2.3 From 7e4073af12bed2c76bc5757ef3e5fbfba75308a0 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 24 Mar 2020 09:05:06 -0700 Subject: Move tcpip.PacketBuffer and IPTables to stack package. This is a precursor to be being able to build an intrusive list of PacketBuffers for use in queuing disciplines being implemented. Updates #2214 PiperOrigin-RevId: 302677662 --- pkg/sentry/socket/netfilter/BUILD | 1 - pkg/sentry/socket/netfilter/extensions.go | 14 +- pkg/sentry/socket/netfilter/netfilter.go | 121 ++++---- pkg/sentry/socket/netfilter/targets.go | 11 +- pkg/sentry/socket/netfilter/tcp_matcher.go | 11 +- pkg/sentry/socket/netfilter/udp_matcher.go | 13 +- pkg/sentry/socket/netstack/BUILD | 1 - pkg/sentry/socket/netstack/stack.go | 3 +- pkg/tcpip/BUILD | 2 - pkg/tcpip/iptables/BUILD | 18 -- pkg/tcpip/iptables/iptables.go | 314 --------------------- pkg/tcpip/iptables/targets.go | 144 ---------- pkg/tcpip/iptables/types.go | 180 ------------ pkg/tcpip/link/channel/channel.go | 14 +- pkg/tcpip/link/fdbased/endpoint.go | 6 +- pkg/tcpip/link/fdbased/endpoint_test.go | 10 +- pkg/tcpip/link/fdbased/mmap.go | 3 +- pkg/tcpip/link/fdbased/packet_dispatchers.go | 4 +- pkg/tcpip/link/loopback/loopback.go | 8 +- pkg/tcpip/link/muxed/injectable.go | 6 +- pkg/tcpip/link/muxed/injectable_test.go | 4 +- pkg/tcpip/link/sharedmem/sharedmem.go | 6 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 26 +- pkg/tcpip/link/sniffer/sniffer.go | 10 +- pkg/tcpip/link/tun/device.go | 2 +- pkg/tcpip/link/waitable/waitable.go | 6 +- pkg/tcpip/link/waitable/waitable_test.go | 18 +- pkg/tcpip/network/arp/arp.go | 12 +- pkg/tcpip/network/arp/arp_test.go | 2 +- pkg/tcpip/network/ip_test.go | 24 +- pkg/tcpip/network/ipv4/BUILD | 1 - pkg/tcpip/network/ipv4/icmp.go | 9 +- pkg/tcpip/network/ipv4/ipv4.go | 21 +- pkg/tcpip/network/ipv4/ipv4_test.go | 18 +- pkg/tcpip/network/ipv6/icmp.go | 10 +- pkg/tcpip/network/ipv6/icmp_test.go | 14 +- pkg/tcpip/network/ipv6/ipv6.go | 10 +- pkg/tcpip/network/ipv6/ipv6_test.go | 4 +- pkg/tcpip/network/ipv6/ndp_test.go | 8 +- pkg/tcpip/packet_buffer.go | 67 ----- pkg/tcpip/packet_buffer_state.go | 27 -- pkg/tcpip/stack/BUILD | 8 +- pkg/tcpip/stack/forwarder.go | 4 +- pkg/tcpip/stack/forwarder_test.go | 36 +-- pkg/tcpip/stack/iptables.go | 311 ++++++++++++++++++++ pkg/tcpip/stack/iptables_targets.go | 144 ++++++++++ pkg/tcpip/stack/iptables_types.go | 180 ++++++++++++ pkg/tcpip/stack/ndp.go | 4 +- pkg/tcpip/stack/ndp_test.go | 14 +- pkg/tcpip/stack/nic.go | 13 +- pkg/tcpip/stack/nic_test.go | 3 +- pkg/tcpip/stack/packet_buffer.go | 66 +++++ pkg/tcpip/stack/packet_buffer_state.go | 26 ++ pkg/tcpip/stack/registration.go | 32 +-- pkg/tcpip/stack/route.go | 6 +- pkg/tcpip/stack/stack.go | 11 +- pkg/tcpip/stack/stack_test.go | 28 +- pkg/tcpip/stack/transport_demuxer.go | 14 +- pkg/tcpip/stack/transport_demuxer_test.go | 2 +- pkg/tcpip/stack/transport_test.go | 27 +- pkg/tcpip/transport/icmp/BUILD | 1 - pkg/tcpip/transport/icmp/endpoint.go | 11 +- pkg/tcpip/transport/icmp/protocol.go | 2 +- pkg/tcpip/transport/packet/BUILD | 1 - pkg/tcpip/transport/packet/endpoint.go | 9 +- pkg/tcpip/transport/raw/BUILD | 1 - pkg/tcpip/transport/raw/endpoint.go | 9 +- pkg/tcpip/transport/tcp/BUILD | 1 - pkg/tcpip/transport/tcp/connect.go | 6 +- pkg/tcpip/transport/tcp/dispatcher.go | 3 +- pkg/tcpip/transport/tcp/endpoint.go | 7 +- pkg/tcpip/transport/tcp/forwarder.go | 2 +- pkg/tcpip/transport/tcp/protocol.go | 4 +- pkg/tcpip/transport/tcp/segment.go | 3 +- pkg/tcpip/transport/tcp/testing/context/context.go | 10 +- pkg/tcpip/transport/udp/BUILD | 1 - pkg/tcpip/transport/udp/endpoint.go | 9 +- pkg/tcpip/transport/udp/forwarder.go | 4 +- pkg/tcpip/transport/udp/protocol.go | 6 +- pkg/tcpip/transport/udp/udp_test.go | 4 +- 80 files changed, 1080 insertions(+), 1126 deletions(-) delete mode 100644 pkg/tcpip/iptables/BUILD delete mode 100644 pkg/tcpip/iptables/iptables.go delete mode 100644 pkg/tcpip/iptables/targets.go delete mode 100644 pkg/tcpip/iptables/types.go delete mode 100644 pkg/tcpip/packet_buffer.go delete mode 100644 pkg/tcpip/packet_buffer_state.go create mode 100644 pkg/tcpip/stack/iptables.go create mode 100644 pkg/tcpip/stack/iptables_targets.go create mode 100644 pkg/tcpip/stack/iptables_types.go create mode 100644 pkg/tcpip/stack/packet_buffer.go create mode 100644 pkg/tcpip/stack/packet_buffer_state.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 7cd2ce55b..e801abeb8 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -22,7 +22,6 @@ go_library( "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/usermem", ], diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index b4b244abf..0336a32d8 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -19,7 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -37,12 +37,12 @@ type matchMaker interface { // name is the matcher name as stored in the xt_entry_match struct. name() string - // marshal converts from an iptables.Matcher to an ABI struct. - marshal(matcher iptables.Matcher) []byte + // marshal converts from an stack.Matcher to an ABI struct. + marshal(matcher stack.Matcher) []byte // unmarshal converts from the ABI matcher struct to an - // iptables.Matcher. - unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) + // stack.Matcher. + unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) } // matchMakers maps the name of supported matchers to the matchMaker that @@ -58,7 +58,7 @@ func registerMatchMaker(mm matchMaker) { matchMakers[mm.name()] = mm } -func marshalMatcher(matcher iptables.Matcher) []byte { +func marshalMatcher(matcher stack.Matcher) []byte { matchMaker, ok := matchMakers[matcher.Name()] if !ok { panic(fmt.Sprintf("Unknown matcher of type %T.", matcher)) @@ -86,7 +86,7 @@ func marshalEntryMatch(name string, data []byte) []byte { return append(buf, make([]byte, size-len(buf))...) } -func unmarshalMatcher(match linux.XTEntryMatch, filter iptables.IPHeaderFilter, buf []byte) (iptables.Matcher, error) { +func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) { matchMaker, ok := matchMakers[match.Name.String()] if !ok { return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String()) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index b5b9be46f..55bcc3ace 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -129,19 +128,19 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen return entries, nil } -func findTable(stack *stack.Stack, tablename linux.TableName) (iptables.Table, error) { - ipt := stack.IPTables() +func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) { + ipt := stk.IPTables() table, ok := ipt.Tables[tablename.String()] if !ok { - return iptables.Table{}, fmt.Errorf("couldn't find table %q", tablename) + return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename) } return table, nil } // FillDefaultIPTables sets stack's IPTables to the default tables and // populates them with metadata. -func FillDefaultIPTables(stack *stack.Stack) { - ipt := iptables.DefaultTables() +func FillDefaultIPTables(stk *stack.Stack) { + ipt := stack.DefaultTables() // In order to fill in the metadata, we have to translate ipt from its // netstack format to Linux's giant-binary-blob format. @@ -154,14 +153,14 @@ func FillDefaultIPTables(stack *stack.Stack) { ipt.Tables[name] = table } - stack.SetIPTables(ipt) + stk.SetIPTables(ipt) } // convertNetstackToBinary converts the iptables as stored in netstack to the // format expected by the iptables tool. Linux stores each table as a binary // blob that can only be traversed by parsing a bit, reading some offsets, // jumping to those offsets, parsing again, etc. -func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, error) { +func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelIPTGetEntries, metadata, error) { // Return values. var entries linux.KernelIPTGetEntries var meta metadata @@ -234,19 +233,19 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern return entries, meta, nil } -func marshalTarget(target iptables.Target) []byte { +func marshalTarget(target stack.Target) []byte { switch tg := target.(type) { - case iptables.AcceptTarget: - return marshalStandardTarget(iptables.RuleAccept) - case iptables.DropTarget: - return marshalStandardTarget(iptables.RuleDrop) - case iptables.ErrorTarget: + case stack.AcceptTarget: + return marshalStandardTarget(stack.RuleAccept) + case stack.DropTarget: + return marshalStandardTarget(stack.RuleDrop) + case stack.ErrorTarget: return marshalErrorTarget(errorTargetName) - case iptables.UserChainTarget: + case stack.UserChainTarget: return marshalErrorTarget(tg.Name) - case iptables.ReturnTarget: - return marshalStandardTarget(iptables.RuleReturn) - case iptables.RedirectTarget: + case stack.ReturnTarget: + return marshalStandardTarget(stack.RuleReturn) + case stack.RedirectTarget: return marshalRedirectTarget() case JumpTarget: return marshalJumpTarget(tg) @@ -255,7 +254,7 @@ func marshalTarget(target iptables.Target) []byte { } } -func marshalStandardTarget(verdict iptables.RuleVerdict) []byte { +func marshalStandardTarget(verdict stack.RuleVerdict) []byte { nflog("convert to binary: marshalling standard target") // The target's name will be the empty string. @@ -316,13 +315,13 @@ func marshalJumpTarget(jt JumpTarget) []byte { // translateFromStandardVerdict translates verdicts the same way as the iptables // tool. -func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 { +func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 { switch verdict { - case iptables.RuleAccept: + case stack.RuleAccept: return -linux.NF_ACCEPT - 1 - case iptables.RuleDrop: + case stack.RuleDrop: return -linux.NF_DROP - 1 - case iptables.RuleReturn: + case stack.RuleReturn: return linux.NF_RETURN default: // TODO(gvisor.dev/issue/170): Support Jump. @@ -331,18 +330,18 @@ func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 { } // translateToStandardTarget translates from the value in a -// linux.XTStandardTarget to an iptables.Verdict. -func translateToStandardTarget(val int32) (iptables.Target, error) { +// linux.XTStandardTarget to an stack.Verdict. +func translateToStandardTarget(val int32) (stack.Target, error) { // TODO(gvisor.dev/issue/170): Support other verdicts. switch val { case -linux.NF_ACCEPT - 1: - return iptables.AcceptTarget{}, nil + return stack.AcceptTarget{}, nil case -linux.NF_DROP - 1: - return iptables.DropTarget{}, nil + return stack.DropTarget{}, nil case -linux.NF_QUEUE - 1: return nil, errors.New("unsupported iptables verdict QUEUE") case linux.NF_RETURN: - return iptables.ReturnTarget{}, nil + return stack.ReturnTarget{}, nil default: return nil, fmt.Errorf("unknown iptables verdict %d", val) } @@ -350,7 +349,7 @@ func translateToStandardTarget(val int32) (iptables.Target, error) { // SetEntries sets iptables rules for a single table. See // net/ipv4/netfilter/ip_tables.c:translate_table for reference. -func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { +func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // Get the basic rules data (struct ipt_replace). if len(optVal) < linux.SizeOfIPTReplace { nflog("optVal has insufficient size for replace %d", len(optVal)) @@ -362,12 +361,12 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace) // TODO(gvisor.dev/issue/170): Support other tables. - var table iptables.Table + var table stack.Table switch replace.Name.String() { - case iptables.TablenameFilter: - table = iptables.EmptyFilterTable() - case iptables.TablenameNat: - table = iptables.EmptyNatTable() + case stack.TablenameFilter: + table = stack.EmptyFilterTable() + case stack.TablenameNat: + table = stack.EmptyNatTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) return syserr.ErrInvalidArgument @@ -434,7 +433,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { } optVal = optVal[targetSize:] - table.Rules = append(table.Rules, iptables.Rule{ + table.Rules = append(table.Rules, stack.Rule{ Filter: filter, Target: target, Matchers: matchers, @@ -465,11 +464,11 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { table.Underflows[hk] = ruleIdx } } - if ruleIdx := table.BuiltinChains[hk]; ruleIdx == iptables.HookUnset { + if ruleIdx := table.BuiltinChains[hk]; ruleIdx == stack.HookUnset { nflog("hook %v is unset.", hk) return syserr.ErrInvalidArgument } - if ruleIdx := table.Underflows[hk]; ruleIdx == iptables.HookUnset { + if ruleIdx := table.Underflows[hk]; ruleIdx == stack.HookUnset { nflog("underflow %v is unset.", hk) return syserr.ErrInvalidArgument } @@ -478,7 +477,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // Add the user chains. for ruleIdx, rule := range table.Rules { - target, ok := rule.Target.(iptables.UserChainTarget) + target, ok := rule.Target.(stack.UserChainTarget) if !ok { continue } @@ -522,8 +521,8 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // PREROUTING chain right now, make sure all other chains point to // ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook != iptables.Input && hook != iptables.Prerouting { - if _, ok := table.Rules[ruleIdx].Target.(iptables.AcceptTarget); !ok { + if hook != stack.Input && hook != stack.Prerouting { + if _, ok := table.Rules[ruleIdx].Target.(stack.AcceptTarget); !ok { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument } @@ -535,7 +534,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - ipt := stack.IPTables() + ipt := stk.IPTables() table.SetMetadata(metadata{ HookEntry: replace.HookEntry, Underflow: replace.Underflow, @@ -543,16 +542,16 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { Size: replace.Size, }) ipt.Tables[replace.Name.String()] = table - stack.SetIPTables(ipt) + stk.SetIPTables(ipt) return nil } // parseMatchers parses 0 or more matchers from optVal. optVal should contain // only the matchers. -func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Matcher, error) { +func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) { nflog("set entries: parsing matchers of size %d", len(optVal)) - var matchers []iptables.Matcher + var matchers []stack.Matcher for len(optVal) > 0 { nflog("set entries: optVal has len %d", len(optVal)) @@ -594,7 +593,7 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma // parseTarget parses a target from optVal. optVal should contain only the // target. -func parseTarget(filter iptables.IPHeaderFilter, optVal []byte) (iptables.Target, error) { +func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) { nflog("set entries: parsing target of size %d", len(optVal)) if len(optVal) < linux.SizeOfXTEntryTarget { return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal)) @@ -638,11 +637,11 @@ func parseTarget(filter iptables.IPHeaderFilter, optVal []byte) (iptables.Target switch name := errorTarget.Name.String(); name { case errorTargetName: nflog("set entries: error target") - return iptables.ErrorTarget{}, nil + return stack.ErrorTarget{}, nil default: // User defined chain. nflog("set entries: user-defined target %q", name) - return iptables.UserChainTarget{Name: name}, nil + return stack.UserChainTarget{Name: name}, nil } case redirectTargetName: @@ -659,8 +658,8 @@ func parseTarget(filter iptables.IPHeaderFilter, optVal []byte) (iptables.Target buf = optVal[:linux.SizeOfXTRedirectTarget] binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) - // Copy linux.XTRedirectTarget to iptables.RedirectTarget. - var target iptables.RedirectTarget + // Copy linux.XTRedirectTarget to stack.RedirectTarget. + var target stack.RedirectTarget nfRange := redirectTarget.NfRange // RangeSize should be 1. @@ -699,14 +698,14 @@ func parseTarget(filter iptables.IPHeaderFilter, optVal []byte) (iptables.Target return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet.", target.Name.String()) } -func filterFromIPTIP(iptip linux.IPTIP) (iptables.IPHeaderFilter, error) { +func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { if containsUnsupportedFields(iptip) { - return iptables.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) + return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) } if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { - return iptables.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) } - return iptables.IPHeaderFilter{ + return stack.IPHeaderFilter{ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), Dst: tcpip.Address(iptip.Dst[:]), DstMask: tcpip.Address(iptip.DstMask[:]), @@ -733,30 +732,30 @@ func containsUnsupportedFields(iptip linux.IPTIP) bool { iptip.InverseFlags&^inverseMask != 0 } -func validUnderflow(rule iptables.Rule) bool { +func validUnderflow(rule stack.Rule) bool { if len(rule.Matchers) != 0 { return false } switch rule.Target.(type) { - case iptables.AcceptTarget, iptables.DropTarget: + case stack.AcceptTarget, stack.DropTarget: return true default: return false } } -func hookFromLinux(hook int) iptables.Hook { +func hookFromLinux(hook int) stack.Hook { switch hook { case linux.NF_INET_PRE_ROUTING: - return iptables.Prerouting + return stack.Prerouting case linux.NF_INET_LOCAL_IN: - return iptables.Input + return stack.Input case linux.NF_INET_FORWARD: - return iptables.Forward + return stack.Forward case linux.NF_INET_LOCAL_OUT: - return iptables.Output + return stack.Output case linux.NF_INET_POST_ROUTING: - return iptables.Postrouting + return stack.Postrouting } panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook)) } diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index c421b87cf..c948de876 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -15,11 +15,10 @@ package netfilter import ( - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// JumpTarget implements iptables.Target. +// JumpTarget implements stack.Target. type JumpTarget struct { // Offset is the byte offset of the rule to jump to. It is used for // marshaling and unmarshaling. @@ -29,7 +28,7 @@ type JumpTarget struct { RuleNum int } -// Action implements iptables.Target.Action. -func (jt JumpTarget) Action(tcpip.PacketBuffer) (iptables.RuleVerdict, int) { - return iptables.RuleJump, jt.RuleNum +// Action implements stack.Target.Action. +func (jt JumpTarget) Action(stack.PacketBuffer) (stack.RuleVerdict, int) { + return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index f9945e214..ff1cfd8f6 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -19,9 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -40,7 +39,7 @@ func (tcpMarshaler) name() string { } // marshal implements matchMaker.marshal. -func (tcpMarshaler) marshal(mr iptables.Matcher) []byte { +func (tcpMarshaler) marshal(mr stack.Matcher) []byte { matcher := mr.(*TCPMatcher) xttcp := linux.XTTCP{ SourcePortStart: matcher.sourcePortStart, @@ -53,7 +52,7 @@ func (tcpMarshaler) marshal(mr iptables.Matcher) []byte { } // unmarshal implements matchMaker.unmarshal. -func (tcpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) { +func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { if len(buf) < linux.SizeOfXTTCP { return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf)) } @@ -97,7 +96,7 @@ func (*TCPMatcher) Name() string { } // Match implements Matcher.Match. -func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) { +func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { netHeader := header.IPv4(pkt.NetworkHeader) if netHeader.TransportProtocol() != header.TCPProtocolNumber { @@ -115,7 +114,7 @@ func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfac // Now we need the transport header. However, this may not have been set // yet. // TODO(gvisor.dev/issue/170): Parsing the transport header should - // ultimately be moved into the iptables.Check codepath as matchers are + // ultimately be moved into the stack.Check codepath as matchers are // added. var tcpHeader header.TCP if pkt.TransportHeader != nil { diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 86aa11696..3359418c1 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -19,9 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -40,7 +39,7 @@ func (udpMarshaler) name() string { } // marshal implements matchMaker.marshal. -func (udpMarshaler) marshal(mr iptables.Matcher) []byte { +func (udpMarshaler) marshal(mr stack.Matcher) []byte { matcher := mr.(*UDPMatcher) xtudp := linux.XTUDP{ SourcePortStart: matcher.sourcePortStart, @@ -53,7 +52,7 @@ func (udpMarshaler) marshal(mr iptables.Matcher) []byte { } // unmarshal implements matchMaker.unmarshal. -func (udpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) { +func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { if len(buf) < linux.SizeOfXTUDP { return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf)) } @@ -94,11 +93,11 @@ func (*UDPMatcher) Name() string { } // Match implements Matcher.Match. -func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) { +func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { netHeader := header.IPv4(pkt.NetworkHeader) // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved - // into the iptables.Check codepath as matchers are added. + // into the stack.Check codepath as matchers are added. if netHeader.TransportProtocol() != header.UDPProtocolNumber { return false, false } @@ -114,7 +113,7 @@ func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfac // Now we need the transport header. However, this may not have been set // yet. // TODO(gvisor.dev/issue/170): Parsing the transport header should - // ultimately be moved into the iptables.Check codepath as matchers are + // ultimately be moved into the stack.Check codepath as matchers are // added. var udpHeader header.UDP if pkt.TransportHeader != nil { diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index ab01cb4fa..cbf46b1e9 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -38,7 +38,6 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index a8e2e8c24..f5fa18136 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -363,7 +362,7 @@ func (s *Stack) RouteTable() []inet.Route { } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() (iptables.IPTables, error) { +func (s *Stack) IPTables() (stack.IPTables, error) { return s.Stack.IPTables(), nil } diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 26f7ba86b..454e07662 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -5,8 +5,6 @@ package(licenses = ["notice"]) go_library( name = "tcpip", srcs = [ - "packet_buffer.go", - "packet_buffer_state.go", "tcpip.go", "time_unsafe.go", "timer.go", diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD deleted file mode 100644 index d1b73cfdf..000000000 --- a/pkg/tcpip/iptables/BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "iptables", - srcs = [ - "iptables.go", - "targets.go", - "types.go", - ], - visibility = ["//visibility:public"], - deps = [ - "//pkg/log", - "//pkg/tcpip", - "//pkg/tcpip/header", - ], -) diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go deleted file mode 100644 index d30571c74..000000000 --- a/pkg/tcpip/iptables/iptables.go +++ /dev/null @@ -1,314 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package iptables supports packet filtering and manipulation via the iptables -// tool. -package iptables - -import ( - "fmt" - - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -// Table names. -const ( - TablenameNat = "nat" - TablenameMangle = "mangle" - TablenameFilter = "filter" -) - -// Chain names as defined by net/ipv4/netfilter/ip_tables.c. -const ( - ChainNamePrerouting = "PREROUTING" - ChainNameInput = "INPUT" - ChainNameForward = "FORWARD" - ChainNameOutput = "OUTPUT" - ChainNamePostrouting = "POSTROUTING" -) - -// HookUnset indicates that there is no hook set for an entrypoint or -// underflow. -const HookUnset = -1 - -// DefaultTables returns a default set of tables. Each chain is set to accept -// all packets. -func DefaultTables() IPTables { - // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for - // iotas. - return IPTables{ - Tables: map[string]Table{ - TablenameNat: Table{ - Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, - }, - BuiltinChains: map[Hook]int{ - Prerouting: 0, - Input: 1, - Output: 2, - Postrouting: 3, - }, - Underflows: map[Hook]int{ - Prerouting: 0, - Input: 1, - Output: 2, - Postrouting: 3, - }, - UserChains: map[string]int{}, - }, - TablenameMangle: Table{ - Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, - }, - BuiltinChains: map[Hook]int{ - Prerouting: 0, - Output: 1, - }, - Underflows: map[Hook]int{ - Prerouting: 0, - Output: 1, - }, - UserChains: map[string]int{}, - }, - TablenameFilter: Table{ - Rules: []Rule{ - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: AcceptTarget{}}, - Rule{Target: ErrorTarget{}}, - }, - BuiltinChains: map[Hook]int{ - Input: 0, - Forward: 1, - Output: 2, - }, - Underflows: map[Hook]int{ - Input: 0, - Forward: 1, - Output: 2, - }, - UserChains: map[string]int{}, - }, - }, - Priorities: map[Hook][]string{ - Input: []string{TablenameNat, TablenameFilter}, - Prerouting: []string{TablenameMangle, TablenameNat}, - Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, - }, - } -} - -// EmptyFilterTable returns a Table with no rules and the filter table chains -// mapped to HookUnset. -func EmptyFilterTable() Table { - return Table{ - Rules: []Rule{}, - BuiltinChains: map[Hook]int{ - Input: HookUnset, - Forward: HookUnset, - Output: HookUnset, - }, - Underflows: map[Hook]int{ - Input: HookUnset, - Forward: HookUnset, - Output: HookUnset, - }, - UserChains: map[string]int{}, - } -} - -// EmptyNatTable returns a Table with no rules and the filter table chains -// mapped to HookUnset. -func EmptyNatTable() Table { - return Table{ - Rules: []Rule{}, - BuiltinChains: map[Hook]int{ - Prerouting: HookUnset, - Input: HookUnset, - Output: HookUnset, - Postrouting: HookUnset, - }, - Underflows: map[Hook]int{ - Prerouting: HookUnset, - Input: HookUnset, - Output: HookUnset, - Postrouting: HookUnset, - }, - UserChains: map[string]int{}, - } -} - -// A chainVerdict is what a table decides should be done with a packet. -type chainVerdict int - -const ( - // chainAccept indicates the packet should continue through netstack. - chainAccept chainVerdict = iota - - // chainAccept indicates the packet should be dropped. - chainDrop - - // chainReturn indicates the packet should return to the calling chain - // or the underflow rule of a builtin chain. - chainReturn -) - -// Check runs pkt through the rules for hook. It returns true when the packet -// should continue traversing the network stack and false when it should be -// dropped. -// -// Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool { - // Go through each table containing the hook. - for _, tablename := range it.Priorities[hook] { - table := it.Tables[tablename] - ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict { - // If the table returns Accept, move on to the next table. - case chainAccept: - continue - // The Drop verdict is final. - case chainDrop: - return false - case chainReturn: - // Any Return from a built-in chain means we have to - // call the underflow. - underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt); v { - case RuleAccept: - continue - case RuleDrop: - return false - case RuleJump, RuleReturn: - panic("Underflows should only return RuleAccept or RuleDrop.") - default: - panic(fmt.Sprintf("Unknown verdict: %d", v)) - } - - default: - panic(fmt.Sprintf("Unknown verdict %v.", verdict)) - } - } - - // Every table returned Accept. - return true -} - -// Precondition: pkt.NetworkHeader is set. -func (it *IPTables) checkChain(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) chainVerdict { - // Start from ruleIdx and walk the list of rules until a rule gives us - // a verdict. - for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict { - case RuleAccept: - return chainAccept - - case RuleDrop: - return chainDrop - - case RuleReturn: - return chainReturn - - case RuleJump: - // "Jumping" to the next rule just means we're - // continuing on down the list. - if jumpTo == ruleIdx+1 { - ruleIdx++ - continue - } - switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict { - case chainAccept: - return chainAccept - case chainDrop: - return chainDrop - case chainReturn: - ruleIdx++ - continue - default: - panic(fmt.Sprintf("Unknown verdict: %d", verdict)) - } - - default: - panic(fmt.Sprintf("Unknown verdict: %d", verdict)) - } - - } - - // We got through the entire table without a decision. Default to DROP - // for safety. - return chainDrop -} - -// Precondition: pk.NetworkHeader is set. -func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { - rule := table.Rules[ruleIdx] - - // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data.First(). - if pkt.NetworkHeader == nil { - pkt.NetworkHeader = pkt.Data.First() - } - - // Check whether the packet matches the IP header filter. - if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader)) { - // Continue on to the next rule. - return RuleJump, ruleIdx + 1 - } - - // Go through each rule matcher. If they all match, run - // the rule target. - for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, pkt, "") - if hotdrop { - return RuleDrop, 0 - } - if !matches { - // Continue on to the next rule. - return RuleJump, ruleIdx + 1 - } - } - - // All the matchers matched, so run the target. - return rule.Target.Action(pkt) -} - -func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { - // TODO(gvisor.dev/issue/170): Support other fields of the filter. - // Check the transport protocol. - if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { - return false - } - - // Check the destination IP. - dest := hdr.DestinationAddress() - matches := true - for i := range filter.Dst { - if dest[i]&filter.DstMask[i] != filter.Dst[i] { - matches = false - break - } - } - if matches == filter.DstInvert { - return false - } - - return true -} diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go deleted file mode 100644 index e457f2349..000000000 --- a/pkg/tcpip/iptables/targets.go +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables - -import ( - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -// AcceptTarget accepts packets. -type AcceptTarget struct{} - -// Action implements Target.Action. -func (AcceptTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) { - return RuleAccept, 0 -} - -// DropTarget drops packets. -type DropTarget struct{} - -// Action implements Target.Action. -func (DropTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) { - return RuleDrop, 0 -} - -// ErrorTarget logs an error and drops the packet. It represents a target that -// should be unreachable. -type ErrorTarget struct{} - -// Action implements Target.Action. -func (ErrorTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) { - log.Debugf("ErrorTarget triggered.") - return RuleDrop, 0 -} - -// UserChainTarget marks a rule as the beginning of a user chain. -type UserChainTarget struct { - Name string -} - -// Action implements Target.Action. -func (UserChainTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) { - panic("UserChainTarget should never be called.") -} - -// ReturnTarget returns from the current chain. If the chain is a built-in, the -// hook's underflow should be called. -type ReturnTarget struct{} - -// Action implements Target.Action. -func (ReturnTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) { - return RuleReturn, 0 -} - -// RedirectTarget redirects the packet by modifying the destination port/IP. -// Min and Max values for IP and Ports in the struct indicate the range of -// values which can be used to redirect. -type RedirectTarget struct { - // TODO(gvisor.dev/issue/170): Other flags need to be added after - // we support them. - // RangeProtoSpecified flag indicates single port is specified to - // redirect. - RangeProtoSpecified bool - - // Min address used to redirect. - MinIP tcpip.Address - - // Max address used to redirect. - MaxIP tcpip.Address - - // Min port used to redirect. - MinPort uint16 - - // Max port used to redirect. - MaxPort uint16 -} - -// Action implements Target.Action. -// TODO(gvisor.dev/issue/170): Parse headers without copying. The current -// implementation only works for PREROUTING and calls pkt.Clone(), neither -// of which should be the case. -func (rt RedirectTarget) Action(pkt tcpip.PacketBuffer) (RuleVerdict, int) { - newPkt := pkt.Clone() - - // Set network header. - headerView := newPkt.Data.First() - netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] - - hlen := int(netHeader.HeaderLength()) - tlen := int(netHeader.TotalLength()) - newPkt.Data.TrimFront(hlen) - newPkt.Data.CapLength(tlen - hlen) - - // TODO(gvisor.dev/issue/170): Change destination address to - // loopback or interface address on which the packet was - // received. - - // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if - // we need to change dest address (for OUTPUT chain) or ports. - switch protocol := netHeader.TransportProtocol(); protocol { - case header.UDPProtocolNumber: - var udpHeader header.UDP - if newPkt.TransportHeader != nil { - udpHeader = header.UDP(newPkt.TransportHeader) - } else { - if len(pkt.Data.First()) < header.UDPMinimumSize { - return RuleDrop, 0 - } - udpHeader = header.UDP(newPkt.Data.First()) - } - udpHeader.SetDestinationPort(rt.MinPort) - case header.TCPProtocolNumber: - var tcpHeader header.TCP - if newPkt.TransportHeader != nil { - tcpHeader = header.TCP(newPkt.TransportHeader) - } else { - if len(pkt.Data.First()) < header.TCPMinimumSize { - return RuleDrop, 0 - } - tcpHeader = header.TCP(newPkt.TransportHeader) - } - // TODO(gvisor.dev/issue/170): Need to recompute checksum - // and implement nat connection tracking to support TCP. - tcpHeader.SetDestinationPort(rt.MinPort) - default: - return RuleDrop, 0 - } - - return RuleAccept, 0 -} diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go deleted file mode 100644 index e7fcf6bff..000000000 --- a/pkg/tcpip/iptables/types.go +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables - -import ( - "gvisor.dev/gvisor/pkg/tcpip" -) - -// A Hook specifies one of the hooks built into the network stack. -// -// Userspace app Userspace app -// ^ | -// | v -// [Input] [Output] -// ^ | -// | v -// | routing -// | | -// | v -// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]-----> -type Hook uint - -// These values correspond to values in include/uapi/linux/netfilter.h. -const ( - // Prerouting happens before a packet is routed to applications or to - // be forwarded. - Prerouting Hook = iota - - // Input happens before a packet reaches an application. - Input - - // Forward happens once it's decided that a packet should be forwarded - // to another host. - Forward - - // Output happens after a packet is written by an application to be - // sent out. - Output - - // Postrouting happens just before a packet goes out on the wire. - Postrouting - - // The total number of hooks. - NumHooks -) - -// A RuleVerdict is what a rule decides should be done with a packet. -type RuleVerdict int - -const ( - // RuleAccept indicates the packet should continue through netstack. - RuleAccept RuleVerdict = iota - - // RuleDrop indicates the packet should be dropped. - RuleDrop - - // RuleJump indicates the packet should jump to another chain. - RuleJump - - // RuleReturn indicates the packet should return to the previous chain. - RuleReturn -) - -// IPTables holds all the tables for a netstack. -type IPTables struct { - // Tables maps table names to tables. User tables have arbitrary names. - Tables map[string]Table - - // Priorities maps each hook to a list of table names. The order of the - // list is the order in which each table should be visited for that - // hook. - Priorities map[Hook][]string -} - -// A Table defines a set of chains and hooks into the network stack. It is -// really just a list of rules with some metadata for entrypoints and such. -type Table struct { - // Rules holds the rules that make up the table. - Rules []Rule - - // BuiltinChains maps builtin chains to their entrypoint rule in Rules. - BuiltinChains map[Hook]int - - // Underflows maps builtin chains to their underflow rule in Rules - // (i.e. the rule to execute if the chain returns without a verdict). - Underflows map[Hook]int - - // UserChains holds user-defined chains for the keyed by name. Users - // can give their chains arbitrary names. - UserChains map[string]int - - // Metadata holds information about the Table that is useful to users - // of IPTables, but not to the netstack IPTables code itself. - metadata interface{} -} - -// ValidHooks returns a bitmap of the builtin hooks for the given table. -func (table *Table) ValidHooks() uint32 { - hooks := uint32(0) - for hook := range table.BuiltinChains { - hooks |= 1 << hook - } - return hooks -} - -// Metadata returns the metadata object stored in table. -func (table *Table) Metadata() interface{} { - return table.metadata -} - -// SetMetadata sets the metadata object stored in table. -func (table *Table) SetMetadata(metadata interface{}) { - table.metadata = metadata -} - -// A Rule is a packet processing rule. It consists of two pieces. First it -// contains zero or more matchers, each of which is a specification of which -// packets this rule applies to. If there are no matchers in the rule, it -// applies to any packet. -type Rule struct { - // Filter holds basic IP filtering fields common to every rule. - Filter IPHeaderFilter - - // Matchers is the list of matchers for this rule. - Matchers []Matcher - - // Target is the action to invoke if all the matchers match the packet. - Target Target -} - -// IPHeaderFilter holds basic IP filtering data common to every rule. -type IPHeaderFilter struct { - // Protocol matches the transport protocol. - Protocol tcpip.TransportProtocolNumber - - // Dst matches the destination IP address. - Dst tcpip.Address - - // DstMask masks bits of the destination IP address when comparing with - // Dst. - DstMask tcpip.Address - - // DstInvert inverts the meaning of the destination IP check, i.e. when - // true the filter will match packets that fail the destination - // comparison. - DstInvert bool -} - -// A Matcher is the interface for matching packets. -type Matcher interface { - // Name returns the name of the Matcher. - Name() string - - // Match returns whether the packet matches and whether the packet - // should be "hotdropped", i.e. dropped immediately. This is usually - // used for suspicious packets. - // - // Precondition: packet.NetworkHeader is set. - Match(hook Hook, packet tcpip.PacketBuffer, interfaceName string) (matches bool, hotdrop bool) -} - -// A Target is the interface for taking an action for a packet. -type Target interface { - // Action takes an action on the packet and returns a verdict on how - // traversal should (or should not) continue. If the return value is - // Jump, it also returns the index of the rule to jump to. - Action(packet tcpip.PacketBuffer) (RuleVerdict, int) -} diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 5944ba190..a8d6653ce 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -28,7 +28,7 @@ import ( // PacketInfo holds all the information about an outbound packet. type PacketInfo struct { - Pkt tcpip.PacketBuffer + Pkt stack.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO Route stack.Route @@ -203,12 +203,12 @@ func (e *Endpoint) NumQueued() int { } // InjectInbound injects an inbound packet. -func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) { +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt stack.PacketBuffer) { e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) } @@ -251,7 +251,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { // Clone r then release its resource so we only get the relevant fields from // stack.Route without holding a reference to a NIC's endpoint. route := r.Clone() @@ -269,7 +269,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { // Clone r then release its resource so we only get the relevant fields from // stack.Route without holding a reference to a NIC's endpoint. route := r.Clone() @@ -280,7 +280,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac off := pkt.DataOffset size := pkt.DataSize p := PacketInfo{ - Pkt: tcpip.PacketBuffer{ + Pkt: stack.PacketBuffer{ Header: pkt.Header, Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(), }, @@ -301,7 +301,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Pkt: tcpip.PacketBuffer{Data: vv}, + Pkt: stack.PacketBuffer{Data: vv}, Proto: 0, GSO: nil, } diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 3b36b9673..235e647ff 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -386,7 +386,7 @@ const ( // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) @@ -440,7 +440,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { var ethHdrBuf []byte // hdr + data iovLen := 2 @@ -610,7 +610,7 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } // InjectInbound injects an inbound packet. -func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt) } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 2066987eb..c7dbbbc6b 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -45,7 +45,7 @@ const ( type packetInfo struct { raddr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - contents tcpip.PacketBuffer + contents stack.PacketBuffer } type context struct { @@ -92,7 +92,7 @@ func (c *context) cleanup() { syscall.Close(c.fds[1]) } -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { c.ch <- packetInfo{remote, protocol, pkt} } @@ -168,7 +168,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(r, gso, proto, stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -261,7 +261,7 @@ func TestPreserveSrcAddress(t *testing.T) { // WritePacket panics given a prependable with anything less than // the minimum size of the ethernet header. hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, proto, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(r, nil /* gso */, proto, stack.PacketBuffer{ Header: hdr, Data: buffer.VectorisedView{}, }); err != nil { @@ -324,7 +324,7 @@ func TestDeliverPacket(t *testing.T) { want := packetInfo{ raddr: raddr, proto: proto, - contents: tcpip.PacketBuffer{ + contents: stack.PacketBuffer{ Data: buffer.View(b).ToVectorisedView(), LinkHeader: buffer.View(hdr), }, diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 62ed1e569..fe2bf3b0b 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) const ( @@ -190,7 +191,7 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, tcpip.PacketBuffer{ + d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, stack.PacketBuffer{ Data: buffer.View(pkt).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index c67d684ce..cb4cbea69 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -139,7 +139,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - pkt := tcpip.PacketBuffer{ + pkt := stack.PacketBuffer{ Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), LinkHeader: buffer.View(eth), } @@ -296,7 +296,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - pkt := tcpip.PacketBuffer{ + pkt := stack.PacketBuffer{ Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), LinkHeader: buffer.View(eth), } diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 499cc608f..4039753b7 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,7 +76,7 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) @@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw // Because we're immediately turning around and writing the packet back // to the rx path, we intentionally don't preserve the remote and local // link addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, tcpip.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } @@ -106,7 +106,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // There should be an ethernet header at the beginning of vv. linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) vv.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), tcpip.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, LinkHeader: buffer.View(linkHeader), }) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 445b22c17..f5973066d 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -80,14 +80,14 @@ func (m *InjectableEndpoint) IsAttached() bool { } // InjectInbound implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt) } // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { return 0, tcpip.ErrNoRoute @@ -98,7 +98,7 @@ func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts [ // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { return endpoint.WritePacket(r, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 63b249837..87c734c1f 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -50,7 +50,7 @@ func TestInjectableEndpointDispatch(t *testing.T) { hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), }) @@ -70,7 +70,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { hdr := buffer.NewPrependable(1) hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, tcpip.PacketBuffer{ + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buffer.NewView(0).ToVectorisedView(), }) diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 655e537c4..6461d0108 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -185,7 +185,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { // Add the ethernet header here. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) @@ -214,7 +214,7 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } @@ -275,7 +275,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { // Send packet up the stack. eth := header.Ethernet(b[:header.EthernetMinimumSize]) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), tcpip.PacketBuffer{ + d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), stack.PacketBuffer{ Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 5c729a439..27ea3f531 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -131,7 +131,7 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ addr: remoteLinkAddr, @@ -273,7 +273,7 @@ func TestSimpleSend(t *testing.T) { randomFill(buf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -345,7 +345,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) { hdr := buffer.NewPrependable(header.EthernetMinimumSize) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{ Header: hdr, }); err != nil { t.Fatalf("WritePacket failed: %v", err) @@ -401,7 +401,7 @@ func TestFillTxQueue(t *testing.T) { for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -419,7 +419,7 @@ func TestFillTxQueue(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != want { @@ -447,7 +447,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -470,7 +470,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -488,7 +488,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != want { @@ -514,7 +514,7 @@ func TestFillTxMemory(t *testing.T) { ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -533,7 +533,7 @@ func TestFillTxMemory(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }) @@ -561,7 +561,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -577,7 +577,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: uu, }); err != want { @@ -588,7 +588,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write the one-buffer packet again. It must succeed. { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, tcpip.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 3392b7edd..0a6b8945c 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -123,7 +123,7 @@ func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { logPacket("recv", protocol, pkt.Data.First(), nil) } @@ -200,7 +200,7 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { logPacket("send", protocol, pkt.Header.View(), gso) } @@ -232,7 +232,7 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumb // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { e.dumpPacket(gso, protocol, pkt) return e.lower.WritePacket(r, gso, protocol, pkt) } @@ -240,10 +240,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { view := pkts[0].Data.ToView() for _, pkt := range pkts { - e.dumpPacket(gso, protocol, tcpip.PacketBuffer{ + e.dumpPacket(gso, protocol, stack.PacketBuffer{ Header: pkt.Header, Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(), }) diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index f6e301304..617446ea2 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -213,7 +213,7 @@ func (d *Device) Write(data []byte) (int64, error) { remote = tcpip.LinkAddress(zeroMAC[:]) } - pkt := tcpip.PacketBuffer{ + pkt := stack.PacketBuffer{ Data: buffer.View(data).ToVectorisedView(), } if ethHdr != nil { diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index a8de38979..52fe397bf 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -50,7 +50,7 @@ func New(lower stack.LinkEndpoint) *Endpoint { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { if !e.dispatchGate.Enter() { return } @@ -99,7 +99,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { if !e.writeGate.Enter() { return nil } @@ -112,7 +112,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { if !e.writeGate.Enter() { return len(pkts), nil } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 31b11a27a..88224e494 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { e.dispatchCount++ } @@ -65,13 +65,13 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { e.writeCount++ return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { e.writeCount += len(pkts) return len(pkts), nil } @@ -89,21 +89,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, 0, tcpip.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -120,21 +120,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, tcpip.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index e9fcc89a8..255098372 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -79,20 +79,20 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []tcpip.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { @@ -113,7 +113,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ Header: hdr, }) fallthrough // also fill the cache from requests @@ -167,7 +167,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ Header: hdr, }) } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 03cf03b6d..b3e239ac7 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -103,7 +103,7 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{ Data: v.ToVectorisedView(), }) } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index f4d78f8c6..4950d69fc 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -96,7 +96,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) { t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } @@ -104,7 +104,7 @@ func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.Trans // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) @@ -150,7 +150,7 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address @@ -172,7 +172,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } @@ -246,7 +246,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -289,7 +289,7 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, tcpip.PacketBuffer{ + ep.HandlePacket(&r, stack.PacketBuffer{ Data: view.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -379,7 +379,7 @@ func TestIPv4ReceiveControl(t *testing.T) { o.extra = c.expectedExtra vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, tcpip.PacketBuffer{ + ep.HandlePacket(&r, stack.PacketBuffer{ Data: vv, }) if want := c.expectedCount; o.controlCalls != want { @@ -444,7 +444,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, tcpip.PacketBuffer{ + ep.HandlePacket(&r, stack.PacketBuffer{ Data: frag1.ToVectorisedView(), }) if o.dataCalls != 0 { @@ -452,7 +452,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send second segment. - ep.HandlePacket(&r, tcpip.PacketBuffer{ + ep.HandlePacket(&r, stack.PacketBuffer{ Data: frag2.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -487,7 +487,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -530,7 +530,7 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, tcpip.PacketBuffer{ + ep.HandlePacket(&r, stack.PacketBuffer{ Data: view.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -644,7 +644,7 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, tcpip.PacketBuffer{ + ep.HandlePacket(&r, stack.PacketBuffer{ Data: view[:len(view)-c.trunc].ToVectorisedView(), }) if want := c.expectedCount; o.controlCalls != want { diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 0fef2b1f1..880ea7de2 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -13,7 +13,6 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 32bf39e43..c4bf1ba5c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,7 +15,6 @@ package ipv4 import ( - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -25,7 +24,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP @@ -53,7 +52,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { +func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived v := pkt.Data.First() @@ -85,7 +84,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, tcpip.PacketBuffer{ + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{ Data: pkt.Data.Clone(nil), NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), }) @@ -99,7 +98,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt tcpip.PacketBuffer) { pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: vv, TransportHeader: buffer.View(pkt), diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 4f1742938..b3ee6000e 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -125,7 +124,7 @@ func (e *endpoint) GSOMaxSize() uint32 { // packet's stated length matches the length of the header+payload. mtu // includes the IP header and options. This does not support the DontFragment // IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. ip := header.IPv4(pkt.Header.View()) flags := ip.Flags() @@ -165,7 +164,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, if i > 0 { newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ Header: pkt.Header, Data: newPayload, NetworkHeader: buffer.View(h), @@ -184,7 +183,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, newPayload := pkt.Data.Clone(nil) newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ Header: pkt.Header, Data: newPayload, NetworkHeader: buffer.View(h), @@ -198,7 +197,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, startOfHdr := pkt.Header startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, tcpip.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ Header: startOfHdr, Data: emptyVV, NetworkHeader: buffer.View(h), @@ -241,7 +240,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) @@ -253,7 +252,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + e.HandlePacket(&loopedR, stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -273,7 +272,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } @@ -292,7 +291,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. ip := header.IPv4(pkt.Data.First()) @@ -344,7 +343,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { headerView := pkt.Data.First() h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { @@ -361,7 +360,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. ipt := e.stack.IPTables() - if ok := ipt.Check(iptables.Input, pkt); !ok { + if ok := ipt.Check(stack.Input, pkt); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index e900f1b45..5a864d832 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -113,7 +113,7 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketInfo tcpip.PacketBuffer, mtu uint32) { +func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) @@ -173,7 +173,7 @@ func compareFragments(t *testing.T, packets []tcpip.PacketBuffer, sourcePacketIn type errorChannel struct { *channel.Endpoint - Ch chan tcpip.PacketBuffer + Ch chan stack.PacketBuffer packetCollectorErrors []*tcpip.Error } @@ -183,7 +183,7 @@ type errorChannel struct { func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { return &errorChannel{ Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan tcpip.PacketBuffer, size), + Ch: make(chan stack.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } } @@ -202,7 +202,7 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { select { case e.Ch <- pkt: default: @@ -281,13 +281,13 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := tcpip.PacketBuffer{ + source := stack.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. Data: payload.Clone(nil), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -295,7 +295,7 @@ func TestFragmentation(t *testing.T) { t.Errorf("err got %v, want %v", err, nil) } - var results []tcpip.PacketBuffer + var results []stack.PacketBuffer L: for { select { @@ -337,7 +337,7 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -459,7 +459,7 @@ func TestInvalidFragments(t *testing.T) { s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, tcpip.PacketBuffer{ + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), }) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 45dc757c7..8640feffc 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -27,7 +27,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to @@ -62,7 +62,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt tcpip. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.PacketBuffer) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived @@ -243,7 +243,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, }); err != nil { sent.Dropped.Increment() @@ -330,7 +330,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P copy(packet, h) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: pkt.Data, }); err != nil { @@ -463,7 +463,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, tcpip.PacketBuffer{ + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ Header: hdr, }) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 50c4b6474..bae09ed94 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -56,7 +56,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, tcpip.PacketBuffer) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error { return nil } @@ -66,7 +66,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, tcpip.PacketBuffer) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) { } type stubLinkAddressCache struct { @@ -187,7 +187,7 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, tcpip.PacketBuffer{ + ep.HandlePacket(&r, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -326,7 +326,7 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), tcpip.PacketBuffer{ + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{ Data: vv, }) } @@ -561,7 +561,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -738,7 +738,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -916,7 +916,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), }) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 9aef5234b..29e597002 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -112,7 +112,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) @@ -124,7 +124,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, tcpip.PacketBuffer{ + e.HandlePacket(&loopedR, stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -139,7 +139,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } @@ -161,14 +161,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // TODO(b/146666412): Support IPv6 header-included packets. return tcpip.ErrNotSupported } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { headerView := pkt.Data.First() h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 1cbfa7278..ed98ef22a 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -55,7 +55,7 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -113,7 +113,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index c9395de52..f924ed9e1 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -135,7 +135,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -238,7 +238,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -304,7 +304,7 @@ func TestHopLimitValidation(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, tcpip.PacketBuffer{ + ep.HandlePacket(r, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -588,7 +588,7 @@ func TestRouterAdvertValidation(t *testing.T) { t.Fatalf("got rxRA = %d, want = 0", got) } - e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) diff --git a/pkg/tcpip/packet_buffer.go b/pkg/tcpip/packet_buffer.go deleted file mode 100644 index ab24372e7..000000000 --- a/pkg/tcpip/packet_buffer.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpip - -import "gvisor.dev/gvisor/pkg/tcpip/buffer" - -// A PacketBuffer contains all the data of a network packet. -// -// As a PacketBuffer traverses up the stack, it may be necessary to pass it to -// multiple endpoints. Clone() should be called in such cases so that -// modifications to the Data field do not affect other copies. -// -// +stateify savable -type PacketBuffer struct { - // Data holds the payload of the packet. For inbound packets, it also - // holds the headers, which are consumed as the packet moves up the - // stack. Headers are guaranteed not to be split across views. - // - // The bytes backing Data are immutable, but Data itself may be trimmed - // or otherwise modified. - Data buffer.VectorisedView - - // DataOffset is used for GSO output. It is the offset into the Data - // field where the payload of this packet starts. - DataOffset int - - // DataSize is used for GSO output. It is the size of this packet's - // payload. - DataSize int - - // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. - Header buffer.Prependable - - // These fields are used by both inbound and outbound packets. They - // typically overlap with the Data and Header fields. - // - // The bytes backing these views are immutable. Each field may be nil - // if either it has not been set yet or no such header exists (e.g. - // packets sent via loopback may not have a link header). - // - // These fields may be Views into other slices (either Data or Header). - // SR dosen't support this, so deep copies are necessary in some cases. - LinkHeader buffer.View - NetworkHeader buffer.View - TransportHeader buffer.View -} - -// Clone makes a copy of pk. It clones the Data field, which creates a new -// VectorisedView but does not deep copy the underlying bytes. -// -// Clone also does not deep copy any of its other fields. -func (pk PacketBuffer) Clone() PacketBuffer { - pk.Data = pk.Data.Clone(nil) - return pk -} diff --git a/pkg/tcpip/packet_buffer_state.go b/pkg/tcpip/packet_buffer_state.go deleted file mode 100644 index ad3cc24fa..000000000 --- a/pkg/tcpip/packet_buffer_state.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcpip - -import "gvisor.dev/gvisor/pkg/tcpip/buffer" - -// beforeSave is invoked by stateify. -func (pk *PacketBuffer) beforeSave() { - // Non-Data fields may be slices of the Data field. This causes - // problems for SR, so during save we make each header independent. - pk.Header = pk.Header.DeepCopy() - pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) - pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) - pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) -} diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 6c029b2fb..7a43a1d4e 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -21,10 +21,15 @@ go_library( "dhcpv6configurationfromndpra_string.go", "forwarder.go", "icmp_rate_limit.go", + "iptables.go", + "iptables_targets.go", + "iptables_types.go", "linkaddrcache.go", "linkaddrentry_list.go", "ndp.go", "nic.go", + "packet_buffer.go", + "packet_buffer_state.go", "registration.go", "route.go", "stack.go", @@ -34,6 +39,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/ilist", + "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/sync", @@ -41,7 +47,6 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/waiter", @@ -65,7 +70,6 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go index 631953935..6b64cd37f 100644 --- a/pkg/tcpip/stack/forwarder.go +++ b/pkg/tcpip/stack/forwarder.go @@ -32,7 +32,7 @@ type pendingPacket struct { nic *NIC route *Route proto tcpip.NetworkProtocolNumber - pkt tcpip.PacketBuffer + pkt PacketBuffer } type forwardQueue struct { @@ -50,7 +50,7 @@ func newForwardQueue() *forwardQueue { return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} } -func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { shouldWait := false f.Lock() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 321b7524d..c45c43d21 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -68,7 +68,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { return &f.id } -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt tcpip.PacketBuffer) { +func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. b := pkt.Data.First() pkt.Data.TrimFront(fwdTestNetHeaderLen) @@ -89,7 +89,7 @@ func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.Header.Prepend(fwdTestNetHeaderLen) @@ -101,11 +101,11 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH } // WritePackets implements LinkEndpoint.WritePackets. -func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { +func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error { +func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -183,7 +183,7 @@ func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumb type fwdTestPacketInfo struct { RemoteLinkAddress tcpip.LinkAddress LocalLinkAddress tcpip.LinkAddress - Pkt tcpip.PacketBuffer + Pkt PacketBuffer } type fwdTestLinkEndpoint struct { @@ -196,12 +196,12 @@ type fwdTestLinkEndpoint struct { } // InjectInbound injects an inbound packet. -func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) { +func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) { e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) } @@ -244,7 +244,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -260,7 +260,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw } // WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { n := 0 for _, pkt := range pkts { e.WritePacket(r, gso, protocol, pkt) @@ -273,7 +273,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.Pack // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := fwdTestPacketInfo{ - Pkt: tcpip.PacketBuffer{Data: vv}, + Pkt: PacketBuffer{Data: vv}, } select { @@ -355,7 +355,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -392,7 +392,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -423,7 +423,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -453,7 +453,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // not be forwarded. buf := buffer.NewView(30) buf[0] = 4 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -461,7 +461,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // forwarded to NIC 2. buf = buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -503,7 +503,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -550,7 +550,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { buf[0] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -603,7 +603,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // maxPendingResolutions + 7). buf := buffer.NewView(30) buf[0] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ Data: buf.ToVectorisedView(), }) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go new file mode 100644 index 000000000..37907ae24 --- /dev/null +++ b/pkg/tcpip/stack/iptables.go @@ -0,0 +1,311 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// Table names. +const ( + TablenameNat = "nat" + TablenameMangle = "mangle" + TablenameFilter = "filter" +) + +// Chain names as defined by net/ipv4/netfilter/ip_tables.c. +const ( + ChainNamePrerouting = "PREROUTING" + ChainNameInput = "INPUT" + ChainNameForward = "FORWARD" + ChainNameOutput = "OUTPUT" + ChainNamePostrouting = "POSTROUTING" +) + +// HookUnset indicates that there is no hook set for an entrypoint or +// underflow. +const HookUnset = -1 + +// DefaultTables returns a default set of tables. Each chain is set to accept +// all packets. +func DefaultTables() IPTables { + // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for + // iotas. + return IPTables{ + Tables: map[string]Table{ + TablenameNat: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Prerouting: 0, + Input: 1, + Output: 2, + Postrouting: 3, + }, + Underflows: map[Hook]int{ + Prerouting: 0, + Input: 1, + Output: 2, + Postrouting: 3, + }, + UserChains: map[string]int{}, + }, + TablenameMangle: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Prerouting: 0, + Output: 1, + }, + Underflows: map[Hook]int{ + Prerouting: 0, + Output: 1, + }, + UserChains: map[string]int{}, + }, + TablenameFilter: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: map[Hook]int{ + Input: 0, + Forward: 1, + Output: 2, + }, + Underflows: map[Hook]int{ + Input: 0, + Forward: 1, + Output: 2, + }, + UserChains: map[string]int{}, + }, + }, + Priorities: map[Hook][]string{ + Input: []string{TablenameNat, TablenameFilter}, + Prerouting: []string{TablenameMangle, TablenameNat}, + Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, + }, + } +} + +// EmptyFilterTable returns a Table with no rules and the filter table chains +// mapped to HookUnset. +func EmptyFilterTable() Table { + return Table{ + Rules: []Rule{}, + BuiltinChains: map[Hook]int{ + Input: HookUnset, + Forward: HookUnset, + Output: HookUnset, + }, + Underflows: map[Hook]int{ + Input: HookUnset, + Forward: HookUnset, + Output: HookUnset, + }, + UserChains: map[string]int{}, + } +} + +// EmptyNatTable returns a Table with no rules and the filter table chains +// mapped to HookUnset. +func EmptyNatTable() Table { + return Table{ + Rules: []Rule{}, + BuiltinChains: map[Hook]int{ + Prerouting: HookUnset, + Input: HookUnset, + Output: HookUnset, + Postrouting: HookUnset, + }, + Underflows: map[Hook]int{ + Prerouting: HookUnset, + Input: HookUnset, + Output: HookUnset, + Postrouting: HookUnset, + }, + UserChains: map[string]int{}, + } +} + +// A chainVerdict is what a table decides should be done with a packet. +type chainVerdict int + +const ( + // chainAccept indicates the packet should continue through netstack. + chainAccept chainVerdict = iota + + // chainAccept indicates the packet should be dropped. + chainDrop + + // chainReturn indicates the packet should return to the calling chain + // or the underflow rule of a builtin chain. + chainReturn +) + +// Check runs pkt through the rules for hook. It returns true when the packet +// should continue traversing the network stack and false when it should be +// dropped. +// +// Precondition: pkt.NetworkHeader is set. +func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { + // Go through each table containing the hook. + for _, tablename := range it.Priorities[hook] { + table := it.Tables[tablename] + ruleIdx := table.BuiltinChains[hook] + switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict { + // If the table returns Accept, move on to the next table. + case chainAccept: + continue + // The Drop verdict is final. + case chainDrop: + return false + case chainReturn: + // Any Return from a built-in chain means we have to + // call the underflow. + underflow := table.Rules[table.Underflows[hook]] + switch v, _ := underflow.Target.Action(pkt); v { + case RuleAccept: + continue + case RuleDrop: + return false + case RuleJump, RuleReturn: + panic("Underflows should only return RuleAccept or RuleDrop.") + default: + panic(fmt.Sprintf("Unknown verdict: %d", v)) + } + + default: + panic(fmt.Sprintf("Unknown verdict %v.", verdict)) + } + } + + // Every table returned Accept. + return true +} + +// Precondition: pkt.NetworkHeader is set. +func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { + // Start from ruleIdx and walk the list of rules until a rule gives us + // a verdict. + for ruleIdx < len(table.Rules) { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict { + case RuleAccept: + return chainAccept + + case RuleDrop: + return chainDrop + + case RuleReturn: + return chainReturn + + case RuleJump: + // "Jumping" to the next rule just means we're + // continuing on down the list. + if jumpTo == ruleIdx+1 { + ruleIdx++ + continue + } + switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict { + case chainAccept: + return chainAccept + case chainDrop: + return chainDrop + case chainReturn: + ruleIdx++ + continue + default: + panic(fmt.Sprintf("Unknown verdict: %d", verdict)) + } + + default: + panic(fmt.Sprintf("Unknown verdict: %d", verdict)) + } + + } + + // We got through the entire table without a decision. Default to DROP + // for safety. + return chainDrop +} + +// Precondition: pk.NetworkHeader is set. +func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { + rule := table.Rules[ruleIdx] + + // If pkt.NetworkHeader hasn't been set yet, it will be contained in + // pkt.Data.First(). + if pkt.NetworkHeader == nil { + pkt.NetworkHeader = pkt.Data.First() + } + + // Check whether the packet matches the IP header filter. + if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader)) { + // Continue on to the next rule. + return RuleJump, ruleIdx + 1 + } + + // Go through each rule matcher. If they all match, run + // the rule target. + for _, matcher := range rule.Matchers { + matches, hotdrop := matcher.Match(hook, pkt, "") + if hotdrop { + return RuleDrop, 0 + } + if !matches { + // Continue on to the next rule. + return RuleJump, ruleIdx + 1 + } + } + + // All the matchers matched, so run the target. + return rule.Target.Action(pkt) +} + +func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { + // TODO(gvisor.dev/issue/170): Support other fields of the filter. + // Check the transport protocol. + if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { + return false + } + + // Check the destination IP. + dest := hdr.DestinationAddress() + matches := true + for i := range filter.Dst { + if dest[i]&filter.DstMask[i] != filter.Dst[i] { + matches = false + break + } + } + if matches == filter.DstInvert { + return false + } + + return true +} diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go new file mode 100644 index 000000000..7b4543caf --- /dev/null +++ b/pkg/tcpip/stack/iptables_targets.go @@ -0,0 +1,144 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// AcceptTarget accepts packets. +type AcceptTarget struct{} + +// Action implements Target.Action. +func (AcceptTarget) Action(packet PacketBuffer) (RuleVerdict, int) { + return RuleAccept, 0 +} + +// DropTarget drops packets. +type DropTarget struct{} + +// Action implements Target.Action. +func (DropTarget) Action(packet PacketBuffer) (RuleVerdict, int) { + return RuleDrop, 0 +} + +// ErrorTarget logs an error and drops the packet. It represents a target that +// should be unreachable. +type ErrorTarget struct{} + +// Action implements Target.Action. +func (ErrorTarget) Action(packet PacketBuffer) (RuleVerdict, int) { + log.Debugf("ErrorTarget triggered.") + return RuleDrop, 0 +} + +// UserChainTarget marks a rule as the beginning of a user chain. +type UserChainTarget struct { + Name string +} + +// Action implements Target.Action. +func (UserChainTarget) Action(PacketBuffer) (RuleVerdict, int) { + panic("UserChainTarget should never be called.") +} + +// ReturnTarget returns from the current chain. If the chain is a built-in, the +// hook's underflow should be called. +type ReturnTarget struct{} + +// Action implements Target.Action. +func (ReturnTarget) Action(PacketBuffer) (RuleVerdict, int) { + return RuleReturn, 0 +} + +// RedirectTarget redirects the packet by modifying the destination port/IP. +// Min and Max values for IP and Ports in the struct indicate the range of +// values which can be used to redirect. +type RedirectTarget struct { + // TODO(gvisor.dev/issue/170): Other flags need to be added after + // we support them. + // RangeProtoSpecified flag indicates single port is specified to + // redirect. + RangeProtoSpecified bool + + // Min address used to redirect. + MinIP tcpip.Address + + // Max address used to redirect. + MaxIP tcpip.Address + + // Min port used to redirect. + MinPort uint16 + + // Max port used to redirect. + MaxPort uint16 +} + +// Action implements Target.Action. +// TODO(gvisor.dev/issue/170): Parse headers without copying. The current +// implementation only works for PREROUTING and calls pkt.Clone(), neither +// of which should be the case. +func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { + newPkt := pkt.Clone() + + // Set network header. + headerView := newPkt.Data.First() + netHeader := header.IPv4(headerView) + newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] + + hlen := int(netHeader.HeaderLength()) + tlen := int(netHeader.TotalLength()) + newPkt.Data.TrimFront(hlen) + newPkt.Data.CapLength(tlen - hlen) + + // TODO(gvisor.dev/issue/170): Change destination address to + // loopback or interface address on which the packet was + // received. + + // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if + // we need to change dest address (for OUTPUT chain) or ports. + switch protocol := netHeader.TransportProtocol(); protocol { + case header.UDPProtocolNumber: + var udpHeader header.UDP + if newPkt.TransportHeader != nil { + udpHeader = header.UDP(newPkt.TransportHeader) + } else { + if len(pkt.Data.First()) < header.UDPMinimumSize { + return RuleDrop, 0 + } + udpHeader = header.UDP(newPkt.Data.First()) + } + udpHeader.SetDestinationPort(rt.MinPort) + case header.TCPProtocolNumber: + var tcpHeader header.TCP + if newPkt.TransportHeader != nil { + tcpHeader = header.TCP(newPkt.TransportHeader) + } else { + if len(pkt.Data.First()) < header.TCPMinimumSize { + return RuleDrop, 0 + } + tcpHeader = header.TCP(newPkt.TransportHeader) + } + // TODO(gvisor.dev/issue/170): Need to recompute checksum + // and implement nat connection tracking to support TCP. + tcpHeader.SetDestinationPort(rt.MinPort) + default: + return RuleDrop, 0 + } + + return RuleAccept, 0 +} diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go new file mode 100644 index 000000000..2ffb55f2a --- /dev/null +++ b/pkg/tcpip/stack/iptables_types.go @@ -0,0 +1,180 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) + +// A Hook specifies one of the hooks built into the network stack. +// +// Userspace app Userspace app +// ^ | +// | v +// [Input] [Output] +// ^ | +// | v +// | routing +// | | +// | v +// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]-----> +type Hook uint + +// These values correspond to values in include/uapi/linux/netfilter.h. +const ( + // Prerouting happens before a packet is routed to applications or to + // be forwarded. + Prerouting Hook = iota + + // Input happens before a packet reaches an application. + Input + + // Forward happens once it's decided that a packet should be forwarded + // to another host. + Forward + + // Output happens after a packet is written by an application to be + // sent out. + Output + + // Postrouting happens just before a packet goes out on the wire. + Postrouting + + // The total number of hooks. + NumHooks +) + +// A RuleVerdict is what a rule decides should be done with a packet. +type RuleVerdict int + +const ( + // RuleAccept indicates the packet should continue through netstack. + RuleAccept RuleVerdict = iota + + // RuleDrop indicates the packet should be dropped. + RuleDrop + + // RuleJump indicates the packet should jump to another chain. + RuleJump + + // RuleReturn indicates the packet should return to the previous chain. + RuleReturn +) + +// IPTables holds all the tables for a netstack. +type IPTables struct { + // Tables maps table names to tables. User tables have arbitrary names. + Tables map[string]Table + + // Priorities maps each hook to a list of table names. The order of the + // list is the order in which each table should be visited for that + // hook. + Priorities map[Hook][]string +} + +// A Table defines a set of chains and hooks into the network stack. It is +// really just a list of rules with some metadata for entrypoints and such. +type Table struct { + // Rules holds the rules that make up the table. + Rules []Rule + + // BuiltinChains maps builtin chains to their entrypoint rule in Rules. + BuiltinChains map[Hook]int + + // Underflows maps builtin chains to their underflow rule in Rules + // (i.e. the rule to execute if the chain returns without a verdict). + Underflows map[Hook]int + + // UserChains holds user-defined chains for the keyed by name. Users + // can give their chains arbitrary names. + UserChains map[string]int + + // Metadata holds information about the Table that is useful to users + // of IPTables, but not to the netstack IPTables code itself. + metadata interface{} +} + +// ValidHooks returns a bitmap of the builtin hooks for the given table. +func (table *Table) ValidHooks() uint32 { + hooks := uint32(0) + for hook := range table.BuiltinChains { + hooks |= 1 << hook + } + return hooks +} + +// Metadata returns the metadata object stored in table. +func (table *Table) Metadata() interface{} { + return table.metadata +} + +// SetMetadata sets the metadata object stored in table. +func (table *Table) SetMetadata(metadata interface{}) { + table.metadata = metadata +} + +// A Rule is a packet processing rule. It consists of two pieces. First it +// contains zero or more matchers, each of which is a specification of which +// packets this rule applies to. If there are no matchers in the rule, it +// applies to any packet. +type Rule struct { + // Filter holds basic IP filtering fields common to every rule. + Filter IPHeaderFilter + + // Matchers is the list of matchers for this rule. + Matchers []Matcher + + // Target is the action to invoke if all the matchers match the packet. + Target Target +} + +// IPHeaderFilter holds basic IP filtering data common to every rule. +type IPHeaderFilter struct { + // Protocol matches the transport protocol. + Protocol tcpip.TransportProtocolNumber + + // Dst matches the destination IP address. + Dst tcpip.Address + + // DstMask masks bits of the destination IP address when comparing with + // Dst. + DstMask tcpip.Address + + // DstInvert inverts the meaning of the destination IP check, i.e. when + // true the filter will match packets that fail the destination + // comparison. + DstInvert bool +} + +// A Matcher is the interface for matching packets. +type Matcher interface { + // Name returns the name of the Matcher. + Name() string + + // Match returns whether the packet matches and whether the packet + // should be "hotdropped", i.e. dropped immediately. This is usually + // used for suspicious packets. + // + // Precondition: packet.NetworkHeader is set. + Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool) +} + +// A Target is the interface for taking an action for a packet. +type Target interface { + // Action takes an action on the packet and returns a verdict on how + // traversal should (or should not) continue. If the return value is + // Jump, it also returns the index of the rule to jump to. + Action(packet PacketBuffer) (RuleVerdict, int) +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index d689a006d..630fdefc5 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -564,7 +564,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, tcpip.PacketBuffer{Header: hdr}, + }, PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() return err @@ -1283,7 +1283,7 @@ func (ndp *ndpState) startSolicitingRouters() { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, tcpip.PacketBuffer{Header: hdr}, + }, PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 4368c236c..06edd05b6 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -602,7 +602,7 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -918,7 +918,7 @@ func TestSetNDPConfigurations(t *testing.T) { // raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options // and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { +func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer { icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) @@ -953,14 +953,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} } // raBufWithOpts returns a valid NDP Router Advertisement with options. // // Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer { +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) } @@ -969,7 +969,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ // // Note, raBufWithDHCPv6 does not populate any of the RA fields other than the // DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer { +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) } @@ -977,7 +977,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo // // Note, raBuf does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { +func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) } @@ -986,7 +986,7 @@ func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer { // // Note, raBufWithPI does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer { +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer { flags := uint8(0) if onLink { // The OnLink flag is the 7th bit in the flags byte. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 9dcb1d52c..b6fa647ea 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" ) var ipv4BroadcastAddr = tcpip.ProtocolAddress{ @@ -1144,7 +1143,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return joins != 0 } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr @@ -1158,7 +1157,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { n.mu.RLock() enabled := n.mu.enabled // If the NIC is not yet enabled, don't receive any packets. @@ -1222,7 +1221,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. if protocol == header.IPv4ProtocolNumber { ipt := n.stack.IPTables() - if ok := ipt.Check(iptables.Prerouting, pkt); !ok { + if ok := ipt.Check(Prerouting, pkt); !ok { // iptables is telling us to drop the packet. return } @@ -1287,7 +1286,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } } -func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. firstData := pkt.Data.First() @@ -1318,7 +1317,7 @@ func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -1364,7 +1363,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index edaee3b86..d672fc157 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -17,7 +17,6 @@ package stack import ( "testing" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -45,7 +44,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket(nil, "", "", 0, tcpip.PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket(nil, "", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go new file mode 100644 index 000000000..1850fa8c3 --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer.go @@ -0,0 +1,66 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package stack + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// A PacketBuffer contains all the data of a network packet. +// +// As a PacketBuffer traverses up the stack, it may be necessary to pass it to +// multiple endpoints. Clone() should be called in such cases so that +// modifications to the Data field do not affect other copies. +// +// +stateify savable +type PacketBuffer struct { + // Data holds the payload of the packet. For inbound packets, it also + // holds the headers, which are consumed as the packet moves up the + // stack. Headers are guaranteed not to be split across views. + // + // The bytes backing Data are immutable, but Data itself may be trimmed + // or otherwise modified. + Data buffer.VectorisedView + + // DataOffset is used for GSO output. It is the offset into the Data + // field where the payload of this packet starts. + DataOffset int + + // DataSize is used for GSO output. It is the size of this packet's + // payload. + DataSize int + + // Header holds the headers of outbound packets. As a packet is passed + // down the stack, each layer adds to Header. + Header buffer.Prependable + + // These fields are used by both inbound and outbound packets. They + // typically overlap with the Data and Header fields. + // + // The bytes backing these views are immutable. Each field may be nil + // if either it has not been set yet or no such header exists (e.g. + // packets sent via loopback may not have a link header). + // + // These fields may be Views into other slices (either Data or Header). + // SR dosen't support this, so deep copies are necessary in some cases. + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View +} + +// Clone makes a copy of pk. It clones the Data field, which creates a new +// VectorisedView but does not deep copy the underlying bytes. +// +// Clone also does not deep copy any of its other fields. +func (pk PacketBuffer) Clone() PacketBuffer { + pk.Data = pk.Data.Clone(nil) + return pk +} diff --git a/pkg/tcpip/stack/packet_buffer_state.go b/pkg/tcpip/stack/packet_buffer_state.go new file mode 100644 index 000000000..76602549e --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_state.go @@ -0,0 +1,26 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package stack + +import "gvisor.dev/gvisor/pkg/tcpip/buffer" + +// beforeSave is invoked by stateify. +func (pk *PacketBuffer) beforeSave() { + // Non-Data fields may be slices of the Data field. This causes + // problems for SR, so during save we make each header independent. + pk.Header = pk.Header.DeepCopy() + pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) + pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) + pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index fa28b46b1..ac043b722 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -67,12 +67,12 @@ type TransportEndpoint interface { // this transport endpoint. It sets pkt.TransportHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) + HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. // HandleControlPacket takes ownership of pkt. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -100,7 +100,7 @@ type RawTransportEndpoint interface { // layer up. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt tcpip.PacketBuffer) + HandlePacket(r *Route, pkt PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -118,7 +118,7 @@ type PacketEndpoint interface { // should construct its own ethernet header for applications. // // HandlePacket takes ownership of pkt. - HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -150,7 +150,7 @@ type TransportProtocol interface { // stats purposes only). // // HandleUnknownDestinationPacket takes ownership of pkt. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -180,7 +180,7 @@ type TransportDispatcher interface { // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -189,7 +189,7 @@ type TransportDispatcher interface { // DeliverTransportControlPacket. // // DeliverTransportControlPacket takes ownership of pkt. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -242,15 +242,15 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have // already been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. - WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. - WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -265,7 +265,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt tcpip.PacketBuffer) + HandlePacket(r *Route, pkt PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -322,7 +322,7 @@ type NetworkDispatcher interface { // packets sent via loopback), and won't have the field set. // // DeliverNetworkPacket takes ownership of pkt. - DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) + DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -354,7 +354,7 @@ const ( // LinkEndpoint is the interface implemented by data link layer protocols (e.g., // ethernet, loopback, raw) and used by network layer protocols to send packets // out through the implementer's data link endpoint. When a link header exists, -// it sets each tcpip.PacketBuffer's LinkHeader field before passing it up the +// it sets each PacketBuffer's LinkHeader field before passing it up the // stack. type LinkEndpoint interface { // MTU is the maximum transmission unit for this endpoint. This is @@ -385,7 +385,7 @@ type LinkEndpoint interface { // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error // WritePackets writes packets with the given protocol through the // given route. pkts must not be zero length. @@ -393,7 +393,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, it may // require to change syscall filters. - WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. @@ -426,7 +426,7 @@ type InjectableLinkEndpoint interface { LinkEndpoint // InjectInbound injects an inbound packet. - InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) // InjectOutbound writes a fully formed outbound packet directly to the // link. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index f565aafb2..9fbe8a411 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -153,7 +153,7 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } @@ -169,7 +169,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.Pack } // WritePackets writes the set of packets through the given route. -func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { +func (r *Route) WritePackets(gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { if !r.ref.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } @@ -190,7 +190,7 @@ func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params Network // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6f423874a..a9584d636 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -31,7 +31,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/waiter" @@ -51,7 +50,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -428,7 +427,7 @@ type Stack struct { // tables are the iptables packet filtering and manipulation rules. The are // protected by tablesMu.` - tables iptables.IPTables + tables IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. @@ -738,7 +737,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, tcpip.PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -1701,7 +1700,7 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() iptables.IPTables { +func (s *Stack) IPTables() IPTables { s.tablesMu.RLock() t := s.tables s.tablesMu.RUnlock() @@ -1709,7 +1708,7 @@ func (s *Stack) IPTables() iptables.IPTables { } // SetIPTables sets the stack's iptables. -func (s *Stack) SetIPTables(ipt iptables.IPTables) { +func (s *Stack) SetIPTables(ipt IPTables) { s.tablesMu.Lock() s.tables = ipt s.tablesMu.Unlock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 9836b340f..555fcd92f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -90,7 +90,7 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ @@ -126,7 +126,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) - f.HandlePacket(r, tcpip.PacketBuffer{ + f.HandlePacket(r, stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) } @@ -153,11 +153,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -287,7 +287,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 0 { @@ -299,7 +299,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -311,7 +311,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -322,7 +322,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -334,7 +334,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -356,7 +356,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }) @@ -414,7 +414,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got := fakeNet.PacketCount(localAddrByte); got != want { @@ -2257,7 +2257,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { @@ -2339,7 +2339,7 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) buf[0] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index d4c0359e8..c55e3e8bc 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -85,7 +85,7 @@ func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { +func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { epsByNic.mu.RLock() mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] @@ -116,7 +116,7 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { epsByNic.mu.RLock() defer epsByNic.mu.RUnlock() @@ -184,7 +184,7 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer) + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { @@ -312,7 +312,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs @@ -403,7 +403,7 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -453,7 +453,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -477,7 +477,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 0e3e239c5..84311bcc8 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -150,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 5d1da2f8b..8ca9ac3cf 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -87,7 +86,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: buffer.View(v).ToVectorisedView(), }); err != nil { @@ -214,7 +213,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -231,7 +230,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -242,8 +241,8 @@ func (f *fakeTransportEndpoint) State() uint32 { func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} -func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { - return iptables.IPTables{}, nil +func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) { + return stack.IPTables{}, nil } func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} @@ -288,7 +287,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { return true } @@ -368,7 +367,7 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -379,7 +378,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -390,7 +389,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 1 { @@ -445,7 +444,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -456,7 +455,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -467,7 +466,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 1 { @@ -622,7 +621,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{ Data: req.ToVectorisedView(), }) diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index ac18ec5b1..9ce625c17 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -31,7 +31,6 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 2a396e9bc..613b12ead 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -135,7 +134,7 @@ func (e *endpoint) Close() { func (e *endpoint) ModerateRecvBuf(copied int) {} // IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (iptables.IPTables, error) { +func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil } @@ -441,7 +440,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: data.ToVectorisedView(), TransportHeader: buffer.View(icmpv4), @@ -471,7 +470,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: dataVV, TransportHeader: buffer.View(icmpv6), @@ -733,7 +732,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: @@ -795,7 +794,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 113d92901..3c47692b2 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { return true } diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD index d22de6b26..b989b1209 100644 --- a/pkg/tcpip/transport/packet/BUILD +++ b/pkg/tcpip/transport/packet/BUILD @@ -31,7 +31,6 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/waiter", ], diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 09a1cd436..df49d0995 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -29,7 +29,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -100,8 +99,8 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb } // Abort implements stack.TransportEndpoint.Abort. -func (e *endpoint) Abort() { - e.Close() +func (ep *endpoint) Abort() { + ep.Close() } // Close implements tcpip.Endpoint.Close. @@ -134,7 +133,7 @@ func (ep *endpoint) Close() { func (ep *endpoint) ModerateRecvBuf(copied int) {} // IPTables implements tcpip.Endpoint.IPTables. -func (ep *endpoint) IPTables() (iptables.IPTables, error) { +func (ep *endpoint) IPTables() (stack.IPTables, error) { return ep.stack.IPTables(), nil } @@ -299,7 +298,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.PacketEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { ep.rcvMu.Lock() // Drop the packet if our buffer is currently full. diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index c9baf4600..2eab09088 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -32,7 +32,6 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/tcpip/transport/packet", "//pkg/waiter", diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 2ef5fac76..536dafd1e 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -30,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -161,7 +160,7 @@ func (e *endpoint) Close() { func (e *endpoint) ModerateRecvBuf(copied int) {} // IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (iptables.IPTables, error) { +func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil } @@ -342,7 +341,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, switch e.NetProto { case header.IPv4ProtocolNumber: if !e.associated { - if err := route.WriteHeaderIncludedPacket(tcpip.PacketBuffer{ + if err := route.WriteHeaderIncludedPacket(stack.PacketBuffer{ Data: buffer.View(payloadBytes).ToVectorisedView(), }); err != nil { return 0, nil, err @@ -350,7 +349,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, break } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: buffer.View(payloadBytes).ToVectorisedView(), }); err != nil { @@ -574,7 +573,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 2fdf6c0a5..7f94f9646 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -66,7 +66,6 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 53193afc6..79552fc61 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -705,7 +705,7 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu return nil } -func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *tcpip.PacketBuffer, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) { +func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) { optLen := len(opts) hdr := &pkt.Header packetSize := pkt.DataSize @@ -752,7 +752,7 @@ func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.Vect // Allocate one big slice for all the headers. hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen buf := make([]byte, n*hdrSize) - pkts := make([]tcpip.PacketBuffer, n) + pkts := make([]stack.PacketBuffer, n) for i := range pkts { pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize]) } @@ -795,7 +795,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso) } - pkt := tcpip.PacketBuffer{ + pkt := stack.PacketBuffer{ Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), DataOffset: 0, DataSize: data.Size(), diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 90ac956a9..6062ca916 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -187,7 +186,7 @@ func (d *dispatcher) wait() { } } -func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { +func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { ep := stackEP.(*endpoint) s := newSegment(r, id, pkt) if !s.parse() { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index eb8a9d73e..594efaa11 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -30,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -1120,7 +1119,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { } // IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (iptables.IPTables, error) { +func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil } @@ -2388,7 +2387,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { }, nil } -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first // land at the Dispatcher which then can either delivery using the // worker go routine or directly do the invoke the tcp processing inline @@ -2407,7 +2406,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool { } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index c9ee5bf06..a094471b8 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -61,7 +61,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { s := newSegment(r, id, pkt) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index b0f918bb4..57985b85d 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -140,7 +140,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // to a specific processing queue. Each queue is serviced by its own processor // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. -func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { +func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { p.dispatcher.queuePacket(r, ep, id, pkt) } @@ -151,7 +151,7 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { s := newSegment(r, id, pkt) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 5d0bc4f72..e6fe7985d 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -18,7 +18,6 @@ import ( "sync/atomic" "time" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -61,7 +60,7 @@ type segment struct { xmitCount uint32 } -func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 8cea20fb5..d4f6bc635 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -307,7 +307,7 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -363,7 +363,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ Data: s, }) } @@ -371,7 +371,7 @@ func (c *Context) SendSegment(s buffer.VectorisedView) { // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ Data: c.BuildSegment(payload, h), }) } @@ -380,7 +380,7 @@ func (c *Context) SendPacket(payload []byte, h *Headers) { // & TCPheaders) in an IPv4 packet via the link layer endpoint using the // provided source and destination IPv4 addresses. func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ Data: c.BuildSegmentWithAddrs(payload, h, src, dst), }) } @@ -548,7 +548,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index adc908e24..b5d2d0ba6 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -32,7 +32,6 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 0af4514e1..a3372ac58 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -234,7 +233,7 @@ func (e *endpoint) Close() { func (e *endpoint) ModerateRecvBuf(copied int) {} // IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (iptables.IPTables, error) { +func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil } @@ -913,7 +912,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, tcpip.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, stack.PacketBuffer{ Header: hdr, Data: data, TransportHeader: buffer.View(udp), @@ -1260,7 +1259,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. hdr := header.UDP(pkt.Data.First()) if int(hdr.Length()) > pkt.Data.Size() { @@ -1327,7 +1326,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt tcpip.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { } // State implements tcpip.Endpoint.State. diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index fc706ede2..a674ceb68 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -43,7 +43,7 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, route: r, @@ -61,7 +61,7 @@ type ForwarderRequest struct { stack *stack.Stack route *stack.Route id stack.TransportEndpointID - pkt tcpip.PacketBuffer + pkt stack.PacketBuffer } // ID returns the 4-tuple (src address, src port, dst address, dst port) that diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 8df089d22..6e31a9bac 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -66,7 +66,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. hdr := header.UDP(pkt.Data.First()) if int(hdr.Length()) > pkt.Data.Size() { @@ -135,7 +135,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -172,7 +172,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: payload, }) diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 34b7c2360..0905726c1 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -439,7 +439,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), @@ -486,7 +486,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), -- cgit v1.2.3 From c8eeedcc1d6b1ee25532ae630a7efd7aa4656bdc Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 24 Mar 2020 15:33:16 -0700 Subject: Add support for setting TCP segment hash. This allows the link layer endpoints to consistenly hash a TCP segment to a single underlying queue in case a link layer endpoint does support multiple underlying queues. Updates #231 PiperOrigin-RevId: 302760664 --- pkg/tcpip/link/fdbased/endpoint.go | 10 ++- pkg/tcpip/link/fdbased/endpoint_test.go | 76 +++++++++++++----- pkg/tcpip/stack/BUILD | 1 + pkg/tcpip/stack/packet_buffer.go | 5 ++ pkg/tcpip/stack/packet_buffer_state.go | 1 + pkg/tcpip/stack/rand.go | 40 +++++++++ pkg/tcpip/stack/stack.go | 44 +++++++++- pkg/tcpip/transport/tcp/accept.go | 20 ++++- pkg/tcpip/transport/tcp/connect.go | 138 +++++++++++++++++++++++--------- pkg/tcpip/transport/tcp/endpoint.go | 5 ++ pkg/tcpip/transport/tcp/protocol.go | 10 ++- 11 files changed, 280 insertions(+), 70 deletions(-) create mode 100644 pkg/tcpip/stack/rand.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 235e647ff..3b3b6909b 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -405,6 +405,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne eth.Encode(ethHdr) } + fd := e.fds[pkt.Hash%uint32(len(e.fds))] if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} if gso != nil { @@ -428,14 +429,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr) - return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) + return rawfile.NonBlockingWrite3(fd, vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) } if pkt.Data.Size() == 0 { - return rawfile.NonBlockingWrite(e.fds[0], pkt.Header.View()) + return rawfile.NonBlockingWrite(fd, pkt.Header.View()) } - return rawfile.NonBlockingWrite3(e.fds[0], pkt.Header.View(), pkt.Data.ToView(), nil) + return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil) } // WritePackets writes outbound packets to the file descriptor. If it is not @@ -551,7 +552,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.Pac packets := 0 for packets < n { - sent, err := rawfile.NonBlockingSendMMsg(e.fds[0], mmsgHdrs) + fd := e.fds[pkts[packets].Hash%uint32(len(e.fds))] + sent, err := rawfile.NonBlockingSendMMsg(fd, mmsgHdrs) if err != nil { return packets, err } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index c7dbbbc6b..3bfb15a8e 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -49,36 +49,42 @@ type packetInfo struct { } type context struct { - t *testing.T - fds [2]int - ep stack.LinkEndpoint - ch chan packetInfo - done chan struct{} + t *testing.T + readFDs []int + writeFDs []int + ep stack.LinkEndpoint + ch chan packetInfo + done chan struct{} } func newContext(t *testing.T, opt *Options) *context { - fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) + firstFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) + if err != nil { + t.Fatalf("Socketpair failed: %v", err) + } + secondFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) if err != nil { t.Fatalf("Socketpair failed: %v", err) } - done := make(chan struct{}, 1) + done := make(chan struct{}, 2) opt.ClosedFunc = func(*tcpip.Error) { done <- struct{}{} } - opt.FDs = []int{fds[1]} + opt.FDs = []int{firstFDPair[1], secondFDPair[1]} ep, err := New(opt) if err != nil { t.Fatalf("Failed to create FD endpoint: %v", err) } c := &context{ - t: t, - fds: fds, - ep: ep, - ch: make(chan packetInfo, 100), - done: done, + t: t, + readFDs: []int{firstFDPair[0], secondFDPair[0]}, + writeFDs: opt.FDs, + ep: ep, + ch: make(chan packetInfo, 100), + done: done, } ep.Attach(c) @@ -87,9 +93,14 @@ func newContext(t *testing.T, opt *Options) *context { } func (c *context) cleanup() { - syscall.Close(c.fds[0]) + for _, fd := range c.readFDs { + syscall.Close(fd) + } + <-c.done <-c.done - syscall.Close(c.fds[1]) + for _, fd := range c.writeFDs { + syscall.Close(fd) + } } func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { @@ -136,7 +147,7 @@ func TestAddress(t *testing.T) { } } -func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { +func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash uint32) { c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize}) defer c.cleanup() @@ -171,13 +182,15 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) { if err := c.ep.WritePacket(r, gso, proto, stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), + Hash: hash, }); err != nil { t.Fatalf("WritePacket failed: %v", err) } - // Read from fd, then compare with what we wrote. + // Read from the corresponding FD, then compare with what we wrote. b = make([]byte, mtu) - n, err := syscall.Read(c.fds[0], b) + fd := c.readFDs[hash%uint32(len(c.readFDs))] + n, err := syscall.Read(fd, b) if err != nil { t.Fatalf("Read failed: %v", err) } @@ -238,7 +251,7 @@ func TestWritePacket(t *testing.T) { t.Run( fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso), func(t *testing.T) { - testWritePacket(t, plen, eth, gso) + testWritePacket(t, plen, eth, gso, 0) }, ) } @@ -246,6 +259,27 @@ func TestWritePacket(t *testing.T) { } } +func TestHashedWritePacket(t *testing.T) { + lengths := []int{0, 100, 1000} + eths := []bool{true, false} + gsos := []uint32{0, 32768} + hashes := []uint32{0, 1} + for _, eth := range eths { + for _, plen := range lengths { + for _, gso := range gsos { + for _, hash := range hashes { + t.Run( + fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v,Hash=%d", eth, plen, gso, hash), + func(t *testing.T) { + testWritePacket(t, plen, eth, gso, hash) + }, + ) + } + } + } + } +} + func TestPreserveSrcAddress(t *testing.T) { baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99") @@ -270,7 +304,7 @@ func TestPreserveSrcAddress(t *testing.T) { // Read from the FD, then compare with what we wrote. b := make([]byte, mtu) - n, err := syscall.Read(c.fds[0], b) + n, err := syscall.Read(c.readFDs[0], b) if err != nil { t.Fatalf("Read failed: %v", err) } @@ -314,7 +348,7 @@ func TestDeliverPacket(t *testing.T) { } // Write packet via the file descriptor. - if _, err := syscall.Write(c.fds[0], all); err != nil { + if _, err := syscall.Write(c.readFDs[0], all); err != nil { t.Fatalf("Write failed: %v", err) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 7a43a1d4e..8d80e9cee 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -30,6 +30,7 @@ go_library( "nic.go", "packet_buffer.go", "packet_buffer_state.go", + "rand.go", "registration.go", "route.go", "stack.go", diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 1850fa8c3..9505a4e92 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -10,6 +10,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + package stack import "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -54,6 +55,10 @@ type PacketBuffer struct { LinkHeader buffer.View NetworkHeader buffer.View TransportHeader buffer.View + + // Hash is the transport layer hash of this packet. A value of zero + // indicates no valid hash has been set. + Hash uint32 } // Clone makes a copy of pk. It clones the Data field, which creates a new diff --git a/pkg/tcpip/stack/packet_buffer_state.go b/pkg/tcpip/stack/packet_buffer_state.go index 76602549e..0c6b7924c 100644 --- a/pkg/tcpip/stack/packet_buffer_state.go +++ b/pkg/tcpip/stack/packet_buffer_state.go @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + package stack import "gvisor.dev/gvisor/pkg/tcpip/buffer" diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go new file mode 100644 index 000000000..421fb5c15 --- /dev/null +++ b/pkg/tcpip/stack/rand.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + mathrand "math/rand" + + "gvisor.dev/gvisor/pkg/sync" +) + +// lockedRandomSource provides a threadsafe rand.Source. +type lockedRandomSource struct { + mu sync.Mutex + src mathrand.Source +} + +func (r *lockedRandomSource) Int63() (n int64) { + r.mu.Lock() + n = r.src.Int63() + r.mu.Unlock() + return n +} + +func (r *lockedRandomSource) Seed(seed int64) { + r.mu.Lock() + r.src.Seed(seed) + r.mu.Unlock() +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a9584d636..41398a1b6 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -20,7 +20,9 @@ package stack import ( + "bytes" "encoding/binary" + mathrand "math/rand" "sync/atomic" "time" @@ -465,6 +467,10 @@ type Stack struct { // forwarder holds the packets that wait for their link-address resolutions // to complete, and forwards them when each resolution is done. forwarder *forwardQueue + + // randomGenerator is an injectable pseudo random generator that can be + // used when a random number is required. + randomGenerator *mathrand.Rand } // UniqueID is an abstract generator of unique identifiers. @@ -525,9 +531,16 @@ type Options struct { // this is non-nil. RawFactory RawFactory - // OpaqueIIDOpts hold the options for generating opaque interface identifiers - // (IIDs) as outlined by RFC 7217. + // OpaqueIIDOpts hold the options for generating opaque interface + // identifiers (IIDs) as outlined by RFC 7217. OpaqueIIDOpts OpaqueInterfaceIdentifierOptions + + // RandSource is an optional source to use to generate random + // numbers. If omitted it defaults to a Source seeded by the data + // returned by rand.Read(). + // + // RandSource must be thread-safe. + RandSource mathrand.Source } // TransportEndpointInfo holds useful information about a transport endpoint @@ -623,6 +636,13 @@ func New(opts Options) *Stack { opts.UniqueID = new(uniqueIDGenerator) } + randSrc := opts.RandSource + if randSrc == nil { + // Source provided by mathrand.NewSource is not thread-safe so + // we wrap it in a simple thread-safe version. + randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} + } + // Make sure opts.NDPConfigs contains valid values only. opts.NDPConfigs.validate() @@ -645,6 +665,7 @@ func New(opts Options) *Stack { ndpDisp: opts.NDPDisp, opaqueIIDOpts: opts.OpaqueIIDOpts, forwarder: newForwardQueue(), + randomGenerator: mathrand.New(randSrc), } // Add specified network protocols. @@ -1818,6 +1839,12 @@ func (s *Stack) Seed() uint32 { return s.seed } +// Rand returns a reference to a pseudo random generator that can be used +// to generate random numbers as required. +func (s *Stack) Rand() *mathrand.Rand { + return s.randomGenerator +} + func generateRandUint32() uint32 { b := make([]byte, 4) if _, err := rand.Read(b); err != nil { @@ -1825,3 +1852,16 @@ func generateRandUint32() uint32 { } return binary.LittleEndian.Uint32(b) } + +func generateRandInt64() int64 { + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + panic(err) + } + buf := bytes.NewReader(b) + var v int64 + if err := binary.Read(buf, binary.LittleEndian, &v); err != nil { + panic(err) + } + return v +} diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 3f80995f3..b4c4c8ab1 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -445,7 +445,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST // must be sent in response to a SYN-ACK while in the listen // state to prevent completing a handshake from an old SYN. - e.sendTCP(&s.route, s.id, buffer.VectorisedView{}, e.ttl, e.sendTOS, header.TCPFlagRst, s.ackNumber, 0, 0, nil, nil) + e.sendTCP(&s.route, tcpFields{ + id: s.id, + ttl: e.ttl, + tos: e.sendTOS, + flags: header.TCPFlagRst, + seq: s.ackNumber, + ack: 0, + rcvWnd: 0, + }, buffer.VectorisedView{}, nil) return } @@ -493,7 +501,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { TSEcr: opts.TSVal, MSS: mssForRoute(&s.route), } - e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + e.sendSynTCP(&s.route, tcpFields{ + id: s.id, + ttl: e.ttl, + tos: e.sendTOS, + flags: header.TCPFlagSyn | header.TCPFlagAck, + seq: cookie, + ack: s.sequenceNumber + 1, + rcvWnd: ctx.rcvWnd, + }, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 79552fc61..1d245c2c6 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -308,7 +308,15 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { if ttl == 0 { ttl = s.route.DefaultTTL() } - h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&s.route, tcpFields{ + id: h.ep.ID, + ttl: ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) return nil } @@ -361,7 +369,15 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&s.route, tcpFields{ + id: h.ep.ID, + ttl: h.ep.ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) return nil } @@ -550,7 +566,16 @@ func (h *handshake) execute() *tcpip.Error { synOpts.WS = -1 } } - h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + + h.ep.sendSynTCP(&h.ep.route, tcpFields{ + id: h.ep.ID, + ttl: h.ep.ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) for h.state != handshakeCompleted { h.ep.mu.Unlock() @@ -573,7 +598,15 @@ func (h *handshake) execute() *tcpip.Error { // the connection with another ACK or data (as ACKs are never // retransmitted on their own). if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { - h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&h.ep.route, tcpFields{ + id: h.ep.ID, + ttl: h.ep.ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) } case wakerForNotification: @@ -686,18 +719,33 @@ func makeSynOptions(opts header.TCPSynOptions) []byte { return options[:offset] } -func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { - options := makeSynOptions(opts) +// tcpFields is a struct to carry different parameters required by the +// send*TCP variant functions below. +type tcpFields struct { + id stack.TransportEndpointID + ttl uint8 + tos uint8 + flags byte + seq seqnum.Value + ack seqnum.Value + rcvWnd seqnum.Size + opts []byte + txHash uint32 +} + +func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) *tcpip.Error { + tf.opts = makeSynOptions(opts) // We ignore SYN send errors and let the callers re-attempt send. - if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil { + if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil { e.stats.SendErrors.SynSendToNetworkFailed.Increment() } - putOptions(options) + putOptions(tf.opts) return nil } -func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { - if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil { +func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { + tf.txHash = e.txHash + if err := sendTCP(r, tf, data, gso); err != nil { e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() return err } @@ -705,8 +753,8 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu return nil } -func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) { - optLen := len(opts) +func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { + optLen := len(tf.opts) hdr := &pkt.Header packetSize := pkt.DataSize off := pkt.DataOffset @@ -714,15 +762,15 @@ func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *stack.Packet tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) pkt.TransportHeader = buffer.View(tcp) tcp.Encode(&header.TCPFields{ - SrcPort: id.LocalPort, - DstPort: id.RemotePort, - SeqNum: uint32(seq), - AckNum: uint32(ack), + SrcPort: tf.id.LocalPort, + DstPort: tf.id.RemotePort, + SeqNum: uint32(tf.seq), + AckNum: uint32(tf.ack), DataOffset: uint8(header.TCPMinimumSize + optLen), - Flags: flags, - WindowSize: uint16(rcvWnd), + Flags: tf.flags, + WindowSize: uint16(tf.rcvWnd), }) - copy(tcp[header.TCPMinimumSize:], opts) + copy(tcp[header.TCPMinimumSize:], tf.opts) length := uint16(hdr.UsedLength() + packetSize) xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) @@ -737,13 +785,12 @@ func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, pkt *stack.Packet xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } - } -func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { - optLen := len(opts) - if rcvWnd > 0xffff { - rcvWnd = 0xffff +func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { + optLen := len(tf.opts) + if tf.rcvWnd > 0xffff { + tf.rcvWnd = 0xffff } mss := int(gso.MSS) @@ -768,14 +815,15 @@ func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.Vect pkts[i].DataOffset = off pkts[i].DataSize = packetSize pkts[i].Data = data - buildTCPHdr(r, id, &pkts[i], flags, seq, ack, rcvWnd, opts, gso) + pkts[i].Hash = tf.txHash + buildTCPHdr(r, tf, &pkts[i], gso) off += packetSize - seq = seq.Add(seqnum.Size(packetSize)) + tf.seq = tf.seq.Add(seqnum.Size(packetSize)) } - if ttl == 0 { - ttl = r.DefaultTTL() + if tf.ttl == 0 { + tf.ttl = r.DefaultTTL() } - sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}) + sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}) if err != nil { r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent)) } @@ -785,14 +833,14 @@ func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.Vect // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { - optLen := len(opts) - if rcvWnd > 0xffff { - rcvWnd = 0xffff +func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { + optLen := len(tf.opts) + if tf.rcvWnd > 0xffff { + tf.rcvWnd = 0xffff } if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { - return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso) + return sendTCPBatch(r, tf, data, gso) } pkt := stack.PacketBuffer{ @@ -800,18 +848,19 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise DataOffset: 0, DataSize: data.Size(), Data: data, + Hash: tf.txHash, } - buildTCPHdr(r, id, &pkt, flags, seq, ack, rcvWnd, opts, gso) + buildTCPHdr(r, tf, &pkt, gso) - if ttl == 0 { - ttl = r.DefaultTTL() + if tf.ttl == 0 { + tf.ttl = r.DefaultTTL() } - if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, pkt); err != nil { + if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil { r.Stats().TCP.SegmentSendErrors.Increment() return err } r.Stats().TCP.SegmentsSent.Increment() - if (flags & header.TCPFlagRst) != 0 { + if (tf.flags & header.TCPFlagRst) != 0 { r.Stats().TCP.ResetsSent.Increment() } return nil @@ -863,7 +912,16 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso) + err := e.sendTCP(&e.route, tcpFields{ + id: e.ID, + ttl: e.ttl, + tos: e.sendTOS, + flags: flags, + seq: seq, + ack: ack, + rcvWnd: rcvWnd, + opts: options, + }, data, e.gso) putOptions(options) return err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 594efaa11..b6e571361 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -581,6 +581,10 @@ type endpoint struct { // endpoint and at this point the endpoint is only around // to complete the TCP shutdown. closed bool + + // txHash is the transport layer hash to be set on outbound packets + // emitted by this endpoint. + txHash uint32 } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -771,6 +775,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue count: 9, }, uniqueID: s.UniqueID(), + txHash: s.Rand().Uint32(), } var ss SendBufferSizeOption diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 57985b85d..1377107ca 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -191,7 +191,15 @@ func replyWithReset(s *segment) { flags |= header.TCPFlagAck ack = s.sequenceNumber.Add(s.logicalLen()) } - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) + sendTCP(&s.route, tcpFields{ + id: s.id, + ttl: s.route.DefaultTTL(), + tos: stack.DefaultTOS, + flags: flags, + seq: seq, + ack: ack, + rcvWnd: 0, + }, buffer.VectorisedView{}, nil /* gso */) } // SetOption implements stack.TransportProtocol.SetOption. -- cgit v1.2.3 From c64796748c735af8b304e62d7833648b691d5a72 Mon Sep 17 00:00:00 2001 From: Jay Zhuang Date: Thu, 26 Mar 2020 08:46:33 -0700 Subject: Clean up transport_demuxer.go and test - Change receiver of endpoint lookup functions - Remove unused struct fields and functions in test - s/%v/%s/ for errors - Capitalize NIC https://github.com/golang/go/wiki/CodeReviewComments#initialisms PiperOrigin-RevId: 303119580 --- pkg/tcpip/stack/transport_demuxer.go | 239 +++++++++++++++--------------- pkg/tcpip/stack/transport_demuxer_test.go | 115 +++++--------- 2 files changed, 160 insertions(+), 194 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index c55e3e8bc..9a33ed375 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -35,7 +35,7 @@ type protocolIDs struct { type transportEndpoints struct { // mu protects all fields of the transportEndpoints. mu sync.RWMutex - endpoints map[TransportEndpointID]*endpointsByNic + endpoints map[TransportEndpointID]*endpointsByNIC // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. rawEndpoints []RawTransportEndpoint @@ -46,11 +46,11 @@ type transportEndpoints struct { func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() - epsByNic, ok := eps.endpoints[id] + epsByNIC, ok := eps.endpoints[id] if !ok { return } - if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + if !epsByNIC.unregisterEndpoint(bindToDevice, ep) { return } delete(eps.endpoints, id) @@ -66,18 +66,85 @@ func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint { return es } -type endpointsByNic struct { +// iterEndpointsLocked yields all endpointsByNIC in eps that match id, in +// descending order of match quality. If a call to yield returns false, +// iterEndpointsLocked stops iteration and returns immediately. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) { + // Try to find a match with the id as provided. + if ep, ok := eps.endpoints[id]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with the id minus the local address. + nid := id + + nid.LocalAddress = "" + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with the id minus the remote part. + nid.LocalAddress = id.LocalAddress + nid.RemoteAddress = "" + nid.RemotePort = 0 + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } + + // Try to find a match with only the local port. + nid.LocalAddress = "" + if ep, ok := eps.endpoints[nid]; ok { + if !yield(ep) { + return + } + } +} + +// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in +// descending order of match quality. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC { + var matchedEPs []*endpointsByNIC + eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { + matchedEPs = append(matchedEPs, ep) + return true + }) + return matchedEPs +} + +// findEndpointLocked returns the endpoint that most closely matches the given id. +// +// Preconditions: eps.mu must be locked. +func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC { + var matchedEP *endpointsByNIC + eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool { + matchedEP = ep + return false + }) + return matchedEP +} + +type endpointsByNIC struct { mu sync.RWMutex endpoints map[tcpip.NICID]*multiPortEndpoint // seed is a random secret for a jenkins hash. seed uint32 } -func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { - epsByNic.mu.RLock() - defer epsByNic.mu.RUnlock() +func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() var eps []TransportEndpoint - for _, ep := range epsByNic.endpoints { + for _, ep := range epsByNIC.endpoints { eps = append(eps, ep.transportEndpoints()...) } return eps @@ -85,13 +152,13 @@ func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { - epsByNic.mu.RLock() +func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { + epsByNIC.mu.RLock() - mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] if !ok { - if mpep, ok = epsByNic.endpoints[0]; !ok { - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + if mpep, ok = epsByNIC.endpoints[0]; !ok { + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } } @@ -100,29 +167,29 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p // endpoints bound to the right device. if isMulticastOrBroadcast(id.LocalAddress) { mpep.handlePacketAll(r, id, pkt) - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } // multiPortEndpoints are guaranteed to have at least one element. - transEP := selectEndpoint(id, mpep, epsByNic.seed) + transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { queuedProtocol.QueuePacket(r, transEP, id, pkt) - epsByNic.mu.RUnlock() + epsByNIC.mu.RUnlock() return } transEP.HandlePacket(r, id, pkt) - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { - epsByNic.mu.RLock() - defer epsByNic.mu.RUnlock() +func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() - mpep, ok := epsByNic.endpoints[n.ID()] + mpep, ok := epsByNIC.endpoints[n.ID()] if !ok { - mpep, ok = epsByNic.endpoints[0] + mpep, ok = epsByNIC.endpoints[0] } if !ok { return @@ -132,16 +199,16 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, pkt) + selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { - epsByNic.mu.Lock() - defer epsByNic.mu.Unlock() +func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNIC.mu.Lock() + defer epsByNIC.mu.Unlock() - multiPortEp, ok := epsByNic.endpoints[bindToDevice] + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] if !ok { multiPortEp = &multiPortEndpoint{ demux: d, @@ -149,24 +216,24 @@ func (epsByNic *endpointsByNic) registerEndpoint(d *transportDemuxer, netProto t transProto: transProto, reuse: reusePort, } - epsByNic.endpoints[bindToDevice] = multiPortEp + epsByNIC.endpoints[bindToDevice] = multiPortEp } return multiPortEp.singleRegisterEndpoint(t, reusePort) } -// unregisterEndpoint returns true if endpointsByNic has to be unregistered. -func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { - epsByNic.mu.Lock() - defer epsByNic.mu.Unlock() - multiPortEp, ok := epsByNic.endpoints[bindToDevice] +// unregisterEndpoint returns true if endpointsByNIC has to be unregistered. +func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { + epsByNIC.mu.Lock() + defer epsByNIC.mu.Unlock() + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] if !ok { return false } if multiPortEp.unregisterEndpoint(t) { - delete(epsByNic.endpoints, bindToDevice) + delete(epsByNIC.endpoints, bindToDevice) } - return len(epsByNic.endpoints) == 0 + return len(epsByNIC.endpoints) == 0 } // transportDemuxer demultiplexes packets targeted at a transport endpoint @@ -198,7 +265,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { for proto := range stack.transportProtocols { protoIDs := protocolIDs{netProto, proto} d.protocol[protoIDs] = &transportEndpoints{ - endpoints: make(map[TransportEndpointID]*endpointsByNic), + endpoints: make(map[TransportEndpointID]*endpointsByNIC), } qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol) if isQueued { @@ -378,16 +445,16 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.mu.Lock() defer eps.mu.Unlock() - epsByNic, ok := eps.endpoints[id] + epsByNIC, ok := eps.endpoints[id] if !ok { - epsByNic = &endpointsByNic{ + epsByNIC = &endpointsByNIC{ endpoints: make(map[tcpip.NICID]*multiPortEndpoint), seed: rand.Uint32(), } - eps.endpoints[id] = epsByNic + eps.endpoints[id] = epsByNIC } - return epsByNic.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) + return epsByNIC.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it @@ -413,7 +480,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // transport endpoints. if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { eps.mu.RLock() - destEPs := d.findAllEndpointsLocked(eps, id) + destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. if len(destEPs) == 0 { @@ -439,7 +506,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto } eps.mu.RLock() - ep := d.findEndpointLocked(eps, id) + ep := eps.findEndpointLocked(id) eps.mu.RUnlock() if ep == nil { if protocol == header.UDPProtocolNumber { @@ -483,115 +550,47 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco return false } - // Try to find the endpoint. eps.mu.RLock() - ep := d.findEndpointLocked(eps, id) + ep := eps.findEndpointLocked(id) eps.mu.RUnlock() - - // Fail if we didn't find one. if ep == nil { return false } - // Deliver the packet. ep.handleControlPacket(n, id, typ, extra, pkt) - return true } -// iterEndpointsLocked yields all endpointsByNic in eps that match id, in -// descending order of match quality. If a call to yield returns false, -// iterEndpointsLocked stops iteration and returns immediately. -// -// Preconditions: eps.mu must be locked. -func (d *transportDemuxer) iterEndpointsLocked(eps *transportEndpoints, id TransportEndpointID, yield func(*endpointsByNic) bool) { - // Try to find a match with the id as provided. - if ep, ok := eps.endpoints[id]; ok { - if !yield(ep) { - return - } - } - - // Try to find a match with the id minus the local address. - nid := id - - nid.LocalAddress = "" - if ep, ok := eps.endpoints[nid]; ok { - if !yield(ep) { - return - } - } - - // Try to find a match with the id minus the remote part. - nid.LocalAddress = id.LocalAddress - nid.RemoteAddress = "" - nid.RemotePort = 0 - if ep, ok := eps.endpoints[nid]; ok { - if !yield(ep) { - return - } - } - - // Try to find a match with only the local port. - nid.LocalAddress = "" - if ep, ok := eps.endpoints[nid]; ok { - if !yield(ep) { - return - } - } -} - -func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic { - var matchedEPs []*endpointsByNic - d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool { - matchedEPs = append(matchedEPs, ep) - return true - }) - return matchedEPs -} - // findTransportEndpoint find a single endpoint that most closely matches the provided id. func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { return nil } - // Try to find the endpoint. + eps.mu.RLock() - epsByNic := d.findEndpointLocked(eps, id) - // Fail if we didn't find one. - if epsByNic == nil { + epsByNIC := eps.findEndpointLocked(id) + if epsByNIC == nil { eps.mu.RUnlock() return nil } - epsByNic.mu.RLock() + epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] if !ok { - if mpep, ok = epsByNic.endpoints[0]; !ok { - epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + if mpep, ok = epsByNIC.endpoints[0]; !ok { + epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return nil } } - ep := selectEndpoint(id, mpep, epsByNic.seed) - epsByNic.mu.RUnlock() + ep := selectEndpoint(id, mpep, epsByNIC.seed) + epsByNIC.mu.RUnlock() return ep } -// findEndpointLocked returns the endpoint that most closely matches the given -// id. -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic { - var matchedEP *endpointsByNic - d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool { - matchedEP = ep - return false - }) - return matchedEP -} - // registerRawEndpoint registers the given endpoint with the dispatcher such // that packets of the appropriate protocol are delivered to it. A single // packet can be sent to one or more raw endpoints along with a non-raw diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 84311bcc8..75c119c99 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -40,75 +40,47 @@ const ( ) type testContext struct { - t *testing.T linkEps map[tcpip.NICID]*channel.Endpoint s *stack.Stack - - ep tcpip.Endpoint - wq waiter.Queue -} - -func (c *testContext) cleanup() { - if c.ep != nil { - c.ep.Close() - } -} - -func (c *testContext) createV6Endpoint(v6only bool) { - var err *tcpip.Error - c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) - } + wq waiter.Queue } // newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs. func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) linkEps := make(map[tcpip.NICID]*channel.Endpoint) for _, linkEpID := range linkEpIDs { channelEp := channel.New(256, mtu, "") if err := s.CreateNIC(linkEpID, channelEp); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatalf("CreateNIC failed: %s", err) } linkEps[linkEpID] = channelEp if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress IPv4 failed: %v", err) + t.Fatalf("AddAddress IPv4 failed: %s", err) } if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress IPv6 failed: %v", err) + t.Fatalf("AddAddress IPv6 failed: %s", err) } } s.SetRouteTable([]tcpip.Route{ - { - Destination: header.IPv4EmptySubnet, - NIC: 1, - }, - { - Destination: header.IPv6EmptySubnet, - NIC: 1, - }, + {Destination: header.IPv4EmptySubnet, NIC: 1}, + {Destination: header.IPv6EmptySubnet, NIC: 1}, }) return &testContext{ - t: t, s: s, linkEps: linkEps, } } type headers struct { - srcPort uint16 - dstPort uint16 + srcPort, dstPort uint16 } func newPayload() []byte { @@ -179,15 +151,15 @@ func TestTransportDemuxerRegister(t *testing.T) { t.Fatalf("%T does not implement stack.TransportEndpoint", ep) } if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want { - t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) + t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) } }) } } -// TestReuseBindToDevice injects varied packets on input devices and checks that +// TestBindToDeviceDistribution injects varied packets on input devices and checks that // the distribution of packets received matches expectations. -func TestDistribution(t *testing.T) { +func TestBindToDeviceDistribution(t *testing.T) { type endpointSockopts struct { reuse int bindToDevice tcpip.NICID @@ -196,19 +168,19 @@ func TestDistribution(t *testing.T) { name string // endpoints will received the inject packets. endpoints []endpointSockopts - // wantedDistribution is the wanted ratio of packets received on each + // wantDistributions is the want ratio of packets received on each // endpoint for each NIC on which packets are injected. - wantedDistributions map[tcpip.NICID][]float64 + wantDistributions map[tcpip.NICID][]float64 }{ { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - {1, 0}, - {1, 0}, - {1, 0}, - {1, 0}, - {1, 0}, + {reuse: 1, bindToDevice: 0}, + {reuse: 1, bindToDevice: 0}, + {reuse: 1, bindToDevice: 0}, + {reuse: 1, bindToDevice: 0}, + {reuse: 1, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. @@ -219,9 +191,9 @@ func TestDistribution(t *testing.T) { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - {0, 1}, - {0, 2}, - {0, 3}, + {reuse: 0, bindToDevice: 1}, + {reuse: 0, bindToDevice: 2}, + {reuse: 0, bindToDevice: 3}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. @@ -236,12 +208,12 @@ func TestDistribution(t *testing.T) { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - {1, 1}, - {1, 1}, - {1, 2}, - {1, 2}, - {1, 2}, - {1, 0}, + {reuse: 1, bindToDevice: 1}, + {reuse: 1, bindToDevice: 1}, + {reuse: 1, bindToDevice: 2}, + {reuse: 1, bindToDevice: 2}, + {reuse: 1, bindToDevice: 2}, + {reuse: 1, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to @@ -256,16 +228,13 @@ func TestDistribution(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { - for device, wantedDistribution := range test.wantedDistributions { + for device, wantDistribution := range test.wantDistributions { t.Run(string(device), func(t *testing.T) { var devices []tcpip.NICID - for d := range test.wantedDistributions { + for d := range test.wantDistributions { devices = append(devices, d) } c := newDualTestContextMultiNIC(t, defaultMTU, devices) - defer c.cleanup() - - c.createV6Endpoint(false) eps := make(map[tcpip.Endpoint]int) @@ -281,7 +250,7 @@ func TestDistribution(t *testing.T) { var err *tcpip.Error ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } eps[ep] = i @@ -294,20 +263,20 @@ func TestDistribution(t *testing.T) { defer ep.Close() reusePortOption := tcpip.ReusePortOption(endpoint.reuse) if err := ep.SetSockOpt(reusePortOption); err != nil { - c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) + t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", reusePortOption, i, err) } bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) if err := ep.SetSockOpt(bindToDeviceOption); err != nil { - c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) + t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err) } if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { - t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) + t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) } } npackets := 100000 nports := 10000 - if got, want := len(test.endpoints), len(wantedDistribution); got != want { + if got, want := len(test.endpoints), len(wantDistribution); got != want { t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) } ports := make(map[uint16]tcpip.Endpoint) @@ -322,11 +291,9 @@ func TestDistribution(t *testing.T) { dstPort: stackPort}, device) - var addr tcpip.FullAddress ep := <-pollChannel - _, _, err := ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) + if _, _, err := ep.Read(nil); err != nil { + t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err) } stats[ep]++ if i < nports { @@ -342,13 +309,13 @@ func TestDistribution(t *testing.T) { // Check that a packet distribution is as expected. for ep, i := range eps { - wantedRatio := wantedDistribution[i] - wantedRecv := wantedRatio * float64(npackets) + wantRatio := wantDistribution[i] + wantRecv := wantRatio * float64(npackets) actualRecv := stats[ep] actualRatio := float64(stats[ep]) / float64(npackets) // The deviation is less than 10%. - if math.Abs(actualRatio-wantedRatio) > 0.05 { - t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) + if math.Abs(actualRatio-wantRatio) > 0.05 { + t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets) } } }) -- cgit v1.2.3 From d5ef8091b4aab16639116e64469db16fc36386c7 Mon Sep 17 00:00:00 2001 From: Jay Zhuang Date: Thu, 26 Mar 2020 11:28:05 -0700 Subject: Add IPv4 to bind_to_device distribution test PiperOrigin-RevId: 303156734 --- pkg/tcpip/stack/transport_demuxer_test.go | 107 ++++++++++++++++++++++++------ 1 file changed, 86 insertions(+), 21 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 75c119c99..c65b0c632 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -31,12 +31,14 @@ import ( ) const ( - stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - stackAddr = "\x0a\x00\x00\x01" - stackPort = 1234 - testPort = 4096 + testSrcAddrV4 = "\x0a\x00\x00\x01" + testDstAddrV4 = "\x0a\x00\x00\x02" + + testDstPort = 1234 + testSrcPort = 4096 ) type testContext struct { @@ -59,11 +61,11 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI } linkEps[linkEpID] = channelEp - if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil { + if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil { t.Fatalf("AddAddress IPv4 failed: %s", err) } - if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil { + if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil { t.Fatalf("AddAddress IPv6 failed: %s", err) } } @@ -91,6 +93,47 @@ func newPayload() []byte { return b } +func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { + buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) + + // Initialize the IP header. + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: 0x80, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(udp.ProtocolNumber), + SrcAddr: testSrcAddrV4, + DstAddr: testDstAddrV4, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u))) + + // Calculate the UDP checksum and set it. + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum)) + + // Inject packet. + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), + }) +} + func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) @@ -102,8 +145,8 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI PayloadLength: uint16(header.UDPMinimumSize + len(payload)), NextHeader: uint8(udp.ProtocolNumber), HopLimit: 65, - SrcAddr: testV6Addr, - DstAddr: stackV6Addr, + SrcAddr: testSrcAddrV6, + DstAddr: testDstAddrV6, }) // Initialize the UDP header. @@ -115,7 +158,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI }) // Calculate the UDP pseudo-header checksum. - xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u))) // Calculate the UDP checksum and set it. xsum = header.Checksum(payload, xsum) @@ -123,7 +166,9 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI // Inject packet. c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ - Data: buf.ToVectorisedView(), + Data: buf.ToVectorisedView(), + NetworkHeader: buffer.View(ip), + TransportHeader: buffer.View(u), }) } @@ -227,9 +272,12 @@ func TestBindToDeviceDistribution(t *testing.T) { }, }, } { - t.Run(test.name, func(t *testing.T) { + for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{ + "IPv4": ipv4.ProtocolNumber, + "IPv6": ipv6.ProtocolNumber, + } { for device, wantDistribution := range test.wantDistributions { - t.Run(string(device), func(t *testing.T) { + t.Run(test.name+protoName+string(device), func(t *testing.T) { var devices []tcpip.NICID for d := range test.wantDistributions { devices = append(devices, d) @@ -248,7 +296,7 @@ func TestBindToDeviceDistribution(t *testing.T) { defer close(ch) var err *tcpip.Error - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) } @@ -269,7 +317,17 @@ func TestBindToDeviceDistribution(t *testing.T) { if err := ep.SetSockOpt(bindToDeviceOption); err != nil { t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err) } - if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { + + var dstAddr tcpip.Address + switch netProtoNum { + case ipv4.ProtocolNumber: + dstAddr = testDstAddrV4 + case ipv6.ProtocolNumber: + dstAddr = testDstAddrV6 + default: + t.Fatalf("unexpected protocol number: %d", netProtoNum) + } + if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil { t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) } } @@ -285,11 +343,18 @@ func TestBindToDeviceDistribution(t *testing.T) { // Send a packet. port := uint16(i % nports) payload := newPayload() - c.sendV6Packet(payload, - &headers{ - srcPort: testPort + port, - dstPort: stackPort}, - device) + hdrs := &headers{ + srcPort: testSrcPort + port, + dstPort: testDstPort, + } + switch netProtoNum { + case ipv4.ProtocolNumber: + c.sendV4Packet(payload, hdrs, device) + case ipv6.ProtocolNumber: + c.sendV6Packet(payload, hdrs, device) + default: + t.Fatalf("unexpected protocol number: %d", netProtoNum) + } ep := <-pollChannel if _, _, err := ep.Read(nil); err != nil { @@ -320,6 +385,6 @@ func TestBindToDeviceDistribution(t *testing.T) { } }) } - }) + } } } -- cgit v1.2.3 From 92b9069b67b927cef25a1490ebd142ad6d65690d Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Fri, 20 Mar 2020 12:00:21 -0700 Subject: Support owner matching for iptables. This feature will match UID and GID of the packet creator, for locally generated packets. This match is only valid in the OUTPUT and POSTROUTING chains. Forwarded packets do not have any socket associated with them. Packets from kernel threads do have a socket, but usually no owner. --- pkg/abi/linux/netfilter.go | 41 ++++++++ pkg/abi/linux/netfilter_test.go | 1 + pkg/sentry/kernel/task.go | 12 +++ pkg/sentry/socket/netfilter/BUILD | 1 + pkg/sentry/socket/netfilter/netfilter.go | 7 +- pkg/sentry/socket/netfilter/owner_matcher.go | 128 ++++++++++++++++++++++++ pkg/sentry/socket/netstack/provider.go | 6 ++ pkg/tcpip/network/ipv4/ipv4.go | 15 +++ pkg/tcpip/stack/packet_buffer.go | 9 +- pkg/tcpip/stack/transport_test.go | 2 + pkg/tcpip/tcpip.go | 12 +++ pkg/tcpip/transport/icmp/endpoint.go | 12 ++- pkg/tcpip/transport/packet/endpoint.go | 2 + pkg/tcpip/transport/raw/endpoint.go | 9 ++ pkg/tcpip/transport/tcp/accept.go | 5 +- pkg/tcpip/transport/tcp/connect.go | 10 +- pkg/tcpip/transport/tcp/endpoint.go | 7 ++ pkg/tcpip/transport/tcp/forwarder.go | 2 +- pkg/tcpip/transport/tcp/protocol.go | 2 +- pkg/tcpip/transport/udp/endpoint.go | 12 ++- test/iptables/filter_output.go | 143 +++++++++++++++++++++++++++ test/iptables/iptables_test.go | 30 ++++++ 22 files changed, 451 insertions(+), 17 deletions(-) create mode 100644 pkg/sentry/socket/netfilter/owner_matcher.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 80dc09aa9..a8d4f9d69 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -509,3 +509,44 @@ const ( // Enable all flags. XT_UDP_INV_MASK = 0x03 ) + +// IPTOwnerInfo holds data for matching packets with owner. It corresponds +// to struct ipt_owner_info in libxt_owner.c of iptables binary. +type IPTOwnerInfo struct { + // UID is user id which created the packet. + UID uint32 + + // GID is group id which created the packet. + GID uint32 + + // PID is process id of the process which created the packet. + PID uint32 + + // SID is session id which created the packet. + SID uint32 + + // Comm is the command name which created the packet. + Comm [16]byte + + // Match is used to match UID/GID of the socket. See the + // XT_OWNER_* flags below. + Match uint8 + + // Invert flips the meaning of Match field. + Invert uint8 +} + +// SizeOfIPTOwnerInfo is the size of an XTOwnerMatchInfo. +const SizeOfIPTOwnerInfo = 34 + +// Flags in IPTOwnerInfo.Match. Corresponding constants are in +// include/uapi/linux/netfilter/xt_owner.h. +const ( + // Match the UID of the packet. + XT_OWNER_UID = 1 << 0 + // Match the GID of the packet. + XT_OWNER_GID = 1 << 1 + // Match if the socket exists for the packet. Forwarded + // packets do not have an associated socket. + XT_OWNER_SOCKET = 1 << 2 +) diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go index 21e237f92..565dd550e 100644 --- a/pkg/abi/linux/netfilter_test.go +++ b/pkg/abi/linux/netfilter_test.go @@ -29,6 +29,7 @@ func TestSizes(t *testing.T) { {IPTGetEntries{}, SizeOfIPTGetEntries}, {IPTGetinfo{}, SizeOfIPTGetinfo}, {IPTIP{}, SizeOfIPTIP}, + {IPTOwnerInfo{}, SizeOfIPTOwnerInfo}, {IPTReplace{}, SizeOfIPTReplace}, {XTCounters{}, SizeOfXTCounters}, {XTEntryMatch{}, SizeOfXTEntryMatch}, diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 8452ddf5b..d6546735e 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -863,3 +863,15 @@ func (t *Task) SetOOMScoreAdj(adj int32) error { atomic.StoreInt32(&t.tg.oomScoreAdj, adj) return nil } + +// UID returns t's uid. +// TODO(gvisor.dev/issue/170): This method is not namespaced yet. +func (t *Task) UID() uint32 { + return uint32(t.Credentials().EffectiveKUID) +} + +// GID returns t's gid. +// TODO(gvisor.dev/issue/170): This method is not namespaced yet. +func (t *Task) GID() uint32 { + return uint32(t.Credentials().EffectiveKGID) +} diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index e801abeb8..721094bbf 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "extensions.go", "netfilter.go", + "owner_matcher.go", "targets.go", "tcp_matcher.go", "udp_matcher.go", diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 55bcc3ace..878f81fd5 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -517,11 +517,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { } // TODO(gvisor.dev/issue/170): Support other chains. - // Since we only support modifying the INPUT chain and redirect for - // PREROUTING chain right now, make sure all other chains point to - // ACCEPT rules. + // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now, + // make sure all other chains point to ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook != stack.Input && hook != stack.Prerouting { + if hook == stack.Forward || hook == stack.Postrouting { if _, ok := table.Rules[ruleIdx].Target.(stack.AcceptTarget); !ok { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go new file mode 100644 index 000000000..5949a7c29 --- /dev/null +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -0,0 +1,128 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" +) + +const matcherNameOwner = "owner" + +func init() { + registerMatchMaker(ownerMarshaler{}) +} + +// ownerMarshaler implements matchMaker for owner matching. +type ownerMarshaler struct{} + +// name implements matchMaker.name. +func (ownerMarshaler) name() string { + return matcherNameOwner +} + +// marshal implements matchMaker.marshal. +func (ownerMarshaler) marshal(mr stack.Matcher) []byte { + matcher := mr.(*OwnerMatcher) + iptOwnerInfo := linux.IPTOwnerInfo{ + UID: matcher.uid, + GID: matcher.gid, + } + + // Support for UID match. + // TODO(gvisor.dev/issue/170): Need to support gid match. + if matcher.matchUID { + iptOwnerInfo.Match = linux.XT_OWNER_UID + } else if matcher.matchGID { + panic("GID match is not supported.") + } else { + panic("UID match is not set.") + } + + buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo) + return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, usermem.ByteOrder, iptOwnerInfo)) +} + +// unmarshal implements matchMaker.unmarshal. +func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { + if len(buf) < linux.SizeOfIPTOwnerInfo { + return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf)) + } + + // For alignment reasons, the match's total size may + // exceed what's strictly necessary to hold matchData. + var matchData linux.IPTOwnerInfo + binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], usermem.ByteOrder, &matchData) + nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData) + + if matchData.Invert != 0 { + return nil, fmt.Errorf("invert flag is not supported for owner match") + } + + // Support for UID match. + // TODO(gvisor.dev/issue/170): Need to support gid match. + if matchData.Match&linux.XT_OWNER_UID != linux.XT_OWNER_UID { + return nil, fmt.Errorf("owner match is only supported for uid") + } + + // Check Flags. + var owner OwnerMatcher + owner.uid = matchData.UID + owner.gid = matchData.GID + owner.matchUID = true + + return &owner, nil +} + +type OwnerMatcher struct { + uid uint32 + gid uint32 + matchUID bool + matchGID bool + invert uint8 +} + +// Name implements Matcher.Name. +func (*OwnerMatcher) Name() string { + return matcherNameOwner +} + +// Match implements Matcher.Match. +func (om *OwnerMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { + // Support only for OUTPUT chain. + // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. + if hook != stack.Output { + return false, true + } + + // If the packet owner is not set, drop the packet. + // Support for uid match. + // TODO(gvisor.dev/issue/170): Need to support gid match. + if pkt.Owner == nil || !om.matchUID { + return false, true + } + + // TODO(gvisor.dev/issue/170): Need to add tests to verify + // drop rule when packet UID does not match owner matcher UID. + if pkt.Owner.UID() != om.uid { + return false, false + } + + return true, false +} diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index 5f181f017..eb090e79b 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -126,6 +126,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated) } else { ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq) + + // Assign task to PacketOwner interface to get the UID and GID for + // iptables owner matching. + if e == nil { + ep.SetOwner(t) + } } if e != nil { return nil, syserr.TranslateNetstackError(e) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index b3ee6000e..a7d9a8b25 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -244,6 +244,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + // iptables filtering. All packets that reach here are locally + // generated. + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Output, pkt); !ok { + // iptables is telling us to drop the packet. + return nil + } + if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in // the PacketBuffer's Data field. @@ -280,7 +288,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.Pac return len(pkts), nil } + // iptables filtering. All packets that reach here are locally + // generated. + ipt := e.stack.IPTables() for i := range pkts { + if ok := ipt.Check(stack.Output, pkts[i]); !ok { + // iptables is telling us to drop the packet. + continue + } ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params) pkts[i].NetworkHeader = buffer.View(ip) } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9505a4e92..9367de180 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -13,7 +13,10 @@ package stack -import "gvisor.dev/gvisor/pkg/tcpip/buffer" +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) // A PacketBuffer contains all the data of a network packet. // @@ -59,6 +62,10 @@ type PacketBuffer struct { // Hash is the transport layer hash of this packet. A value of zero // indicates no valid hash has been set. Hash uint32 + + // Owner is implemented by task to get the uid and gid. + // Only set for locally generated packets. + Owner tcpip.PacketOwner } // Clone makes a copy of pk. It clones the Data field, which creates a new diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 8ca9ac3cf..3084e6593 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -56,6 +56,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } +func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} + func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 3dc5d87d6..2ef3271f1 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -336,6 +336,15 @@ type ControlMessages struct { PacketInfo IPPacketInfo } +// PacketOwner is used to get UID and GID of the packet. +type PacketOwner interface { + // UID returns UID of the packet. + UID() uint32 + + // GID returns GID of the packet. + GID() uint32 +} + // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) // that exposes functionality like read, write, connect, etc. to users of the // networking stack. @@ -470,6 +479,9 @@ type Endpoint interface { // Stats returns a reference to the endpoint stats. Stats() EndpointStats + + // SetOwner sets the task owner to the endpoint owner. + SetOwner(owner PacketOwner) } // EndpointInfo is the interface implemented by each endpoint info struct. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 613b12ead..b007302fb 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -73,6 +73,9 @@ type endpoint struct { route stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { @@ -133,6 +136,10 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} + // IPTables implements tcpip.Endpoint.IPTables. func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil @@ -321,7 +328,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c switch e.NetProto { case header.IPv4ProtocolNumber: - err = send4(route, e.ID.LocalPort, v, e.ttl) + err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) case header.IPv6ProtocolNumber: err = send6(route, e.ID.LocalPort, v, e.ttl) @@ -415,7 +422,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } -func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { +func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return tcpip.ErrInvalidEndpointState } @@ -444,6 +451,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err Header: hdr, Data: data.ToVectorisedView(), TransportHeader: buffer.View(icmpv4), + Owner: owner, }) } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index df49d0995..23158173d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -392,3 +392,5 @@ func (ep *endpoint) Info() tcpip.EndpointInfo { func (ep *endpoint) Stats() tcpip.EndpointStats { return &ep.stats } + +func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 536dafd1e..337bc1c71 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -80,6 +80,9 @@ type endpoint struct { // Connect(), and is valid only when conneted is true. route stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } // NewEndpoint returns a raw endpoint for the given protocols. @@ -159,6 +162,10 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} + // IPTables implements tcpip.Endpoint.IPTables. func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil @@ -348,10 +355,12 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } break } + hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ Header: hdr, Data: buffer.View(payloadBytes).ToVectorisedView(), + Owner: e.owner, }); err != nil { return 0, nil, err } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 375ca21f6..7a9dea4ac 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -276,7 +276,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // and then performs the TCP 3-way handshake. // // The new endpoint is returned with e.mu held. -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) @@ -284,6 +284,7 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if err != nil { return nil, err } + ep.owner = owner // listenEP is nil when listenContext is used by tcp.Forwarder. deferAccept := time.Duration(0) @@ -414,7 +415,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header }() defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}) + n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 1d245c2c6..3239a5911 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -745,7 +745,7 @@ func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOp func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { tf.txHash = e.txHash - if err := sendTCP(r, tf, data, gso); err != nil { + if err := sendTCP(r, tf, data, gso, e.owner); err != nil { e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() return err } @@ -787,7 +787,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta } } -func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { +func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { optLen := len(tf.opts) if tf.rcvWnd > 0xffff { tf.rcvWnd = 0xffff @@ -816,6 +816,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso pkts[i].DataSize = packetSize pkts[i].Data = data pkts[i].Hash = tf.txHash + pkts[i].Owner = owner buildTCPHdr(r, tf, &pkts[i], gso) off += packetSize tf.seq = tf.seq.Add(seqnum.Size(packetSize)) @@ -833,14 +834,14 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { +func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { optLen := len(tf.opts) if tf.rcvWnd > 0xffff { tf.rcvWnd = 0xffff } if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { - return sendTCPBatch(r, tf, data, gso) + return sendTCPBatch(r, tf, data, gso, owner) } pkt := stack.PacketBuffer{ @@ -849,6 +850,7 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac DataSize: data.Size(), Data: data, Hash: tf.txHash, + Owner: owner, } buildTCPHdr(r, tf, &pkt, gso) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 1ebee0cfe..9b123e968 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -603,6 +603,9 @@ type endpoint struct { // txHash is the transport layer hash to be set on outbound packets // emitted by this endpoint. txHash uint32 + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -1132,6 +1135,10 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvListMu.Unlock() } +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} + // IPTables implements tcpip.Endpoint.IPTables. func (e *endpoint) IPTables() (stack.IPTables, error) { return e.stack.IPTables(), nil diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index a094471b8..808410c92 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -157,7 +157,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, TSVal: r.synOptions.TSVal, TSEcr: r.synOptions.TSEcr, SACKPermitted: r.synOptions.SACKPermitted, - }, queue) + }, queue, nil) if err != nil { return nil, err } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 1377107ca..dce9a1652 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -199,7 +199,7 @@ func replyWithReset(s *segment) { seq: seq, ack: ack, rcvWnd: 0, - }, buffer.VectorisedView{}, nil /* gso */) + }, buffer.VectorisedView{}, nil /* gso */, nil /* PacketOwner */) } // SetOption implements stack.TransportProtocol.SetOption. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index a3372ac58..120d3baa3 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -143,6 +143,9 @@ type endpoint struct { // TODO(b/142022063): Add ability to save and restore per endpoint stats. stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner } // +stateify savable @@ -484,7 +487,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -886,7 +889,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner) *tcpip.Error { // Allocate a buffer for the UDP header. hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) @@ -916,6 +919,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u Header: hdr, Data: data, TransportHeader: buffer.View(udp), + Owner: owner, }); err != nil { r.Stats().UDP.PacketSendErrors.Increment() return err @@ -1356,3 +1360,7 @@ func (*endpoint) Wait() {} func isBroadcastOrMulticast(a tcpip.Address) bool { return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) } + +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index 4582d514c..f6d974b85 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -24,6 +24,11 @@ func init() { RegisterTestCase(FilterOutputDropTCPSrcPort{}) RegisterTestCase(FilterOutputDestination{}) RegisterTestCase(FilterOutputInvertDestination{}) + RegisterTestCase(FilterOutputAcceptTCPOwner{}) + RegisterTestCase(FilterOutputDropTCPOwner{}) + RegisterTestCase(FilterOutputAcceptUDPOwner{}) + RegisterTestCase(FilterOutputDropUDPOwner{}) + RegisterTestCase(FilterOutputOwnerFail{}) } // FilterOutputDropTCPDestPort tests that connections are not accepted on @@ -90,6 +95,144 @@ func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP) error { return nil } +// FilterOutputAcceptTCPOwner tests that TCP connections from uid owner are accepted. +type FilterOutputAcceptTCPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputAcceptTCPOwner) Name() string { + return "FilterOutputAcceptTCPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputAcceptTCPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + if err := listenTCP(acceptPort, sendloopDuration); err != nil { + return fmt.Errorf("connection on port %d should be accepted, but got dropped", acceptPort) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputAcceptTCPOwner) LocalAction(ip net.IP) error { + if err := connectTCP(ip, acceptPort, sendloopDuration); err != nil { + return fmt.Errorf("connection destined to port %d should be accepted, but got dropped", acceptPort) + } + + return nil +} + +// FilterOutputDropTCPOwner tests that TCP connections from uid owner are dropped. +type FilterOutputDropTCPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputDropTCPOwner) Name() string { + return "FilterOutputDropTCPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDropTCPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + if err := listenTCP(acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("connection on port %d should be dropped, but got accepted", acceptPort) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDropTCPOwner) LocalAction(ip net.IP) error { + if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("connection destined to port %d should be dropped, but got accepted", acceptPort) + } + + return nil +} + +// FilterOutputAcceptUDPOwner tests that UDP packets from uid owner are accepted. +type FilterOutputAcceptUDPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputAcceptUDPOwner) Name() string { + return "FilterOutputAcceptUDPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputAcceptUDPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { + return err + } + + // Send UDP packets on acceptPort. + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputAcceptUDPOwner) LocalAction(ip net.IP) error { + // Listen for UDP packets on acceptPort. + return listenUDP(acceptPort, sendloopDuration) +} + +// FilterOutputDropUDPOwner tests that UDP packets from uid owner are dropped. +type FilterOutputDropUDPOwner struct{} + +// Name implements TestCase.Name. +func (FilterOutputDropUDPOwner) Name() string { + return "FilterOutputDropUDPOwner" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputDropUDPOwner) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { + return err + } + + // Send UDP packets on dropPort. + return sendUDPLoop(ip, dropPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputDropUDPOwner) LocalAction(ip net.IP) error { + // Listen for UDP packets on dropPort. + if err := listenUDP(dropPort, sendloopDuration); err == nil { + return fmt.Errorf("packets should not be received") + } + + return nil +} + +// FilterOutputOwnerFail tests that without uid/gid option, owner rule +// will fail. +type FilterOutputOwnerFail struct{} + +// Name implements TestCase.Name. +func (FilterOutputOwnerFail) Name() string { + return "FilterOutputOwnerFail" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputOwnerFail) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil { + return fmt.Errorf("Invalid argument") + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputOwnerFail) LocalAction(ip net.IP) error { + // no-op. + return nil +} + // FilterOutputDestination tests that we can selectively allow packets to // certain destinations. type FilterOutputDestination struct{} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 7f1f70606..493d69052 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -274,6 +274,36 @@ func TestFilterOutputDropTCPSrcPort(t *testing.T) { } } +func TestFilterOutputAcceptTCPOwner(t *testing.T) { + if err := singleTest(FilterOutputAcceptTCPOwner{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterOutputDropTCPOwner(t *testing.T) { + if err := singleTest(FilterOutputDropTCPOwner{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterOutputAcceptUDPOwner(t *testing.T) { + if err := singleTest(FilterOutputAcceptUDPOwner{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterOutputDropUDPOwner(t *testing.T) { + if err := singleTest(FilterOutputDropUDPOwner{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterOutputOwnerFail(t *testing.T) { + if err := singleTest(FilterOutputOwnerFail{}); err != nil { + t.Fatal(err) + } +} + func TestJumpSerialize(t *testing.T) { if err := singleTest(FilterInputSerializeJump{}); err != nil { t.Fatal(err) -- cgit v1.2.3 From edc3c049eb553fcbf32f4a6b515141a26c5609d4 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 26 Mar 2020 15:59:41 -0700 Subject: Use panic instead of log.Fatalf PiperOrigin-RevId: 303212189 --- pkg/tcpip/network/ipv6/icmp.go | 6 +++--- pkg/tcpip/stack/ndp.go | 29 +++++++++++++++-------------- pkg/tcpip/stack/nic.go | 5 ++--- 3 files changed, 20 insertions(+), 20 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 8640feffc..e0dd5afd3 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -15,7 +15,7 @@ package ipv6 import ( - "log" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -199,7 +199,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P opt, done, err := it.Next() if err != nil { // This should never happen as Iter(true) above did not return an error. - log.Fatalf("unexpected error when iterating over NDP options: %s", err) + panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err)) } if done { break @@ -306,7 +306,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P opt, done, err := it.Next() if err != nil { // This should never happen as Iter(true) above did not return an error. - log.Fatalf("unexpected error when iterating over NDP options: %s", err) + panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err)) } if done { break diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 630fdefc5..7c9fc48d1 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -15,6 +15,7 @@ package stack import ( + "fmt" "log" "math/rand" "time" @@ -428,7 +429,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref if ref.getKind() != permanentTentative { // The endpoint should be marked as tentative since we are starting DAD. - log.Fatalf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()) + panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) } // Should not attempt to perform DAD on an address that is currently in the @@ -440,7 +441,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // address, or its reference count would have been increased without doing // the work that would have been done for an address that was brand new. // See NIC.addAddressLocked. - log.Fatalf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()) + panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) } remaining := ndp.configs.DupAddrDetectTransmits @@ -476,7 +477,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref if ref.getKind() != permanentTentative { // The endpoint should still be marked as tentative since we are still // performing DAD on it. - log.Fatalf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID()) + panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID())) } dadDone := remaining == 0 @@ -546,9 +547,9 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { // Route should resolve immediately since snmc is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { - log.Fatalf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err) + panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)) } else if c != nil { - log.Fatalf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()) + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) } hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) @@ -949,7 +950,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { - log.Fatalf("ndp: must have a slaacPrefixes entry for the SLAAC prefix %s", prefix) + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the SLAAC prefix %s", prefix)) } ndp.deprecateSLAACAddress(prefixState.ref) @@ -1029,7 +1030,7 @@ func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referen ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) if err != nil { - log.Fatalf("ndp: error when adding address %+v: %s", generatedAddr, err) + panic(fmt.Sprintf("ndp: error when adding address %+v: %s", generatedAddr, err)) } return ref @@ -1043,7 +1044,7 @@ func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referen func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl time.Duration) { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { - log.Fatalf("ndp: SLAAC prefix state not found to refresh lifetimes for %s", prefix) + panic(fmt.Sprintf("ndp: SLAAC prefix state not found to refresh lifetimes for %s", prefix)) } defer func() { ndp.slaacPrefixes[prefix] = prefixState }() @@ -1144,7 +1145,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, removeAddr bool) if removeAddr { if err := ndp.nic.removePermanentAddressLocked(addr); err != nil { - log.Fatalf("ndp: removePermanentAddressLocked(%s): %s", addr, err) + panic(fmt.Sprintf("ndp: removePermanentAddressLocked(%s): %s", addr, err)) } } @@ -1193,7 +1194,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { } if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { - log.Fatalf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes) + panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)) } for prefix := range ndp.onLinkPrefixes { @@ -1201,7 +1202,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { } if got := len(ndp.onLinkPrefixes); got != 0 { - log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got) + panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got)) } for router := range ndp.defaultRouters { @@ -1209,7 +1210,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { } if got := len(ndp.defaultRouters); got != 0 { - log.Fatalf("ndp: still have discovered default routers after cleaning up; found = %d", got) + panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got)) } } @@ -1251,9 +1252,9 @@ func (ndp *ndpState) startSolicitingRouters() { // header.IPv6AllRoutersMulticastAddress is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { - log.Fatalf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err) + panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)) } else if c != nil { - log.Fatalf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()) + panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())) } // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index b6fa647ea..4835251bc 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "log" "reflect" "sort" "strings" @@ -480,7 +479,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn // Should never happen as we got r from the primary IPv6 endpoint list and // ScopeForIPv6Address only returns an error if addr is not an IPv6 // address. - log.Fatalf("header.ScopeForIPv6Address(%s): %s", addr, err) + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err)) } cs = append(cs, ipv6AddrCandidate{ @@ -492,7 +491,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn remoteScope, err := header.ScopeForIPv6Address(remoteAddr) if err != nil { // primaryIPv6Endpoint should never be called with an invalid IPv6 address. - log.Fatalf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err) + panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) } // Sort the addresses as per RFC 6724 section 5 rules 1-3. -- cgit v1.2.3 From aecd3a25a99bc04fe2c032ea3422f10b2dba3256 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 1 Apr 2020 16:40:08 -0700 Subject: Deflake tcpip/stack:stack_x_test Timeouts were increased to deflake pkg/tcpip/stack:stack_x_test tests that depend on timers. Some timeouts used previously were intended for tests that do not depend on timers, so this change updates those timeouts to give more time for a timer-based event to occur. This change also de-parallelizes non-subtests to reduce the number of active timers. Test: bazel test //pkg/tcpip/stack:stack_x_test --runs_per_test=500 PiperOrigin-RevId: 304287622 --- pkg/tcpip/stack/ndp_test.go | 126 +++++++++++++++--------------------------- pkg/tcpip/stack/stack_test.go | 2 - 2 files changed, 44 insertions(+), 84 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 06edd05b6..598468bdd 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -406,8 +406,7 @@ func TestDADResolve(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) } - // Address should not be considered bound to the NIC yet - // (DAD ongoing). + // Address should not be considered bound to the NIC yet (DAD ongoing). addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) @@ -416,10 +415,9 @@ func TestDADResolve(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } - // Wait for the remaining time - some delta (500ms), to - // make sure the address is still not resolved. - const delta = 500 * time.Millisecond - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) + // Make sure the address does not resolve before the resolution time has + // passed. + time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout) addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) @@ -430,13 +428,7 @@ func TestDADResolve(t *testing.T) { // Wait for DAD to resolve. select { - case <-time.After(2 * delta): - // We should get a resolution event after 500ms - // (delta) since we wait for 500ms less than the - // expected resolution time above to make sure - // that the address did not yet resolve. Waiting - // for 1s (2x delta) without a resolution event - // means something is wrong. + case <-time.After(2 * defaultAsyncEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { @@ -1034,8 +1026,6 @@ func TestNoRouterDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), } @@ -1074,8 +1064,6 @@ func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) str // TestRouterDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered router when the dispatcher asks it not to. func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), } @@ -1116,8 +1104,6 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { } func TestRouterDiscovery(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), rememberRouter: true, @@ -1219,8 +1205,6 @@ func TestRouterDiscovery(t *testing.T) { // TestRouterDiscoveryMaxRouters tests that only // stack.MaxDiscoveredDefaultRouters discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, 1), rememberRouter: true, @@ -1287,8 +1271,6 @@ func TestNoPrefixDiscovery(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, 1), } @@ -1328,8 +1310,6 @@ func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) st // TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not // remember a discovered on-link prefix when the dispatcher asks it not to. func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { - t.Parallel() - prefix, subnet, _ := prefixSubnetAddr(0, "") ndpDisp := ndpDispatcher{ @@ -1373,8 +1353,6 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { } func TestPrefixDiscovery(t *testing.T) { - t.Parallel() - prefix1, subnet1, _ := prefixSubnetAddr(0, "") prefix2, subnet2, _ := prefixSubnetAddr(1, "") prefix3, subnet3, _ := prefixSubnetAddr(2, "") @@ -1563,8 +1541,6 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // TestPrefixDiscoveryMaxRouters tests that only // stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3), rememberPrefix: true, @@ -1659,8 +1635,6 @@ func TestNoAutoGenAddr(t *testing.T) { forwarding := i&4 == 0 t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } @@ -2410,8 +2384,6 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { }, } - const delta = 500 * time.Millisecond - // This Run will not return until the parallel tests finish. // // We need this because we need to do some teardown work after the @@ -2464,24 +2436,21 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { // to test.evl. // - // Make sure we do not get any invalidation - // events until atleast 500ms (delta) before - // test.evl. + // The address should not be invalidated until the effective valid + // lifetime has passed. select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - delta): + case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncEventTimeout): } - // Wait for another second (2x delta), but now - // we expect the invalidation event. + // Wait for the invalidation event. select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - - case <-time.After(2 * delta): + case <-time.After(2 * defaultAsyncEventTimeout): t.Fatal("timeout waiting for addr auto gen event") } }) @@ -2493,8 +2462,6 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { // by the user, its resources will be cleaned up and an invalidation event will // be sent to the integrator. func TestAutoGenAddrRemoval(t *testing.T) { - t.Parallel() - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) ndpDisp := ndpDispatcher{ @@ -2551,8 +2518,6 @@ func TestAutoGenAddrRemoval(t *testing.T) { // TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously // assigned to the NIC but is in the permanentExpired state. func TestAutoGenAddrAfterRemoval(t *testing.T) { - t.Parallel() - const nicID = 1 prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) @@ -2664,8 +2629,6 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { // TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that // is already assigned to the NIC, the static address remains. func TestAutoGenAddrStaticConflict(t *testing.T) { - t.Parallel() - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) ndpDisp := ndpDispatcher{ @@ -2721,8 +2684,6 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { // TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use // opaque interface identifiers when configured to do so. func TestAutoGenAddrWithOpaqueIID(t *testing.T) { - t.Parallel() - const nicID = 1 const nicName = "nic1" var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte @@ -2826,8 +2787,6 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { // to the integrator when an RA is received with the NDP Recursive DNS Server // option with at least one valid address. func TestNDPRecursiveDNSServerDispatch(t *testing.T) { - t.Parallel() - tests := []struct { name string opt header.NDPRecursiveDNSServer @@ -2919,11 +2878,7 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ // We do not expect more than a single RDNSS // event at any time for this test. @@ -2973,8 +2928,6 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { // TestCleanupNDPState tests that all discovered routers and prefixes, and // auto-generated addresses are invalidated when a NIC becomes a router. func TestCleanupNDPState(t *testing.T) { - t.Parallel() - const ( lifetimeSeconds = 5 maxRouterAndPrefixEvents = 4 @@ -3417,8 +3370,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // TestRouterSolicitation tests the initial Router Solicitations that are sent // when a NIC newly becomes enabled. func TestRouterSolicitation(t *testing.T) { - t.Parallel() - const nicID = 1 tests := []struct { @@ -3435,13 +3386,22 @@ func TestRouterSolicitation(t *testing.T) { effectiveMaxRtrSolicitDelay time.Duration }{ { - name: "Single RS with delay", + name: "Single RS with 2s delay and interval", expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 1, - rtrSolicitInt: time.Second, - effectiveRtrSolicitInt: time.Second, - maxRtrSolicitDelay: time.Second, - effectiveMaxRtrSolicitDelay: time.Second, + rtrSolicitInt: 2 * time.Second, + effectiveRtrSolicitInt: 2 * time.Second, + maxRtrSolicitDelay: 2 * time.Second, + effectiveMaxRtrSolicitDelay: 2 * time.Second, + }, + { + name: "Single RS with 4s delay and interval", + expectedSrcAddr: header.IPv6Any, + maxRtrSolicit: 1, + rtrSolicitInt: 4 * time.Second, + effectiveRtrSolicitInt: 4 * time.Second, + maxRtrSolicitDelay: 4 * time.Second, + effectiveMaxRtrSolicitDelay: 4 * time.Second, }, { name: "Two RS with delay", @@ -3449,8 +3409,8 @@ func TestRouterSolicitation(t *testing.T) { nicAddr: llAddr1, expectedSrcAddr: llAddr1, maxRtrSolicit: 2, - rtrSolicitInt: time.Second, - effectiveRtrSolicitInt: time.Second, + rtrSolicitInt: 2 * time.Second, + effectiveRtrSolicitInt: 2 * time.Second, maxRtrSolicitDelay: 500 * time.Millisecond, effectiveMaxRtrSolicitDelay: 500 * time.Millisecond, }, @@ -3464,8 +3424,8 @@ func TestRouterSolicitation(t *testing.T) { header.NDPSourceLinkLayerAddressOption(linkAddr1), }, maxRtrSolicit: 1, - rtrSolicitInt: time.Second, - effectiveRtrSolicitInt: time.Second, + rtrSolicitInt: 2 * time.Second, + effectiveRtrSolicitInt: 2 * time.Second, maxRtrSolicitDelay: 0, effectiveMaxRtrSolicitDelay: 0, }, @@ -3515,6 +3475,7 @@ func TestRouterSolicitation(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() + e := channelLinkWithHeaderLength{ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), headerLength: test.linkHeaderLen, @@ -3583,15 +3544,19 @@ func TestRouterSolicitation(t *testing.T) { } for ; remaining > 0; remaining-- { - waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout) - waitForPkt(defaultAsyncEventTimeout) + if test.effectiveRtrSolicitInt > defaultAsyncEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncEventTimeout) + waitForPkt(2 * defaultAsyncEventTimeout) + } else { + waitForPkt(test.effectiveRtrSolicitInt * defaultAsyncEventTimeout) + } } // Make sure no more RS. if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt + defaultTimeout) + waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncEventTimeout) } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultTimeout) + waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) } // Make sure the counter got properly @@ -3605,11 +3570,9 @@ func TestRouterSolicitation(t *testing.T) { } func TestStopStartSolicitingRouters(t *testing.T) { - t.Parallel() - const nicID = 1 + const delay = 0 const interval = 500 * time.Millisecond - const delay = time.Second const maxRtrSolicitations = 3 tests := []struct { @@ -3684,7 +3647,6 @@ func TestStopStartSolicitingRouters(t *testing.T) { p, ok := e.ReadContext(ctx) if !ok { t.Fatal("timed out waiting for packet") - return } if p.Proto != header.IPv6ProtocolNumber { @@ -3710,11 +3672,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stop soliciting routers. test.stopFn(t, s, true /* first */) - ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout) + ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { - // A single RS may have been sent before forwarding was enabled. - ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout) + // A single RS may have been sent before solicitations were stopped. + ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout) defer cancel() if _, ok = e.ReadContext(ctx); ok { t.Fatal("should not have sent more than one RS message") @@ -3724,7 +3686,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stopping router solicitations after it has already been stopped should // do nothing. test.stopFn(t, s, false /* first */) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") @@ -3740,7 +3702,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { waitForPkt(delay + defaultAsyncEventTimeout) waitForPkt(interval + defaultAsyncEventTimeout) waitForPkt(interval + defaultAsyncEventTimeout) - ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout) + ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") @@ -3749,7 +3711,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Starting router solicitations after it has already completed should do // nothing. test.startFn(t, s) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after finishing router solicitations") diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 555fcd92f..b8543b71e 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -3176,8 +3176,6 @@ func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { // TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC // was disabled have DAD performed on them when the NIC is enabled. func TestDoDADWhenNICEnabled(t *testing.T) { - t.Parallel() - const dadTransmits = 1 const retransmitTimer = time.Second const nicID = 1 -- cgit v1.2.3 From fc99a7ebf0c24b6f7b3cfd6351436373ed54548b Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Fri, 3 Apr 2020 18:34:48 -0700 Subject: Refactor software GSO code. Software GSO implementation currently has a complicated code path with implicit assumptions that all packets to WritePackets carry same Data and it does this to avoid allocations on the path etc. But this makes it hard to reuse the WritePackets API. This change breaks all such assumptions by introducing a new Vectorised View API ReadToVV which can be used to cleanly split a VV into multiple independent VVs. Further this change also makes packet buffers linkable to form an intrusive list. This allows us to get rid of the array of packet buffers that are passed in the WritePackets API call and replace it with a list of packet buffers. While this code does introduce some more allocations in the benchmarks it doesn't cause any degradation. Updates #231 PiperOrigin-RevId: 304731742 --- pkg/ilist/list.go | 13 ++- pkg/sentry/kernel/kernel.go | 22 +++-- pkg/tcpip/buffer/view.go | 53 +++++++++- pkg/tcpip/buffer/view_test.go | 137 ++++++++++++++++++++++++++ pkg/tcpip/link/channel/channel.go | 18 ++-- pkg/tcpip/link/fdbased/endpoint.go | 162 ++++++++++++++----------------- pkg/tcpip/link/loopback/loopback.go | 2 +- pkg/tcpip/link/muxed/injectable.go | 2 +- pkg/tcpip/link/sharedmem/sharedmem.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 14 +-- pkg/tcpip/link/waitable/waitable.go | 4 +- pkg/tcpip/link/waitable/waitable_test.go | 6 +- pkg/tcpip/network/arp/arp.go | 2 +- pkg/tcpip/network/ip_test.go | 2 +- pkg/tcpip/network/ipv4/ipv4.go | 37 +++++-- pkg/tcpip/network/ipv6/icmp.go | 2 +- pkg/tcpip/network/ipv6/ipv6.go | 12 +-- pkg/tcpip/stack/BUILD | 14 ++- pkg/tcpip/stack/forwarder_test.go | 8 +- pkg/tcpip/stack/iptables.go | 17 ++++ pkg/tcpip/stack/ndp_test.go | 2 +- pkg/tcpip/stack/packet_buffer.go | 14 +-- pkg/tcpip/stack/packet_buffer_state.go | 27 ------ pkg/tcpip/stack/registration.go | 4 +- pkg/tcpip/stack/route.go | 19 ++-- pkg/tcpip/stack/stack_test.go | 2 +- pkg/tcpip/transport/tcp/connect.go | 47 +++++---- pkg/tcpip/transport/tcp/segment.go | 6 +- 28 files changed, 420 insertions(+), 230 deletions(-) delete mode 100644 pkg/tcpip/stack/packet_buffer_state.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go index 8f93e4d6d..0d07da3b1 100644 --- a/pkg/ilist/list.go +++ b/pkg/ilist/list.go @@ -86,12 +86,21 @@ func (l *List) Back() Element { return l.tail } +// Len returns the number of elements in the list. +// +// NOTE: This is an O(n) operation. +func (l *List) Len() (count int) { + for e := l.Front(); e != nil; e = e.Next() { + count++ + } + return count +} + // PushFront inserts the element e at the front of list l. func (l *List) PushFront(e Element) { linker := ElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) - if l.head != nil { ElementMapper{}.linkerFor(l.head).SetPrev(e) } else { @@ -106,7 +115,6 @@ func (l *List) PushBack(e Element) { linker := ElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) - if l.tail != nil { ElementMapper{}.linkerFor(l.tail).SetNext(e) } else { @@ -127,7 +135,6 @@ func (l *List) PushBackList(m *List) { l.tail = m.tail } - m.head = nil m.tail = nil } diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 0a448b57c..2e6f42b92 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -564,15 +564,25 @@ func (ts *TaskSet) unregisterEpollWaiters() { ts.mu.RLock() defer ts.mu.RUnlock() + + // Tasks that belong to the same process could potentially point to the + // same FDTable. So we retain a map of processed ones to avoid + // processing the same FDTable multiple times. + processed := make(map[*FDTable]struct{}) for t := range ts.Root.tids { // We can skip locking Task.mu here since the kernel is paused. - if t.fdTable != nil { - t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { - if e, ok := file.FileOperations.(*epoll.EventPoll); ok { - e.UnregisterEpollWaiters() - } - }) + if t.fdTable == nil { + continue + } + if _, ok := processed[t.fdTable]; ok { + continue } + t.fdTable.forEach(func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { + if e, ok := file.FileOperations.(*epoll.EventPoll); ok { + e.UnregisterEpollWaiters() + } + }) + processed[t.fdTable] = struct{}{} } } diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8d42cd066..8ec5d5d5c 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -17,6 +17,7 @@ package buffer import ( "bytes" + "io" ) // View is a slice of a buffer, with convenience methods. @@ -89,6 +90,47 @@ func (vv *VectorisedView) TrimFront(count int) { } } +// Read implements io.Reader. +func (vv *VectorisedView) Read(v View) (copied int, err error) { + count := len(v) + for count > 0 && len(vv.views) > 0 { + if count < len(vv.views[0]) { + vv.size -= count + copy(v[copied:], vv.views[0][:count]) + vv.views[0].TrimFront(count) + copied += count + return copied, nil + } + count -= len(vv.views[0]) + copy(v[copied:], vv.views[0]) + copied += len(vv.views[0]) + vv.RemoveFirst() + } + if copied == 0 { + return 0, io.EOF + } + return copied, nil +} + +// ReadToVV reads up to n bytes from vv to dstVV and removes them from vv. It +// returns the number of bytes copied. +func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int) { + for count > 0 && len(vv.views) > 0 { + if count < len(vv.views[0]) { + vv.size -= count + dstVV.AppendView(vv.views[0][:count]) + vv.views[0].TrimFront(count) + copied += count + return + } + count -= len(vv.views[0]) + dstVV.AppendView(vv.views[0]) + copied += len(vv.views[0]) + vv.RemoveFirst() + } + return copied +} + // CapLength irreversibly reduces the length of the vectorised view. func (vv *VectorisedView) CapLength(length int) { if length < 0 { @@ -116,12 +158,12 @@ func (vv *VectorisedView) CapLength(length int) { // Clone returns a clone of this VectorisedView. // If the buffer argument is large enough to contain all the Views of this VectorisedView, // the method will avoid allocations and use the buffer to store the Views of the clone. -func (vv VectorisedView) Clone(buffer []View) VectorisedView { +func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } // First returns the first view of the vectorised view. -func (vv VectorisedView) First() View { +func (vv *VectorisedView) First() View { if len(vv.views) == 0 { return nil } @@ -134,11 +176,12 @@ func (vv *VectorisedView) RemoveFirst() { return } vv.size -= len(vv.views[0]) + vv.views[0] = nil vv.views = vv.views[1:] } // Size returns the size in bytes of the entire content stored in the vectorised view. -func (vv VectorisedView) Size() int { +func (vv *VectorisedView) Size() int { return vv.size } @@ -146,7 +189,7 @@ func (vv VectorisedView) Size() int { // // If the vectorised view contains a single view, that view will be returned // directly. -func (vv VectorisedView) ToView() View { +func (vv *VectorisedView) ToView() View { if len(vv.views) == 1 { return vv.views[0] } @@ -158,7 +201,7 @@ func (vv VectorisedView) ToView() View { } // Views returns the slice containing the all views. -func (vv VectorisedView) Views() []View { +func (vv *VectorisedView) Views() []View { return vv.views } diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index ebc3a17b7..106e1994c 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -233,3 +233,140 @@ func TestToClone(t *testing.T) { }) } } + +func TestVVReadToVV(t *testing.T) { + testCases := []struct { + comment string + vv VectorisedView + bytesToRead int + wantBytes string + leftVV VectorisedView + }{ + { + comment: "large VV, short read", + vv: vv(30, "012345678901234567890123456789"), + bytesToRead: 10, + wantBytes: "0123456789", + leftVV: vv(20, "01234567890123456789"), + }, + { + comment: "largeVV, multiple views, short read", + vv: vv(13, "123", "345", "567", "8910"), + bytesToRead: 6, + wantBytes: "123345", + leftVV: vv(7, "567", "8910"), + }, + { + comment: "smallVV (multiple views), large read", + vv: vv(3, "1", "2", "3"), + bytesToRead: 10, + wantBytes: "123", + leftVV: vv(0, ""), + }, + { + comment: "smallVV (single view), large read", + vv: vv(1, "1"), + bytesToRead: 10, + wantBytes: "1", + leftVV: vv(0, ""), + }, + { + comment: "emptyVV, large read", + vv: vv(0, ""), + bytesToRead: 10, + wantBytes: "", + leftVV: vv(0, ""), + }, + } + + for _, tc := range testCases { + t.Run(tc.comment, func(t *testing.T) { + var readTo VectorisedView + inSize := tc.vv.Size() + copied := tc.vv.ReadToVV(&readTo, tc.bytesToRead) + if got, want := copied, len(tc.wantBytes); got != want { + t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc: %+v", got, want, tc) + } + if got, want := string(readTo.ToView()), tc.wantBytes; got != want { + t.Errorf("unexpected content in readTo got: %s, want: %s", got, want) + } + if got, want := tc.vv.Size(), inSize-copied; got != want { + t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want { + t.Errorf("unexpected data left in vv after read got: %+v, want: %+v", got, want) + } + }) + } +} + +func TestVVRead(t *testing.T) { + testCases := []struct { + comment string + vv VectorisedView + bytesToRead int + readBytes string + leftBytes string + wantError bool + }{ + { + comment: "large VV, short read", + vv: vv(30, "012345678901234567890123456789"), + bytesToRead: 10, + readBytes: "0123456789", + leftBytes: "01234567890123456789", + }, + { + comment: "largeVV, multiple buffers, short read", + vv: vv(13, "123", "345", "567", "8910"), + bytesToRead: 6, + readBytes: "123345", + leftBytes: "5678910", + }, + { + comment: "smallVV, large read", + vv: vv(3, "1", "2", "3"), + bytesToRead: 10, + readBytes: "123", + leftBytes: "", + }, + { + comment: "smallVV, large read", + vv: vv(1, "1"), + bytesToRead: 10, + readBytes: "1", + leftBytes: "", + }, + { + comment: "emptyVV, large read", + vv: vv(0, ""), + bytesToRead: 10, + readBytes: "", + wantError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.comment, func(t *testing.T) { + readTo := NewView(tc.bytesToRead) + inSize := tc.vv.Size() + copied, err := tc.vv.Read(readTo) + if !tc.wantError && err != nil { + t.Fatalf("unexpected error in tc.vv.Read(..) = %s", err) + } + readTo = readTo[:copied] + if got, want := copied, len(tc.readBytes); got != want { + t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(readTo), tc.readBytes; got != want { + t.Errorf("unexpected data in readTo got: %s, want: %s", got, want) + } + if got, want := tc.vv.Size(), inSize-copied; got != want { + t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv) + } + if got, want := string(tc.vv.ToView()), tc.leftBytes; got != want { + t.Errorf("vv has incorrect data after Read got: %s, want: %s", got, want) + } + }) + } +} diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index a8d6653ce..b4a0ae53d 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -28,7 +28,7 @@ import ( // PacketInfo holds all the information about an outbound packet. type PacketInfo struct { - Pkt stack.PacketBuffer + Pkt *stack.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO Route stack.Route @@ -257,7 +257,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne route := r.Clone() route.Release() p := PacketInfo{ - Pkt: pkt, + Pkt: &pkt, Proto: protocol, GSO: gso, Route: route, @@ -269,21 +269,15 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { // Clone r then release its resource so we only get the relevant fields from // stack.Route without holding a reference to a NIC's endpoint. route := r.Clone() route.Release() - payloadView := pkts[0].Data.ToView() n := 0 - for _, pkt := range pkts { - off := pkt.DataOffset - size := pkt.DataSize + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ - Pkt: stack.PacketBuffer{ - Header: pkt.Header, - Data: buffer.NewViewFromBytes(payloadView[off : off+size]).ToVectorisedView(), - }, + Pkt: pkt, Proto: protocol, GSO: gso, Route: route, @@ -301,7 +295,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.Pac // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Pkt: stack.PacketBuffer{Data: vv}, + Pkt: &stack.PacketBuffer{Data: vv}, Proto: 0, GSO: nil, } diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 3b3b6909b..7198742b7 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -441,118 +441,106 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - var ethHdrBuf []byte - // hdr + data - iovLen := 2 - if e.hdrSize > 0 { - // Add ethernet header if needed. - ethHdrBuf = make([]byte, header.EthernetMinimumSize) - eth := header.Ethernet(ethHdrBuf) - ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, - Type: protocol, - } - - // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress - } else { - ethHdr.SrcAddr = e.addr - } - eth.Encode(ethHdr) - iovLen++ - } +// +// NOTE: This API uses sendmmsg to batch packets. As a result the underlying FD +// picked to write the packet out has to be the same for all packets in the +// list. In other words all packets in the batch should belong to the same +// flow. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + n := pkts.Len() - n := len(pkts) - - views := pkts[0].Data.Views() - /* - * Each boundary in views can add one more iovec. - * - * payload | | | | - * ----------------------------- - * packets | | | | | | | - * ----------------------------- - * iovecs | | | | | | | | | - */ - iovec := make([]syscall.Iovec, n*iovLen+len(views)-1) mmsgHdrs := make([]rawfile.MMsgHdr, n) + i := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + var ethHdrBuf []byte + iovLen := 0 + if e.hdrSize > 0 { + // Add ethernet header if needed. + ethHdrBuf = make([]byte, header.EthernetMinimumSize) + eth := header.Ethernet(ethHdrBuf) + ethHdr := &header.EthernetFields{ + DstAddr: r.RemoteLinkAddress, + Type: protocol, + } - iovecIdx := 0 - viewIdx := 0 - viewOff := 0 - off := 0 - nextOff := 0 - for i := range pkts { - // TODO(b/134618279): Different packets may have different data - // in the future. We should handle this. - if !viewsEqual(pkts[i].Data.Views(), views) { - panic("All packets in pkts should have the same Data.") + // Preserve the src address if it's set in the route. + if r.LocalLinkAddress != "" { + ethHdr.SrcAddr = r.LocalLinkAddress + } else { + ethHdr.SrcAddr = e.addr + } + eth.Encode(ethHdr) + iovLen++ } - prevIovecIdx := iovecIdx - mmsgHdr := &mmsgHdrs[i] - mmsgHdr.Msg.Iov = &iovec[iovecIdx] - packetSize := pkts[i].DataSize - hdr := &pkts[i].Header - - off = pkts[i].DataOffset - if off != nextOff { - // We stop in a different point last time. - size := packetSize - viewIdx = 0 - viewOff = 0 - for size > 0 { - if size >= len(views[viewIdx]) { - viewIdx++ - viewOff = 0 - size -= len(views[viewIdx]) - } else { - viewOff = size - size = 0 + var vnetHdrBuf []byte + vnetHdr := virtioNetHdr{} + if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if gso != nil { + vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + if gso.NeedsCsum { + vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM + vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen + vnetHdr.csumOffset = gso.CsumOffset + } + if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { + switch gso.Type { + case stack.GSOTCPv4: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 + case stack.GSOTCPv6: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 + default: + panic(fmt.Sprintf("Unknown gso type: %v", gso.Type)) + } + vnetHdr.gsoSize = gso.MSS } } + vnetHdrBuf = vnetHdrToByteSlice(&vnetHdr) + iovLen++ } - nextOff = off + packetSize + iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views())) + mmsgHdr := &mmsgHdrs[i] + mmsgHdr.Msg.Iov = &iovecs[0] + iovecIdx := 0 + if vnetHdrBuf != nil { + v := &iovecs[iovecIdx] + v.Base = &vnetHdrBuf[0] + v.Len = uint64(len(vnetHdrBuf)) + iovecIdx++ + } if ethHdrBuf != nil { - v := &iovec[iovecIdx] + v := &iovecs[iovecIdx] v.Base = ðHdrBuf[0] v.Len = uint64(len(ethHdrBuf)) iovecIdx++ } - - v := &iovec[iovecIdx] + pktSize := uint64(0) + // Encode L3 Header + v := &iovecs[iovecIdx] + hdr := &pkt.Header hdrView := hdr.View() v.Base = &hdrView[0] v.Len = uint64(len(hdrView)) + pktSize += v.Len iovecIdx++ - for packetSize > 0 { - vec := &iovec[iovecIdx] + // Now encode the Transport Payload. + pktViews := pkt.Data.Views() + for i := range pktViews { + vec := &iovecs[iovecIdx] iovecIdx++ - - v := views[viewIdx] - vec.Base = &v[viewOff] - s := len(v) - viewOff - if s <= packetSize { - viewIdx++ - viewOff = 0 - } else { - s = packetSize - viewOff += s - } - vec.Len = uint64(s) - packetSize -= s + vec.Base = &pktViews[i][0] + vec.Len = uint64(len(pktViews[i])) + pktSize += vec.Len } - - mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx) + mmsgHdr.Msg.Iovlen = uint64(iovecIdx) + i++ } packets := 0 for packets < n { - fd := e.fds[pkts[packets].Hash%uint32(len(e.fds))] + fd := e.fds[pkts.Front().Hash%uint32(len(e.fds))] sent, err := rawfile.NonBlockingSendMMsg(fd, mmsgHdrs) if err != nil { return packets, err diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 4039753b7..1e2255bfa 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -92,7 +92,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index f5973066d..a5478ce17 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -87,7 +87,7 @@ func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, // WritePackets writes outbound packets to the appropriate // LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if // r.RemoteAddress has a route registered in this endpoint. -func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { endpoint, ok := m.routes[r.RemoteAddress] if !ok { return 0, tcpip.ErrNoRoute diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 6461d0108..0796d717e 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -214,7 +214,7 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 0a6b8945c..062388f4d 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -200,7 +200,7 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { logPacket("send", protocol, pkt.Header.View(), gso) } @@ -233,20 +233,16 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, protocol tcpip.NetworkProtocolNumb // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { - e.dumpPacket(gso, protocol, pkt) + e.dumpPacket(gso, protocol, &pkt) return e.lower.WritePacket(r, gso, protocol, pkt) } // WritePackets implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - view := pkts[0].Data.ToView() - for _, pkt := range pkts { - e.dumpPacket(gso, protocol, stack.PacketBuffer{ - Header: pkt.Header, - Data: view[pkt.DataOffset:][:pkt.DataSize].ToVectorisedView(), - }) +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.dumpPacket(gso, protocol, pkt) } return e.lower.WritePackets(r, gso, pkts, protocol) } diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 52fe397bf..2b3741276 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -112,9 +112,9 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets implements stack.LinkEndpoint.WritePackets. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { if !e.writeGate.Enter() { - return len(pkts), nil + return pkts.Len(), nil } n, err := e.lower.WritePackets(r, gso, pkts, protocol) diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 88224e494..54eb5322b 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -71,9 +71,9 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcp } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - e.writeCount += len(pkts) - return len(pkts), nil +func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + e.writeCount += pkts.Len() + return pkts.Len(), nil } func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 255098372..7acbfa0a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderPara } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketBuffer, stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4950d69fc..4c20301c6 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -172,7 +172,7 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt []stack.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a7d9a8b25..104aafbed 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -280,28 +280,47 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("multiple packets in local loop") } if r.Loop&stack.PacketOut == 0 { - return len(pkts), nil + return pkts.Len(), nil + } + + for pkt := pkts.Front(); pkt != nil; { + ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) + pkt.NetworkHeader = buffer.View(ip) + pkt = pkt.Next() } // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - for i := range pkts { - if ok := ipt.Check(stack.Output, pkts[i]); !ok { - // iptables is telling us to drop the packet. + dropped := ipt.CheckPackets(stack.Output, pkts) + if len(dropped) == 0 { + // Fast path: If no packets are to be dropped then we can just invoke the + // faster WritePackets API directly. + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + + // Slow Path as we are dropping some packets in the batch degrade to + // emitting one packet at a time. + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if _, ok := dropped[pkt]; ok { continue } - ip := e.addIPHeader(r, &pkts[i].Header, pkts[i].DataSize, params) - pkts[i].NetworkHeader = buffer.View(ip) + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + n++ } - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + return n, nil } // WriteHeaderIncludedPacket writes a packet already containing a network diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 6d2d2c034..f91180aa3 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -79,7 +79,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Only the first view in vv is accounted for by h. To account for the // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. - payload := pkt.Data + payload := pkt.Data.Clone(nil) payload.RemoveFirst() if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index b462b8604..a815b4d9b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -143,19 +143,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { panic("not implemented") } if r.Loop&stack.PacketOut == 0 { - return len(pkts), nil + return pkts.Len(), nil } - for i := range pkts { - hdr := &pkts[i].Header - size := pkts[i].DataSize - ip := e.addIPHeader(r, hdr, size, params) - pkts[i].NetworkHeader = buffer.View(ip) + for pb := pkts.Front(); pb != nil; pb = pb.Next() { + ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) + pb.NetworkHeader = buffer.View(ip) } n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 8d80e9cee..5e963a4af 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -15,6 +15,18 @@ go_template_instance( }, ) +go_template_instance( + name = "packet_buffer_list", + out = "packet_buffer_list.go", + package = "stack", + prefix = "PacketBuffer", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*PacketBuffer", + "Linker": "*PacketBuffer", + }, +) + go_library( name = "stack", srcs = [ @@ -29,7 +41,7 @@ go_library( "ndp.go", "nic.go", "packet_buffer.go", - "packet_buffer_state.go", + "packet_buffer_list.go", "rand.go", "registration.go", "route.go", diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c45c43d21..e9c652042 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -101,7 +101,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH } // WritePackets implements LinkEndpoint.WritePackets. -func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { +func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } @@ -260,10 +260,10 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw } // WritePackets stores outbound packets into the channel. -func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { n := 0 - for _, pkt := range pkts { - e.WritePacket(r, gso, protocol, pkt) + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.WritePacket(r, gso, protocol, *pkt) n++ } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 37907ae24..6c0a4b24d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -209,6 +209,23 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { return true } +// CheckPackets runs pkts through the rules for hook and returns a map of packets that +// should not go forward. +// +// NOTE: unlike the Check API the returned map contains packets that should be +// dropped. +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if ok := it.Check(hook, *pkt); !ok { + if drop == nil { + drop = make(map[*PacketBuffer]struct{}) + } + drop[pkt] = struct{}{} + } + } + return drop +} + // Precondition: pkt.NetworkHeader is set. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 598468bdd..27dc8baf9 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -468,7 +468,7 @@ func TestDADResolve(t *testing.T) { // As per RFC 4861 section 4.3, a possible option is the Source Link // Layer option, but this option MUST NOT be included when the source // address of the packet is the unspecified address. - checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), + checker.IPv6(t, p.Pkt.Header.View(), checker.SrcAddr(header.IPv6Any), checker.DstAddr(snmc), checker.TTL(header.NDPHopLimit), diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9367de180..dc125f25e 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -23,9 +23,11 @@ import ( // As a PacketBuffer traverses up the stack, it may be necessary to pass it to // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. -// -// +stateify savable type PacketBuffer struct { + // PacketBufferEntry is used to build an intrusive list of + // PacketBuffers. + PacketBufferEntry + // Data holds the payload of the packet. For inbound packets, it also // holds the headers, which are consumed as the packet moves up the // stack. Headers are guaranteed not to be split across views. @@ -34,14 +36,6 @@ type PacketBuffer struct { // or otherwise modified. Data buffer.VectorisedView - // DataOffset is used for GSO output. It is the offset into the Data - // field where the payload of this packet starts. - DataOffset int - - // DataSize is used for GSO output. It is the size of this packet's - // payload. - DataSize int - // Header holds the headers of outbound packets. As a packet is passed // down the stack, each layer adds to Header. Header buffer.Prependable diff --git a/pkg/tcpip/stack/packet_buffer_state.go b/pkg/tcpip/stack/packet_buffer_state.go deleted file mode 100644 index 0c6b7924c..000000000 --- a/pkg/tcpip/stack/packet_buffer_state.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import "gvisor.dev/gvisor/pkg/tcpip/buffer" - -// beforeSave is invoked by stateify. -func (pk *PacketBuffer) beforeSave() { - // Non-Data fields may be slices of the Data field. This causes - // problems for SR, so during save we make each header independent. - pk.Header = pk.Header.DeepCopy() - pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...) - pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...) - pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...) -} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ac043b722..23ca9ee03 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -246,7 +246,7 @@ type NetworkEndpoint interface { // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. - WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. @@ -393,7 +393,7 @@ type LinkEndpoint interface { // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, it may // require to change syscall filters. - WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) + WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) // WriteRawPacket writes a packet directly to the link. The packet // should already have an ethernet header. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 9fbe8a411..a0e5e0300 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -168,23 +168,26 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuff return err } -// WritePackets writes the set of packets through the given route. -func (r *Route) WritePackets(gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) { +// WritePackets writes a list of n packets through the given route and returns +// the number of packets written. +func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { if !r.ref.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(pkts) - n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) } r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) - payloadSize := 0 - for i := 0; i < n; i++ { - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkts[i].Header.UsedLength())) - payloadSize += pkts[i].DataSize + + writtenBytes := 0 + for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { + writtenBytes += pb.Header.UsedLength() + writtenBytes += pb.Data.Size() } - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize)) + + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) return n, err } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b8543b71e..3f8a2a095 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -153,7 +153,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { panic("not implemented") } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 3239a5911..2ca3fb809 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -756,8 +756,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { optLen := len(tf.opts) hdr := &pkt.Header - packetSize := pkt.DataSize - off := pkt.DataOffset + packetSize := pkt.Data.Size() // Initialize the header. tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) pkt.TransportHeader = buffer.View(tcp) @@ -782,12 +781,18 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, off, packetSize) + xsum = header.ChecksumVV(pkt.Data, xsum) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } } func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { + // We need to shallow clone the VectorisedView here as ReadToView will + // split the VectorisedView and Trim underlying views as it splits. Not + // doing the clone here will cause the underlying views of data itself + // to be altered. + data = data.Clone(nil) + optLen := len(tf.opts) if tf.rcvWnd > 0xffff { tf.rcvWnd = 0xffff @@ -796,31 +801,25 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso mss := int(gso.MSS) n := (data.Size() + mss - 1) / mss - // Allocate one big slice for all the headers. - hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen - buf := make([]byte, n*hdrSize) - pkts := make([]stack.PacketBuffer, n) - for i := range pkts { - pkts[i].Header = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize]) - } - size := data.Size() - off := 0 + hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen + var pkts stack.PacketBufferList for i := 0; i < n; i++ { packetSize := mss if packetSize > size { packetSize = size } size -= packetSize - pkts[i].DataOffset = off - pkts[i].DataSize = packetSize - pkts[i].Data = data - pkts[i].Hash = tf.txHash - pkts[i].Owner = owner - buildTCPHdr(r, tf, &pkts[i], gso) - off += packetSize + var pkt stack.PacketBuffer + pkt.Header = buffer.NewPrependable(hdrSize) + pkt.Hash = tf.txHash + pkt.Owner = owner + data.ReadToVV(&pkt.Data, packetSize) + buildTCPHdr(r, tf, &pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) + pkts.PushBack(&pkt) } + if tf.ttl == 0 { tf.ttl = r.DefaultTTL() } @@ -845,12 +844,10 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac } pkt := stack.PacketBuffer{ - Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), - DataOffset: 0, - DataSize: data.Size(), - Data: data, - Hash: tf.txHash, - Owner: owner, + Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), + Data: data, + Hash: tf.txHash, + Owner: owner, } buildTCPHdr(r, tf, &pkt, gso) diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index e6fe7985d..40461fd31 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -77,9 +77,11 @@ func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.V id: id, route: r.Clone(), } - s.views[0] = v - s.data = buffer.NewVectorisedView(len(v), s.views[:1]) s.rcvdTime = time.Now() + if len(v) != 0 { + s.views[0] = v + s.data = buffer.NewVectorisedView(len(v), s.views[:1]) + } return s } -- cgit v1.2.3 From 928a7c60b8f02811e9c0fcbed0077efd55471cc4 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Wed, 25 Mar 2020 17:15:05 -0700 Subject: Fix all printf formatting errors. Updates #2243 --- pkg/eventchannel/event_test.go | 4 ++-- pkg/segment/test/segment_test.go | 2 +- pkg/sentry/fs/fdpipe/pipe_test.go | 4 ++-- pkg/sentry/fsimpl/tmpfs/stat_test.go | 8 ++++---- pkg/tcpip/header/eth_test.go | 2 +- pkg/tcpip/header/ndp_test.go | 2 +- pkg/tcpip/link/fdbased/endpoint.go | 2 +- pkg/tcpip/stack/ndp_test.go | 12 ++++++------ pkg/tcpip/stack/stack_test.go | 32 ++++++++++++++++---------------- pkg/tcpip/tcpip_test.go | 2 +- pkg/tcpip/transport/tcp/tcp_test.go | 4 ++-- pkg/tcpip/transport/udp/udp_test.go | 2 +- runsc/container/test_app/test_app.go | 2 +- test/root/cgroup_test.go | 4 ++-- test/runtimes/blacklist_test.go | 2 +- test/runtimes/runner.go | 2 +- tools/nogo.json | 31 ------------------------------- 17 files changed, 43 insertions(+), 74 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go index 7f41b4a27..43750360b 100644 --- a/pkg/eventchannel/event_test.go +++ b/pkg/eventchannel/event_test.go @@ -78,7 +78,7 @@ func TestMultiEmitter(t *testing.T) { for _, name := range names { m := testMessage{name: name} if _, err := me.Emit(m); err != nil { - t.Fatal("me.Emit(%v) failed: %v", m, err) + t.Fatalf("me.Emit(%v) failed: %v", m, err) } } @@ -96,7 +96,7 @@ func TestMultiEmitter(t *testing.T) { // Close multiEmitter. if err := me.Close(); err != nil { - t.Fatal("me.Close() failed: %v", err) + t.Fatalf("me.Close() failed: %v", err) } // All testEmitters should be closed. diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go index f19a005f3..97b16c158 100644 --- a/pkg/segment/test/segment_test.go +++ b/pkg/segment/test/segment_test.go @@ -63,7 +63,7 @@ func checkSet(s *Set, expectedSegments int) error { return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments) } if got, want := seg.Value(), seg.Start()+valueOffset; got != want { - return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start, got, want) + return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start(), got, want) } prev = next havePrev = true diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go index 5aff0cc95..a0082ecca 100644 --- a/pkg/sentry/fs/fdpipe/pipe_test.go +++ b/pkg/sentry/fs/fdpipe/pipe_test.go @@ -119,7 +119,7 @@ func TestNewPipe(t *testing.T) { continue } if flags := p.flags; test.flags != flags { - t.Errorf("%s: got file flags %s, want %s", test.desc, flags, test.flags) + t.Errorf("%s: got file flags %v, want %v", test.desc, flags, test.flags) continue } if len(test.readAheadBuffer) != len(p.readAheadBuffer) { @@ -136,7 +136,7 @@ func TestNewPipe(t *testing.T) { continue } if !fdnotifier.HasFD(int32(f.FD())) { - t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD) + t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD()) } } } diff --git a/pkg/sentry/fsimpl/tmpfs/stat_test.go b/pkg/sentry/fsimpl/tmpfs/stat_test.go index 3e02e7190..d4f59ee5b 100644 --- a/pkg/sentry/fsimpl/tmpfs/stat_test.go +++ b/pkg/sentry/fsimpl/tmpfs/stat_test.go @@ -140,7 +140,7 @@ func TestSetStatAtime(t *testing.T) { Mask: 0, Atime: linux.NsecToStatxTimestamp(100), }}); err != nil { - t.Errorf("SetStat atime without mask failed: %v") + t.Errorf("SetStat atime without mask failed: %v", err) } // Atime should be unchanged. if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { @@ -155,7 +155,7 @@ func TestSetStatAtime(t *testing.T) { Atime: linux.NsecToStatxTimestamp(100), } if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil { - t.Errorf("SetStat atime with mask failed: %v") + t.Errorf("SetStat atime with mask failed: %v", err) } if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { t.Errorf("Stat got error: %v", err) @@ -205,7 +205,7 @@ func TestSetStat(t *testing.T) { Mask: 0, Atime: linux.NsecToStatxTimestamp(100), }}); err != nil { - t.Errorf("SetStat atime without mask failed: %v") + t.Errorf("SetStat atime without mask failed: %v", err) } // Atime should be unchanged. if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { @@ -220,7 +220,7 @@ func TestSetStat(t *testing.T) { Atime: linux.NsecToStatxTimestamp(100), } if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil { - t.Errorf("SetStat atime with mask failed: %v") + t.Errorf("SetStat atime with mask failed: %v", err) } if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil { t.Errorf("Stat got error: %v", err) diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go index 7a0014ad9..14413f2ce 100644 --- a/pkg/tcpip/header/eth_test.go +++ b/pkg/tcpip/header/eth_test.go @@ -88,7 +88,7 @@ func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr { - t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", got, test.expectedLinkAddr) + t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", test.addr, got, test.expectedLinkAddr) } }) } diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index 1cb9f5dc8..969f8f1e8 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -115,7 +115,7 @@ func TestNDPNeighborAdvert(t *testing.T) { // Make sure flags got updated in the backing buffer. if got := b[ndpNAFlagsOffset]; got != 64 { - t.Errorf("got flags byte = %d, want = 64") + t.Errorf("got flags byte = %d, want = 64", got) } } diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 7198742b7..b857ce9d0 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -91,7 +91,7 @@ func (p PacketDispatchMode) String() string { case PacketMMap: return "PacketMMap" default: - return fmt.Sprintf("unknown packet dispatch mode %v", p) + return fmt.Sprintf("unknown packet dispatch mode '%d'", p) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 27dc8baf9..bc822e3b1 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1959,7 +1959,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // addr2 is deprecated but if explicitly requested, it should be used. fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) } // Another PI w/ 0 preferred lifetime should not result in a deprecation @@ -1972,7 +1972,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { } expectPrimaryAddr(addr1) if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr2.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) } // Refresh lifetimes of addr generated from prefix2. @@ -2084,7 +2084,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // addr1 is deprecated but if explicitly requested, it should be used. fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) } // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make @@ -2097,7 +2097,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { } expectPrimaryAddr(addr2) if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) } // Refresh lifetimes for addr of prefix1. @@ -2121,7 +2121,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // addr2 should be the primary endpoint now since it is not deprecated. expectPrimaryAddr(addr2) if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", got, addr1.Address) + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) } // Wait for addr of prefix1 to be invalidated. @@ -2564,7 +2564,7 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { AddressWithPrefix: addr2, } if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d, %s) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) } // addr2 should be more preferred now since it is at the front of the primary // list. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 3f8a2a095..c7634ceb1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1445,19 +1445,19 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err) + t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) + t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } // If the NIC doesn't exist, it won't work. if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable { - t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) + t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable) } } @@ -1483,12 +1483,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { } nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err) + t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { - t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err) + t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } // Set the initial route table. @@ -1503,10 +1503,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // When an interface is given, the route for a broadcast goes through it. r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { - t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) + t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } // When an interface is not given, it consults the route table. @@ -2399,7 +2399,7 @@ func TestNICContextPreservation(t *testing.T) { t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos) } if got, want := nicinfo.Context == test.want, true; got != want { - t.Fatal("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) + t.Fatalf("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want) } }) } @@ -2768,7 +2768,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { { subnet, err := tcpip.NewSubnet("\x00", "\x00") if err != nil { - t.Fatalf("NewSubnet failed:", err) + t.Fatalf("NewSubnet failed: %v", err) } s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } @@ -2782,11 +2782,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // permanentExpired kind. r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false) if err != nil { - t.Fatal("FindRoute failed:", err) + t.Fatalf("FindRoute failed: %v", err) } defer r.Release() if err := s.RemoveAddress(1, "\x01"); err != nil { - t.Fatalf("RemoveAddress failed:", err) + t.Fatalf("RemoveAddress failed: %v", err) } // @@ -2798,7 +2798,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // Add some other address with peb set to // FirstPrimaryEndpoint. if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + t.Fatalf("AddAddressWithOptions failed: %v", err) } @@ -2806,7 +2806,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // make sure the new peb was respected. // (The address should just be promoted now). if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil { - t.Fatal("AddAddressWithOptions failed:", err) + t.Fatalf("AddAddressWithOptions failed: %v", err) } var primaryAddrs []tcpip.Address for _, pa := range s.NICInfo()[1].ProtocolAddresses { @@ -2839,11 +2839,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { // GetMainNICAddress; else, our original address // should be returned. if err := s.RemoveAddress(1, "\x03"); err != nil { - t.Fatalf("RemoveAddress failed:", err) + t.Fatalf("RemoveAddress failed: %v", err) } addr, err = s.GetMainNICAddress(1, fakeNetNumber) if err != nil { - t.Fatal("s.GetMainNICAddress failed:", err) + t.Fatalf("s.GetMainNICAddress failed: %v", err) } if ps == stack.NeverPrimaryEndpoint { if want := (tcpip.AddressWithPrefix{}); addr != want { diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index 8c0aacffa..1c8e2bc34 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -218,7 +218,7 @@ func TestAddressWithPrefixSubnet(t *testing.T) { gotSubnet := ap.Subnet() wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask) if err != nil { - t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) + t.Errorf("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err) continue } if gotSubnet != wantSubnet { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index ce3df7478..8a87fa0a8 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -284,7 +284,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // are released instantly on Close. tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { - t.Fatalf("e.stack.SetTransportProtocolOption(%d, %s) = %s", tcp.ProtocolNumber, tcpTW, err) + t.Fatalf("e.stack.SetTransportProtocolOption(%d, %v) = %v", tcp.ProtocolNumber, tcpTW, err) } c.EP.Close() @@ -5609,7 +5609,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { return } if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) { - t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w, wantRcvWnd) + t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w) } }, )) diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 0905726c1..b3fe557ca 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -607,7 +607,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe // Check the peer address. h := flow.header4Tuple(incoming) if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr) + c.t.Fatalf("unexpected remote address: got %s, want %v", addr.Addr, h.srcAddr) } // Check the payload. diff --git a/runsc/container/test_app/test_app.go b/runsc/container/test_app/test_app.go index 01c47c79f..5f1c4b7d6 100644 --- a/runsc/container/test_app/test_app.go +++ b/runsc/container/test_app/test_app.go @@ -96,7 +96,7 @@ func (c *uds) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) listener, err := net.Listen("unix", c.socketPath) if err != nil { - log.Fatal("error listening on socket %q:", c.socketPath, err) + log.Fatalf("error listening on socket %q: %v", c.socketPath, err) } go server(listener, outputFile) diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go index 4038661cb..679342def 100644 --- a/test/root/cgroup_test.go +++ b/test/root/cgroup_test.go @@ -53,7 +53,7 @@ func verifyPid(pid int, path string) error { if scanner.Err() != nil { return scanner.Err() } - return fmt.Errorf("got: %s, want: %d", gots, pid) + return fmt.Errorf("got: %v, want: %d", gots, pid) } // TestCgroup sets cgroup options and checks that cgroup was properly configured. @@ -106,7 +106,7 @@ func TestMemCGroup(t *testing.T) { time.Sleep(100 * time.Millisecond) } - t.Fatalf("%vMB is less than %vMB: %v", memUsage>>20, allocMemSize>>20) + t.Fatalf("%vMB is less than %vMB", memUsage>>20, allocMemSize>>20) } // TestCgroup sets cgroup options and checks that cgroup was properly configured. diff --git a/test/runtimes/blacklist_test.go b/test/runtimes/blacklist_test.go index 52f49b984..0ff69ab18 100644 --- a/test/runtimes/blacklist_test.go +++ b/test/runtimes/blacklist_test.go @@ -32,6 +32,6 @@ func TestBlacklists(t *testing.T) { t.Fatalf("error parsing blacklist: %v", err) } if *blacklistFile != "" && len(bl) == 0 { - t.Errorf("got empty blacklist for file %q", blacklistFile) + t.Errorf("got empty blacklist for file %q", *blacklistFile) } } diff --git a/test/runtimes/runner.go b/test/runtimes/runner.go index ddb890dbc..3c98f4570 100644 --- a/test/runtimes/runner.go +++ b/test/runtimes/runner.go @@ -114,7 +114,7 @@ func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.Int F: func(t *testing.T) { // Is the test blacklisted? if _, ok := blacklist[tc]; ok { - t.Skip("SKIP: blacklisted test %q", tc) + t.Skipf("SKIP: blacklisted test %q", tc) } var ( diff --git a/tools/nogo.json b/tools/nogo.json index cc05ba027..f69999e50 100644 --- a/tools/nogo.json +++ b/tools/nogo.json @@ -27,37 +27,6 @@ "/external/io_opencensus_go/tag/map_codec.go": "allowed: false positive" } }, - "printf": { - "exclude_files": { - ".*_abi_autogen_test.go": "fix: Sprintf format has insufficient args", - "/pkg/segment/test/segment_test.go": "fix: Errorf format %d arg seg.Start is a func value, not called", - "/pkg/tcpip/tcpip_test.go": "fix: Error call has possible formatting directive %q", - "/pkg/tcpip/header/eth_test.go": "fix: Fatalf format %s reads arg #3, but call has 2 args", - "/pkg/tcpip/header/ndp_test.go": "fix: Errorf format %d reads arg #1, but call has 0 args", - "/pkg/eventchannel/event_test.go": "fix: Fatal call has possible formatting directive %v", - "/pkg/tcpip/stack/ndp.go": "fix: Fatalf format %s has arg protocolAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress", - "/pkg/sentry/fs/fdpipe/pipe_test.go": "fix: Errorf format %s has arg flags of wrong type gvisor.dev/gvisor/pkg/sentry/fs.FileFlags", - "/pkg/sentry/fs/fdpipe/pipe_test.go": "fix: Errorf format %d arg f.FD is a func value, not called", - "/pkg/tcpip/link/fdbased/endpoint.go": "fix: Sprintf format %v with arg p causes recursive String method call", - "/pkg/tcpip/transport/udp/udp_test.go": "fix: Fatalf format %s has arg h.srcAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.FullAddress", - "/pkg/tcpip/transport/tcp/tcp_test.go": "fix: Fatalf format %s has arg tcpTW of wrong type gvisor.dev/gvisor/pkg/tcpip.TCPTimeWaitTimeoutOption", - "/pkg/tcpip/transport/tcp/tcp_test.go": "fix: Errorf call needs 1 arg but has 2 args", - "/pkg/tcpip/stack/ndp_test.go": "fix: Errorf format %s reads arg #3, but call has 2 args", - "/pkg/tcpip/stack/ndp_test.go": "fix: Fatalf format %s reads arg #5, but call has 4 args", - "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf format %s has arg protoAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress", - "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf format %s has arg nic1ProtoAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress", - "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf format %s has arg nic2ProtoAddr of wrong type gvisor.dev/gvisor/pkg/tcpip.ProtocolAddress", - "/pkg/tcpip/stack/stack_test.go": "fix: Fatal call has possible formatting directive %t", - "/pkg/tcpip/stack/stack_test.go": "fix: Fatalf call has arguments but no formatting directives", - "/pkg/tcpip/link/fdbased/endpoint.go": "fix: Sprintf format %v with arg p causes recursive String method call", - "/pkg/sentry/fsimpl/tmpfs/stat_test.go": "fix: Errorf format %v reads arg #1, but call has 0 args", - "/runsc/container/test_app/test_app.go": "fix: Fatal call has possible formatting directive %q", - "/test/root/cgroup_test.go": "fix: Errorf format %s has arg gots of wrong type []int", - "/test/root/cgroup_test.go": "fix: Fatalf format %v reads arg #3, but call has 2 args", - "/test/runtimes/runner.go": "fix: Skip call has possible formatting directive %q", - "/test/runtimes/blacklist_test.go": "fix: Errorf format %q has arg blacklistFile of wrong type *string" - } - }, "structtag": { "exclude_files": { "/external/": "allowed: may use arbitrary tags" -- cgit v1.2.3 From 867eeb18d8c65019fb755194d5bdf768837f07a8 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Wed, 25 Mar 2020 17:20:16 -0700 Subject: Remove lostcancel warnings. Updates #2243 --- pkg/tcpip/network/arp/arp_test.go | 3 ++- pkg/tcpip/stack/ndp_test.go | 6 ++++-- pkg/tcpip/transport/tcp/testing/context/context.go | 9 ++++++--- pkg/tcpip/transport/udp/udp_test.go | 15 ++++++++++----- tools/nogo.json | 8 -------- 5 files changed, 22 insertions(+), 19 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index b3e239ac7..1646d9cde 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -138,7 +138,8 @@ func TestDirectRequest(t *testing.T) { // Sleep tests are gross, but this will only potentially flake // if there's a bug. If there is no bug this will reliably // succeed. - ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() if pkt, ok := c.linkEP.ReadContext(ctx); ok { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index bc822e3b1..acb2d4731 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -3483,7 +3483,8 @@ func TestRouterSolicitation(t *testing.T) { e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired waitForPkt := func(timeout time.Duration) { t.Helper() - ctx, _ := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() p, ok := e.ReadContext(ctx) if !ok { t.Fatal("timed out waiting for packet") @@ -3513,7 +3514,8 @@ func TestRouterSolicitation(t *testing.T) { } waitForNothing := func(timeout time.Duration) { t.Helper() - ctx, _ := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet") } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index d4f6bc635..431ab4e6b 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -217,7 +217,8 @@ func (c *Context) Stack() *stack.Stack { func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { c.t.Helper() - ctx, _ := context.WithTimeout(context.Background(), wait) + ctx, cancel := context.WithTimeout(context.Background(), wait) + defer cancel() if _, ok := c.linkEP.ReadContext(ctx); ok { c.t.Fatal(errMsg) } @@ -235,7 +236,8 @@ func (c *Context) CheckNoPacket(errMsg string) { func (c *Context) GetPacket() []byte { c.t.Helper() - ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { c.t.Fatalf("Packet wasn't written out") @@ -486,7 +488,8 @@ func (c *Context) CreateV6Endpoint(v6only bool) { func (c *Context) GetV6Packet() []byte { c.t.Helper() - ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { c.t.Fatalf("Packet wasn't written out") diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index b3fe557ca..fd818243f 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -358,7 +358,8 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { c.t.Helper() - ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { c.t.Fatalf("Packet wasn't written out") @@ -1563,7 +1564,8 @@ func TestV4UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) } @@ -1571,7 +1573,8 @@ func TestV4UnknownDestination(t *testing.T) { } // ICMP required. - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { t.Fatalf("packet wasn't written out") @@ -1639,7 +1642,8 @@ func TestV6UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) } @@ -1647,7 +1651,8 @@ func TestV6UnknownDestination(t *testing.T) { } // ICMP required. - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() p, ok := c.linkEP.ReadContext(ctx) if !ok { t.Fatalf("packet wasn't written out") diff --git a/tools/nogo.json b/tools/nogo.json index 09bda9212..85fac8604 100644 --- a/tools/nogo.json +++ b/tools/nogo.json @@ -9,14 +9,6 @@ "/external/": "allowed: not subject to unsafe naming rules" } }, - "lostcancel": { - "exclude_files": { - "/pkg/tcpip/network/arp/arp_test.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak", - "/pkg/tcpip/stack/ndp_test.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak", - "/pkg/tcpip/transport/udp/udp_test.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak", - "/pkg/tcpip/transport/tcp/testing/context/context.go": "fix: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak" - } - }, "nilness": { "exclude_files": { "/com_github_vishvananda_netlink/route_linux.go": "allowed: false positive", -- cgit v1.2.3 From 7928aa345e334f2c68f8f03b71d8cabe79e8db7e Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Thu, 9 Apr 2020 09:30:39 -0700 Subject: Convert int and bool socket options to use GetSockOptInt and GetSockOptBool PiperOrigin-RevId: 305699233 --- pkg/sentry/socket/netstack/netstack.go | 155 +++++------- pkg/sentry/socket/unix/transport/BUILD | 1 + pkg/sentry/socket/unix/transport/unix.go | 50 ++-- pkg/tcpip/stack/transport_demuxer_test.go | 35 ++- pkg/tcpip/tcpip.go | 145 ++++++----- pkg/tcpip/transport/icmp/endpoint.go | 47 ++-- pkg/tcpip/transport/raw/endpoint.go | 18 +- pkg/tcpip/transport/tcp/endpoint.go | 387 +++++++++++++----------------- pkg/tcpip/transport/tcp/tcp_test.go | 110 +++++---- pkg/tcpip/transport/udp/endpoint.go | 231 +++++++++--------- pkg/tcpip/transport/udp/udp_test.go | 60 ++--- 11 files changed, 583 insertions(+), 656 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 5d0085462..20e3fa0d2 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -300,7 +300,7 @@ type SocketOperations struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil { + if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { return nil, syserr.TranslateNetstackError(err) } } @@ -965,6 +965,13 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in return nil, syserr.ErrProtocolNotAvailable } +func boolToInt32(v bool) int32 { + if v { + return 1 + } + return 0 +} + // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. @@ -998,12 +1005,11 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.PasscredOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.PasscredOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1042,24 +1048,22 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.ReuseAddressOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.ReusePortOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.ReusePortOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1089,24 +1093,22 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.BroadcastOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.BroadcastOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.KeepaliveEnabledOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { @@ -1156,47 +1158,41 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptInt(tcpip.DelayOption) + v, err := ep.GetSockOptBool(tcpip.DelayOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } - - if v == 0 { - return int32(1), nil - } - return int32(0), nil + return boolToInt32(!v), nil case linux.TCP_CORK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.CorkOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.CorkOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.QuickAckOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.QuickAckOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + return boolToInt32(v), nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.MaxSegOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.MaxSegOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1328,11 +1324,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1342,8 +1334,8 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if outLen == 0 { return make([]byte, 0), nil } - var v tcpip.IPv6TrafficClassOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1365,12 +1357,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil default: emitUnimplementedEventIPv6(t, name) @@ -1386,8 +1373,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.ErrInvalidArgument } - var v tcpip.TTLOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.TTLOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1403,8 +1390,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.ErrInvalidArgument } - var v tcpip.MulticastTTLOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.MulticastTTLOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -1429,23 +1416,19 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.ErrInvalidArgument } - var v tcpip.MulticastLoopOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - if v { - return int32(1), nil - } - return int32(0), nil + return boolToInt32(v), nil case linux.IP_TOS: // Length handling for parity with Linux. if outLen == 0 { return []byte(nil), nil } - var v tcpip.IPv4TOSOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } if outLen < sizeOfInt32 { @@ -1462,11 +1445,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { @@ -1477,11 +1456,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - var o int32 - if v { - o = 1 - } - return o, nil + return boolToInt32(v), nil default: emitUnimplementedEventIP(t, name) @@ -1592,7 +1567,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0)) case linux.SO_REUSEPORT: if len(optVal) < sizeOfInt32 { @@ -1600,7 +1575,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0)) case linux.SO_BINDTODEVICE: n := bytes.IndexByte(optVal, 0) @@ -1628,7 +1603,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BroadcastOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0)) case linux.SO_PASSCRED: if len(optVal) < sizeOfInt32 { @@ -1636,7 +1611,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0)) case linux.SO_KEEPALIVE: if len(optVal) < sizeOfInt32 { @@ -1644,7 +1619,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveEnabledOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0)) case linux.SO_SNDTIMEO: if len(optVal) < linux.SizeOfTimeval { @@ -1716,11 +1691,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - var o int - if v == 0 { - o = 1 - } - return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o)) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0)) case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -1728,7 +1699,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.CorkOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0)) case linux.TCP_QUICKACK: if len(optVal) < sizeOfInt32 { @@ -1736,7 +1707,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0)) case linux.TCP_MAXSEG: if len(optVal) < sizeOfInt32 { @@ -1744,7 +1715,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v))) case linux.TCP_KEEPIDLE: if len(optVal) < sizeOfInt32 { @@ -1855,7 +1826,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) if v == -1 { v = 0 } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, int(v))) case linux.IPV6_RECVTCLASS: v, err := parseIntOrChar(optVal) @@ -1940,7 +1911,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s if v < 0 || v > 255 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastTTLOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MulticastTTLOption, int(v))) case linux.IP_ADD_MEMBERSHIP: req, err := copyInMulticastRequest(optVal, false /* allowAddr */) @@ -1987,9 +1958,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s return err } - return syserr.TranslateNetstackError(ep.SetSockOpt( - tcpip.MulticastLoopOption(v != 0), - )) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0)) case linux.MCAST_JOIN_GROUP: // FIXME(b/124219304): Implement MCAST_JOIN_GROUP. @@ -2008,7 +1977,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } else if v < 1 || v > 255 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TTLOption, int(v))) case linux.IP_TOS: if len(optVal) == 0 { @@ -2018,7 +1987,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv4TOSOption, int(v))) case linux.IP_RECVTOS: v, err := parseIntOrChar(optVal) diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 74bcd6300..c708b6030 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/ilist", + "//pkg/log", "//pkg/refs", "//pkg/sync", "//pkg/syserr", diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 2ef654235..1f3880cc5 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" @@ -838,24 +839,45 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess // SetSockOpt sets a socket option. Currently not supported. func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { - case tcpip.PasscredOption: - e.setPasscred(v != 0) - return nil - } return nil } func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.BroadcastOption: + case tcpip.PasscredOption: + e.setPasscred(v) + case tcpip.ReuseAddressOption: + default: + log.Warningf("Unsupported socket option: %d", opt) + return tcpip.ErrUnknownProtocolOption + } return nil } func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.SendBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: + default: + log.Warningf("Unsupported socket option: %d", opt) + return tcpip.ErrUnknownProtocolOption + } return nil } func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + case tcpip.PasscredOption: + return e.Passcred(), nil + + default: + log.Warningf("Unsupported socket option: %d", opt) + return false, tcpip.ErrUnknownProtocolOption + } } func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { @@ -914,29 +936,19 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return int(v), nil default: + log.Warningf("Unsupported socket option: %d", opt) return -1, tcpip.ErrUnknownProtocolOption } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.PasscredOption: - if e.Passcred() { - *o = tcpip.PasscredOption(1) - } else { - *o = tcpip.PasscredOption(0) - } - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - default: + log.Warningf("Unsupported socket option: %T", opt) return tcpip.ErrUnknownProtocolOption } } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index c65b0c632..2474a7db3 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -206,7 +206,7 @@ func TestTransportDemuxerRegister(t *testing.T) { // the distribution of packets received matches expectations. func TestBindToDeviceDistribution(t *testing.T) { type endpointSockopts struct { - reuse int + reuse bool bindToDevice tcpip.NICID } for _, test := range []struct { @@ -221,11 +221,11 @@ func TestBindToDeviceDistribution(t *testing.T) { "BindPortReuse", // 5 endpoints that all have reuse set. []endpointSockopts{ - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, - {reuse: 1, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, + {reuse: true, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. @@ -236,9 +236,9 @@ func TestBindToDeviceDistribution(t *testing.T) { "BindToDevice", // 3 endpoints with various bindings. []endpointSockopts{ - {reuse: 0, bindToDevice: 1}, - {reuse: 0, bindToDevice: 2}, - {reuse: 0, bindToDevice: 3}, + {reuse: false, bindToDevice: 1}, + {reuse: false, bindToDevice: 2}, + {reuse: false, bindToDevice: 3}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. @@ -253,12 +253,12 @@ func TestBindToDeviceDistribution(t *testing.T) { "ReuseAndBindToDevice", // 6 endpoints with various bindings. []endpointSockopts{ - {reuse: 1, bindToDevice: 1}, - {reuse: 1, bindToDevice: 1}, - {reuse: 1, bindToDevice: 2}, - {reuse: 1, bindToDevice: 2}, - {reuse: 1, bindToDevice: 2}, - {reuse: 1, bindToDevice: 0}, + {reuse: true, bindToDevice: 1}, + {reuse: true, bindToDevice: 1}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 2}, + {reuse: true, bindToDevice: 0}, }, map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to @@ -309,9 +309,8 @@ func TestBindToDeviceDistribution(t *testing.T) { }(ep) defer ep.Close() - reusePortOption := tcpip.ReusePortOption(endpoint.reuse) - if err := ep.SetSockOpt(reusePortOption); err != nil { - t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", reusePortOption, i, err) + if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil { + t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err) } bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) if err := ep.SetSockOpt(bindToDeviceOption); err != nil { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 2ef3271f1..aec7126ff 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -520,34 +520,90 @@ type WriteOptions struct { type SockOptBool int const ( + // BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether + // datagram sockets are allowed to send packets to a broadcast address. + BroadcastOption SockOptBool = iota + + // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be + // held until segments are full by the TCP transport protocol. + CorkOption + + // DelayOption is used by SetSockOpt/GetSockOpt to specify if data + // should be sent out immediately by the transport protocol. For TCP, + // it determines if the Nagle algorithm is on or off. + DelayOption + + // KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether + // TCP keepalive is enabled for this socket. + KeepaliveEnabledOption + + // MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether + // multicast packets sent over a non-loopback interface will be looped back. + MulticastLoopOption + + // PasscredOption is used by SetSockOpt/GetSockOpt to specify whether + // SCM_CREDENTIALS socket control messages are enabled. + // + // Only supported on Unix sockets. + PasscredOption + + // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. + QuickAckOption + // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the // IPV6_TCLASS ancillary message is passed with incoming packets. - ReceiveTClassOption SockOptBool = iota + ReceiveTClassOption // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS // ancillary message is passed with incoming packets. ReceiveTOSOption - // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 - // socket is to be restricted to sending and receiving IPv6 packets only. - V6OnlyOption - // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify // if more inforamtion is provided with incoming packets such // as interface index and address. ReceiveIPPacketInfoOption - // TODO(b/146901447): convert existing bool socket options to be handled via - // Get/SetSockOptBool + // ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind() + // should allow reuse of local address. + ReuseAddressOption + + // ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets + // to be bound to an identical socket address. + ReusePortOption + + // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 + // socket is to be restricted to sending and receiving IPv6 packets only. + V6OnlyOption ) // SockOptInt represents socket options which values have the int type. type SockOptInt int const ( + // KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number + // of un-ACKed TCP keepalives that will be sent before the connection is + // closed. + KeepaliveCountOption SockOptInt = iota + + // IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS + // for all subsequent outgoing IPv4 packets from the endpoint. + IPv4TOSOption + + // IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS + // for all subsequent outgoing IPv6 packets from the endpoint. + IPv6TrafficClassOption + + // MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current + // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. + MaxSegOption + + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default + // TTL value for multicast messages. The default is 1. + MulticastTTLOption + // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the // number of unread bytes in the input buffer should be returned. - ReceiveQueueSizeOption SockOptInt = iota + ReceiveQueueSizeOption // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to // specify the send buffer size option. @@ -561,44 +617,21 @@ const ( // number of unread bytes in the output buffer should be returned. SendQueueSizeOption - // DelayOption is used by SetSockOpt/GetSockOpt to specify if data - // should be sent out immediately by the transport protocol. For TCP, - // it determines if the Nagle algorithm is on or off. - DelayOption - - // TODO(b/137664753): convert all int socket options to be handled via - // GetSockOptInt. + // TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop + // limit value for unicast messages. The default is protocol specific. + // + // A zero value indicates the default. + TTLOption ) // ErrorOption is used in GetSockOpt to specify that the last error reported by // the endpoint should be cleared and returned. type ErrorOption struct{} -// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be -// held until segments are full by the TCP transport protocol. -type CorkOption int - -// ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind() -// should allow reuse of local address. -type ReuseAddressOption int - -// ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets -// to be bound to an identical socket address. -type ReusePortOption int - // BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets // should bind only on a specific NIC. type BindToDeviceOption NICID -// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. -type QuickAckOption int - -// PasscredOption is used by SetSockOpt/GetSockOpt to specify whether -// SCM_CREDENTIALS socket control messages are enabled. -// -// Only supported on Unix sockets. -type PasscredOption int - // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // // TODO(b/64800844): Add and populate stat fields. @@ -607,10 +640,6 @@ type TCPInfoOption struct { RTTVar time.Duration } -// KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether -// TCP keepalive is enabled for this socket. -type KeepaliveEnabledOption int - // KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a // connection must remain idle before the first TCP keepalive packet is sent. // Once this time is reached, KeepaliveIntervalOption is used instead. @@ -620,11 +649,6 @@ type KeepaliveIdleOption time.Duration // interval between sending TCP keepalive packets. type KeepaliveIntervalOption time.Duration -// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number -// of un-ACKed TCP keepalives that will be sent before the connection is -// closed. -type KeepaliveCountOption int - // TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user // specified timeout for a given TCP connection. // See: RFC5482 for details. @@ -638,20 +662,9 @@ type CongestionControlOption string // control algorithms. type AvailableCongestionControlOption string -// ModerateReceiveBufferOption allows the caller to enable/disable TCP receive // buffer moderation. type ModerateReceiveBufferOption bool -// MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current -// Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. -type MaxSegOption int - -// TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop -// limit value for unicast messages. The default is protocol specific. -// -// A zero value indicates the default. -type TTLOption uint8 - // TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the // maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state // before being marked closed. @@ -668,10 +681,6 @@ type TCPTimeWaitTimeoutOption time.Duration // for a handshake till the specified timeout until a segment with data arrives. type TCPDeferAcceptOption time.Duration -// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default -// TTL value for multicast messages. The default is 1. -type MulticastTTLOption uint8 - // MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a // default interface for multicast. type MulticastInterfaceOption struct { @@ -679,10 +688,6 @@ type MulticastInterfaceOption struct { InterfaceAddr Address } -// MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether -// multicast packets sent over a non-loopback interface will be looped back. -type MulticastLoopOption bool - // MembershipOption is used by SetSockOpt/GetSockOpt as an argument to // AddMembershipOption and RemoveMembershipOption. type MembershipOption struct { @@ -705,22 +710,10 @@ type RemoveMembershipOption MembershipOption // TCP out-of-band data is delivered along with the normal in-band data. type OutOfBandInlineOption int -// BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether -// datagram sockets are allowed to send packets to a broadcast address. -type BroadcastOption int - // DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify // a default TTL. type DefaultTTLOption uint8 -// IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS -// for all subsequent outgoing IPv4 packets from the endpoint. -type IPv4TOSOption uint8 - -// IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS -// for all subsequent outgoing IPv6 packets from the endpoint. -type IPv6TrafficClassOption uint8 - // IPPacketInfo is the message struture for IP_PKTINFO. // // +stateify savable diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b007302fb..3a133eef9 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -348,29 +348,37 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(o) - e.mu.Unlock() - } - return nil } // SetSockOptBool sets a socket option. Currently not supported. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - return nil + return tcpip.ErrUnknownProtocolOption } // SetSockOptInt sets a socket option. Currently not supported. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + + default: + return tcpip.ErrUnknownProtocolOption + } return nil } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -397,26 +405,23 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil + case tcpip.TTLOption: + e.rcvMu.Lock() + v := int(e.ttl) + e.rcvMu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption } - return -1, tcpip.ErrUnknownProtocolOption } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - - case *tcpip.TTLOption: - e.rcvMu.Lock() - *o = tcpip.TTLOption(e.ttl) - e.rcvMu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 337bc1c71..eee754a5a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -533,14 +533,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - default: return tcpip.ErrUnknownProtocolOption } @@ -548,7 +544,13 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -576,9 +578,9 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvMu.Unlock() return v, nil + default: + return -1, tcpip.ErrUnknownProtocolOption } - - return -1, tcpip.ErrUnknownProtocolOption } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 9b123e968..a8d443f73 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -821,7 +821,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var de DelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { - e.SetSockOptInt(tcpip.DelayOption, 1) + e.SetSockOptBool(tcpip.DelayOption, true) } var tcpLT tcpip.TCPLingerTimeoutOption @@ -1409,10 +1409,60 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo // SetSockOptBool sets a socket option. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - e.LockUser() - defer e.UnlockUser() - switch opt { + + case tcpip.BroadcastOption: + e.LockUser() + e.broadcast = v + e.UnlockUser() + + case tcpip.CorkOption: + e.LockUser() + if !v { + atomic.StoreUint32(&e.cork, 0) + + // Handle the corked data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.cork, 1) + } + e.UnlockUser() + + case tcpip.DelayOption: + if v { + atomic.StoreUint32(&e.delay, 1) + } else { + atomic.StoreUint32(&e.delay, 0) + + // Handle delayed data. + e.sndWaker.Assert() + } + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + e.keepalive.enabled = v + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.QuickAckOption: + o := uint32(1) + if v { + o = 0 + } + atomic.StoreUint32(&e.slowAck, o) + + case tcpip.ReuseAddressOption: + e.LockUser() + e.reuseAddr = v + e.UnlockUser() + return nil + + case tcpip.ReusePortOption: + e.LockUser() + e.reusePort = v + e.UnlockUser() + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -1424,7 +1474,11 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return tcpip.ErrInvalidEndpointState } + e.LockUser() e.v6only = v + e.UnlockUser() + default: + return tcpip.ErrUnknownProtocolOption } return nil @@ -1432,7 +1486,40 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt sets a socket option. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + switch opt { + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + e.keepalive.count = int(v) + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.IPv4TOSOption: + e.LockUser() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.UnlockUser() + + case tcpip.IPv6TrafficClassOption: + e.LockUser() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.UnlockUser() + + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.LockUser() + e.userMSS = uint16(userMSS) + e.UnlockUser() + e.notifyProtocolGoroutine(notifyMSSChanged) + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1483,7 +1570,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.rcvListMu.Unlock() e.UnlockUser() e.notifyProtocolGoroutine(mask) - return nil case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max @@ -1502,52 +1588,21 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.sndBufMu.Lock() e.sndBufSize = size e.sndBufMu.Unlock() - return nil - case tcpip.DelayOption: - if v == 0 { - atomic.StoreUint32(&e.delay, 0) - - // Handle delayed data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.delay, 1) - } - return nil + case tcpip.TTLOption: + e.LockUser() + e.ttl = uint8(v) + e.UnlockUser() default: - return nil + return tcpip.ErrUnknownProtocolOption } + return nil } // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 - const inetECNMask = 3 switch v := opt.(type) { - case tcpip.CorkOption: - if v == 0 { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - return nil - - case tcpip.ReuseAddressOption: - e.LockUser() - e.reuseAddr = v != 0 - e.UnlockUser() - return nil - - case tcpip.ReusePortOption: - e.LockUser() - e.reusePort = v != 0 - e.UnlockUser() - return nil - case tcpip.BindToDeviceOption: id := tcpip.NICID(v) if id != 0 && !e.stack.HasNIC(id) { @@ -1556,72 +1611,26 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.LockUser() e.bindToDevice = id e.UnlockUser() - return nil - - case tcpip.QuickAckOption: - if v == 0 { - atomic.StoreUint32(&e.slowAck, 1) - } else { - atomic.StoreUint32(&e.slowAck, 0) - } - return nil - - case tcpip.MaxSegOption: - userMSS := v - if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { - return tcpip.ErrInvalidOptionValue - } - e.LockUser() - e.userMSS = uint16(userMSS) - e.UnlockUser() - e.notifyProtocolGoroutine(notifyMSSChanged) - return nil - - case tcpip.TTLOption: - e.LockUser() - e.ttl = uint8(v) - e.UnlockUser() - return nil - - case tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - e.keepalive.enabled = v != 0 - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil case tcpip.KeepaliveIdleOption: e.keepalive.Lock() e.keepalive.idle = time.Duration(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil case tcpip.KeepaliveIntervalOption: e.keepalive.Lock() e.keepalive.interval = time.Duration(v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil - case tcpip.KeepaliveCountOption: - e.keepalive.Lock() - e.keepalive.count = int(v) - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - return nil + case tcpip.OutOfBandInlineOption: + // We don't currently support disabling this option. case tcpip.TCPUserTimeoutOption: e.LockUser() e.userTimeout = time.Duration(v) e.UnlockUser() - return nil - - case tcpip.BroadcastOption: - e.LockUser() - e.broadcast = v != 0 - e.UnlockUser() - return nil case tcpip.CongestionControlOption: // Query the available cc algorithms in the stack and @@ -1652,22 +1661,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // control algorithm is specified. return tcpip.ErrNoSuchFile - case tcpip.IPv4TOSOption: - e.LockUser() - // TODO(gvisor.dev/issue/995): ECN is not currently supported, - // ignore the bits for now. - e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.UnlockUser() - return nil - - case tcpip.IPv6TrafficClassOption: - e.LockUser() - // TODO(gvisor.dev/issue/995): ECN is not currently supported, - // ignore the bits for now. - e.sendTOS = uint8(v) & ^uint8(inetECNMask) - e.UnlockUser() - return nil - case tcpip.TCPLingerTimeoutOption: e.LockUser() if v < 0 { @@ -1688,7 +1681,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } e.tcpLingerTimeout = time.Duration(v) e.UnlockUser() - return nil case tcpip.TCPDeferAcceptOption: e.LockUser() @@ -1697,11 +1689,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } e.deferAccept = time.Duration(v) e.UnlockUser() - return nil default: return nil } + return nil } // readyReceiveSize returns the number of bytes ready to be received. @@ -1723,6 +1715,43 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { + case tcpip.BroadcastOption: + e.LockUser() + v := e.broadcast + e.UnlockUser() + return v, nil + + case tcpip.CorkOption: + return atomic.LoadUint32(&e.cork) != 0, nil + + case tcpip.DelayOption: + return atomic.LoadUint32(&e.delay) != 0, nil + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + v := e.keepalive.enabled + e.keepalive.Unlock() + + return v, nil + + case tcpip.QuickAckOption: + v := atomic.LoadUint32(&e.slowAck) == 0 + return v, nil + + case tcpip.ReuseAddressOption: + e.LockUser() + v := e.reuseAddr + e.UnlockUser() + + return v, nil + + case tcpip.ReusePortOption: + e.LockUser() + v := e.reusePort + e.UnlockUser() + + return v, nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -1734,14 +1763,41 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.UnlockUser() return v, nil - } - return false, tcpip.ErrUnknownProtocolOption + default: + return false, tcpip.ErrUnknownProtocolOption + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + v := e.keepalive.count + e.keepalive.Unlock() + return v, nil + + case tcpip.IPv4TOSOption: + e.LockUser() + v := int(e.sendTOS) + e.UnlockUser() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.LockUser() + v := int(e.sendTOS) + e.UnlockUser() + return v, nil + + case tcpip.MaxSegOption: + // This is just stubbed out. Linux never returns the user_mss + // value as it either returns the defaultMSS or returns the + // actual current MSS. Netstack just returns the defaultMSS + // always for now. + v := header.TCPDefaultMSS + return v, nil + case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() @@ -1757,12 +1813,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.rcvListMu.Unlock() return v, nil - case tcpip.DelayOption: - var o int - if v := atomic.LoadUint32(&e.delay); v != 0 { - o = 1 - } - return o, nil + case tcpip.TTLOption: + e.LockUser() + v := int(e.ttl) + e.UnlockUser() + return v, nil default: return -1, tcpip.ErrUnknownProtocolOption @@ -1779,61 +1834,10 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.lastErrorMu.Unlock() return err - case *tcpip.MaxSegOption: - // This is just stubbed out. Linux never returns the user_mss - // value as it either returns the defaultMSS or returns the - // actual current MSS. Netstack just returns the defaultMSS - // always for now. - *o = header.TCPDefaultMSS - return nil - - case *tcpip.CorkOption: - *o = 0 - if v := atomic.LoadUint32(&e.cork); v != 0 { - *o = 1 - } - return nil - - case *tcpip.ReuseAddressOption: - e.LockUser() - v := e.reuseAddr - e.UnlockUser() - - *o = 0 - if v { - *o = 1 - } - return nil - - case *tcpip.ReusePortOption: - e.LockUser() - v := e.reusePort - e.UnlockUser() - - *o = 0 - if v { - *o = 1 - } - return nil - case *tcpip.BindToDeviceOption: e.LockUser() *o = tcpip.BindToDeviceOption(e.bindToDevice) e.UnlockUser() - return nil - - case *tcpip.QuickAckOption: - *o = 1 - if v := atomic.LoadUint32(&e.slowAck); v != 0 { - *o = 0 - } - return nil - - case *tcpip.TTLOption: - e.LockUser() - *o = tcpip.TTLOption(e.ttl) - e.UnlockUser() - return nil case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} @@ -1846,92 +1850,45 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { o.RTTVar = snd.rtt.rttvar snd.rtt.Unlock() } - return nil - - case *tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - v := e.keepalive.enabled - e.keepalive.Unlock() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() *o = tcpip.KeepaliveIdleOption(e.keepalive.idle) e.keepalive.Unlock() - return nil case *tcpip.KeepaliveIntervalOption: e.keepalive.Lock() *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval) e.keepalive.Unlock() - return nil - - case *tcpip.KeepaliveCountOption: - e.keepalive.Lock() - *o = tcpip.KeepaliveCountOption(e.keepalive.count) - e.keepalive.Unlock() - return nil case *tcpip.TCPUserTimeoutOption: e.LockUser() *o = tcpip.TCPUserTimeoutOption(e.userTimeout) e.UnlockUser() - return nil case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. *o = 1 - return nil - - case *tcpip.BroadcastOption: - e.LockUser() - v := e.broadcast - e.UnlockUser() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.CongestionControlOption: e.LockUser() *o = e.cc e.UnlockUser() - return nil - - case *tcpip.IPv4TOSOption: - e.LockUser() - *o = tcpip.IPv4TOSOption(e.sendTOS) - e.UnlockUser() - return nil - - case *tcpip.IPv6TrafficClassOption: - e.LockUser() - *o = tcpip.IPv6TrafficClassOption(e.sendTOS) - e.UnlockUser() - return nil case *tcpip.TCPLingerTimeoutOption: e.LockUser() *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) e.UnlockUser() - return nil case *tcpip.TCPDeferAcceptOption: e.LockUser() *o = tcpip.TCPDeferAcceptOption(e.deferAccept) e.UnlockUser() - return nil default: return tcpip.ErrUnknownProtocolOption } + return nil } // checkV4MappedLocked determines the effective network protocol and converts diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index ce3df7478..32d0af6c4 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -728,7 +728,7 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) { const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize tests := []struct { name string - setMSS uint16 + setMSS int expMSS uint16 }{ { @@ -756,15 +756,14 @@ func TestUserSuppliedMSSOnConnectV4(t *testing.T) { c.Create(-1) // Set the MSS socket option. - opt := tcpip.MaxSegOption(test.setMSS) - if err := c.EP.SetSockOpt(opt); err != nil { - t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, test.setMSS); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) } // Get expected window size. rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) if err != nil { - t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) @@ -818,15 +817,14 @@ func TestUserSuppliedMSSOnConnectV6(t *testing.T) { c.CreateV6Endpoint(true) // Set the MSS socket option. - opt := tcpip.MaxSegOption(test.setMSS) - if err := c.EP.SetSockOpt(opt); err != nil { - t.Fatalf("SetSockOpt(%#v) failed: %s", opt, err) + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) } // Get expected window size. rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) if err != nil { - t.Fatalf("GetSockOpt(%v) failed: %s", tcpip.ReceiveBufferSizeOption, err) + t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) @@ -1077,17 +1075,17 @@ func TestTOSV4(t *testing.T) { c.EP = ep const tos = 0xC0 - if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { + t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err) } - var v tcpip.IPv4TOSOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Errorf("GetSockopt failed: %s", err) + v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err) } - if want := tcpip.IPv4TOSOption(tos); v != want { - t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if v != tos { + t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos) } testV4Connect(t, c, checker.TOS(tos, 0)) @@ -1125,17 +1123,17 @@ func TestTrafficClassV6(t *testing.T) { c.CreateV6Endpoint(false) const tos = 0xC0 - if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { - t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err) + if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil { + t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err) } - var v tcpip.IPv6TrafficClassOption - if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockopt failed: %s", err) + v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err) } - if want := tcpip.IPv6TrafficClassOption(tos); v != want { - t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if v != tos { + t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos) } // Test the connection request. @@ -1711,7 +1709,7 @@ func TestNoWindowShrinking(t *testing.T) { c.CreateConnected(789, 30000, 10) if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %v", err) } we, ch := waiter.NewChannelEntry(nil) @@ -1984,7 +1982,7 @@ func TestScaledWindowAccept(t *testing.T) { // Set the window size greater than the maximum non-scaled window. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { @@ -2057,7 +2055,7 @@ func TestNonScaledWindowAccept(t *testing.T) { // Set the window size greater than the maximum non-scaled window. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %v", err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { @@ -2221,10 +2219,10 @@ func TestSegmentMerging(t *testing.T) { { "cork", func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(1)) + ep.SetSockOptBool(tcpip.CorkOption, true) }, func(ep tcpip.Endpoint) { - ep.SetSockOpt(tcpip.CorkOption(0)) + ep.SetSockOptBool(tcpip.CorkOption, false) }, }, } @@ -2316,7 +2314,7 @@ func TestDelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptInt(tcpip.DelayOption, 1) + c.EP.SetSockOptBool(tcpip.DelayOption, true) var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { @@ -2364,7 +2362,7 @@ func TestUndelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptInt(tcpip.DelayOption, 1) + c.EP.SetSockOptBool(tcpip.DelayOption, true) allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { @@ -2397,7 +2395,7 @@ func TestUndelay(t *testing.T) { // Check that we don't get the second packet yet. c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.EP.SetSockOptInt(tcpip.DelayOption, 0) + c.EP.SetSockOptBool(tcpip.DelayOption, false) // Check that data is received. second := c.GetPacket() @@ -2434,8 +2432,8 @@ func TestMSSNotDelayed(t *testing.T) { fn func(tcpip.Endpoint) }{ {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }}, + {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }}, + {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }}, } for _, test := range tests { @@ -2576,12 +2574,12 @@ func TestSetTTL(t *testing.T) { t.Fatalf("NewEndpoint failed: %v", err) } - if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - t.Fatalf("SetSockOpt failed: %v", err) + if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { + t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { - t.Fatalf("Unexpected return value from Connect: %v", err) + t.Fatalf("Unexpected return value from Connect: %s", err) } // Receive SYN packet. @@ -2621,7 +2619,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { // window scaling option. const rcvBufferSize = 0x20000 if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { @@ -2765,7 +2763,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { const rcvBufferSize = 0x20000 const wndScale = 2 if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { - t.Fatalf("SetSockOpt failed failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) } // Start connection attempt. @@ -3882,26 +3880,26 @@ func TestMinMaxBufferSizes(t *testing.T) { // Set values below the min. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) } checkRecvBufferSize(t, ep, 200) if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) } checkSendBufferSize(t, ep, 300) // Set values above the max. if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) } checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { - t.Fatalf("GetSockOpt failed: %v", err) + t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) } checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) @@ -4147,11 +4145,11 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { case "ipv4": case "ipv6": if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(true)) failed: %v", err) + t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err) } case "dual": if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil { - t.Fatalf("SetSockOpt(V6OnlyOption(false)) failed: %v", err) + t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err) } default: t.Fatalf("unknown network: '%s'", network) @@ -4477,8 +4475,8 @@ func TestKeepalive(t *testing.T) { const keepAliveInterval = 10 * time.Millisecond c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) - c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5)) - c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) + c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) // 5 unacked keepalives are sent. ACK each one, and check that the // connection stays alive after 5. @@ -5770,14 +5768,14 @@ func TestReceiveBufferAutoTuning(t *testing.T) { func TestDelayEnabled(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - checkDelayOption(t, c, false, 0) // Delay is disabled by default. + checkDelayOption(t, c, false, false) // Delay is disabled by default. for _, v := range []struct { delayEnabled tcp.DelayEnabled - wantDelayOption int + wantDelayOption bool }{ - {delayEnabled: false, wantDelayOption: 0}, - {delayEnabled: true, wantDelayOption: 1}, + {delayEnabled: false, wantDelayOption: false}, + {delayEnabled: true, wantDelayOption: true}, } { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5788,7 +5786,7 @@ func TestDelayEnabled(t *testing.T) { } } -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption int) { +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { t.Helper() var gotDelayEnabled tcp.DelayEnabled @@ -5803,12 +5801,12 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.Del if err != nil { t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %v", err) } - gotDelayOption, err := ep.GetSockOptInt(tcpip.DelayOption) + gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption) if err != nil { - t.Fatalf("ep.GetSockOptInt(tcpip.DelayOption) failed: %v", err) + t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err) } if gotDelayOption != wantDelayOption { - t.Errorf("ep.GetSockOptInt(tcpip.DelayOption) got: %d, want: %d", gotDelayOption, wantDelayOption) + t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption) } } @@ -6620,8 +6618,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { const keepAliveInterval = 10 * time.Millisecond c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond)) c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) - c.EP.SetSockOpt(tcpip.KeepaliveCountOption(10)) - c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10) + c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) // Set userTimeout to be the duration for 3 keepalive probes. userTimeout := 30 * time.Millisecond diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 120d3baa3..492cc1fcb 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -501,11 +501,20 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { switch opt { + case tcpip.BroadcastOption: + e.mu.Lock() + e.broadcast = v + e.mu.Unlock() + + case tcpip.MulticastLoopOption: + e.mu.Lock() + e.multicastLoop = v + e.mu.Unlock() + case tcpip.ReceiveTOSOption: e.mu.Lock() e.receiveTOS = v e.mu.Unlock() - return nil case tcpip.ReceiveTClassOption: // We only support this option on v6 endpoints. @@ -516,7 +525,18 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.mu.Lock() e.receiveTClass = v e.mu.Unlock() - return nil + + case tcpip.ReceiveIPPacketInfoOption: + e.mu.Lock() + e.receiveIPPacketInfo = v + e.mu.Unlock() + + case tcpip.ReuseAddressOption: + + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v + e.mu.Unlock() case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. @@ -533,13 +553,8 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { } e.v6only = v - return nil - - case tcpip.ReceiveIPPacketInfoOption: - e.mu.Lock() - e.receiveIPPacketInfo = v - e.mu.Unlock() - return nil + default: + return tcpip.ErrUnknownProtocolOption } return nil @@ -547,22 +562,40 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { - return nil -} + switch opt { + case tcpip.MulticastTTLOption: + e.mu.Lock() + e.multicastTTL = uint8(v) + e.mu.Unlock() -// SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { case tcpip.TTLOption: e.mu.Lock() e.ttl = uint8(v) e.mu.Unlock() - case tcpip.MulticastTTLOption: + case tcpip.IPv4TOSOption: e.mu.Lock() - e.multicastTTL = uint8(v) + e.sendTOS = uint8(v) + e.mu.Unlock() + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + e.sendTOS = uint8(v) e.mu.Unlock() + case tcpip.ReceiveBufferSizeOption: + case tcpip.SendBufferSizeOption: + + default: + return tcpip.ErrUnknownProtocolOption + } + + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { case tcpip.MulticastInterfaceOption: e.mu.Lock() defer e.mu.Unlock() @@ -686,16 +719,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] - case tcpip.MulticastLoopOption: - e.mu.Lock() - e.multicastLoop = bool(v) - e.mu.Unlock() - - case tcpip.ReusePortOption: - e.mu.Lock() - e.reusePort = v != 0 - e.mu.Unlock() - case tcpip.BindToDeviceOption: id := tcpip.NICID(v) if id != 0 && !e.stack.HasNIC(id) { @@ -704,26 +727,6 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() e.bindToDevice = id e.mu.Unlock() - return nil - - case tcpip.BroadcastOption: - e.mu.Lock() - e.broadcast = v != 0 - e.mu.Unlock() - - return nil - - case tcpip.IPv4TOSOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - return nil - - case tcpip.IPv6TrafficClassOption: - e.mu.Lock() - e.sendTOS = uint8(v) - e.mu.Unlock() - return nil } return nil } @@ -731,6 +734,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { + case tcpip.BroadcastOption: + e.mu.RLock() + v := e.broadcast + e.mu.RUnlock() + return v, nil + + case tcpip.KeepaliveEnabledOption: + return false, nil + + case tcpip.MulticastLoopOption: + e.mu.RLock() + v := e.multicastLoop + e.mu.RUnlock() + return v, nil + case tcpip.ReceiveTOSOption: e.mu.RLock() v := e.receiveTOS @@ -748,6 +766,22 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.ReceiveIPPacketInfoOption: + e.mu.RLock() + v := e.receiveIPPacketInfo + e.mu.RUnlock() + return v, nil + + case tcpip.ReuseAddressOption: + return false, nil + + case tcpip.ReusePortOption: + e.mu.RLock() + v := e.reusePort + e.mu.RUnlock() + + return v, nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -760,19 +794,32 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return v, nil - case tcpip.ReceiveIPPacketInfoOption: - e.mu.RLock() - v := e.receiveIPPacketInfo - e.mu.RUnlock() - return v, nil + default: + return false, tcpip.ErrUnknownProtocolOption } - - return false, tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { + case tcpip.IPv4TOSOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + case tcpip.MulticastTTLOption: + e.mu.Lock() + v := int(e.multicastTTL) + e.mu.Unlock() + return v, nil + case tcpip.ReceiveQueueSizeOption: v := 0 e.rcvMu.Lock() @@ -794,29 +841,22 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { v := e.rcvBufSizeMax e.rcvMu.Unlock() return v, nil - } - return -1, tcpip.ErrUnknownProtocolOption + case tcpip.TTLOption: + e.mu.Lock() + v := int(e.ttl) + e.mu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: - return nil - - case *tcpip.TTLOption: - e.mu.Lock() - *o = tcpip.TTLOption(e.ttl) - e.mu.Unlock() - return nil - - case *tcpip.MulticastTTLOption: - e.mu.Lock() - *o = tcpip.MulticastTTLOption(e.multicastTTL) - e.mu.Unlock() - return nil - case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ @@ -824,67 +864,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.multicastAddr, } e.mu.Unlock() - return nil - - case *tcpip.MulticastLoopOption: - e.mu.RLock() - v := e.multicastLoop - e.mu.RUnlock() - - *o = tcpip.MulticastLoopOption(v) - return nil - - case *tcpip.ReuseAddressOption: - *o = 0 - return nil - - case *tcpip.ReusePortOption: - e.mu.RLock() - v := e.reusePort - e.mu.RUnlock() - - *o = 0 - if v { - *o = 1 - } - return nil case *tcpip.BindToDeviceOption: e.mu.RLock() *o = tcpip.BindToDeviceOption(e.bindToDevice) e.mu.RUnlock() - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - - case *tcpip.BroadcastOption: - e.mu.RLock() - v := e.broadcast - e.mu.RUnlock() - - *o = 0 - if v { - *o = 1 - } - return nil - - case *tcpip.IPv4TOSOption: - e.mu.RLock() - *o = tcpip.IPv4TOSOption(e.sendTOS) - e.mu.RUnlock() - return nil - - case *tcpip.IPv6TrafficClassOption: - e.mu.RLock() - *o = tcpip.IPv6TrafficClassOption(e.sendTOS) - e.mu.RUnlock() - return nil default: return tcpip.ErrUnknownProtocolOption } + return nil } // sendUDP sends a UDP segment via the provided network endpoint and under the diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 0905726c1..b3ee688b7 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -343,11 +343,11 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { c.createEndpoint(flow.sockProto()) if flow.isV6Only() { if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + c.t.Fatalf("SetSockOptBool failed: %s", err) } } else if flow.isBroadcast() { - if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil { - c.t.Fatal("SetSockOpt failed:", err) + if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + c.t.Fatalf("SetSockOptBool failed: %s", err) } } } @@ -1271,8 +1271,8 @@ func TestTTL(t *testing.T) { c.createEndpointForFlow(flow) const multicastTTL = 42 - if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil { + c.t.Fatalf("SetSockOptInt failed: %s", err) } var wantTTL uint8 @@ -1311,8 +1311,8 @@ func TestSetTTL(t *testing.T) { c.createEndpointForFlow(flow) - if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { + c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) } var p stack.NetworkProtocol @@ -1346,25 +1346,26 @@ func TestSetTOS(t *testing.T) { c.createEndpointForFlow(flow) const tos = testTOS - var v tcpip.IPv4TOSOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) + c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) } - if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err) + if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { + c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err) } - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) } - if want := tcpip.IPv4TOSOption(tos); v != want { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) + if v != tos { + c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos) } testWrite(c, flow, checker.TOS(tos, 0)) @@ -1381,25 +1382,26 @@ func TestSetTClass(t *testing.T) { c.createEndpointForFlow(flow) const tClass = testTOS - var v tcpip.IPv6TrafficClassOption - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) + c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0) } - if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil { - c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err) + if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil { + c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err) } - if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt(%T) failed: %s", v, err) + v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) } - if want := tcpip.IPv6TrafficClassOption(tClass); v != want { - c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) + if v != tClass { + c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass) } // The header getter for TClass is called TOS, so use that checker. @@ -1430,7 +1432,7 @@ func TestReceiveTosTClass(t *testing.T) { // Verify that setting and reading the option works. v, err := c.ep.GetSockOptBool(option) if err != nil { - c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) + c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) } // Test for expected default value. if v != false { @@ -1444,7 +1446,7 @@ func TestReceiveTosTClass(t *testing.T) { got, err := c.ep.GetSockOptBool(option) if err != nil { - c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) + c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) } if got != want { -- cgit v1.2.3 From 36fbaac5201365ffec4c323956f8465492c8a32c Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 14 Apr 2020 18:31:20 -0700 Subject: Attempt SLAAC address regeneration on DAD conflicts As per RFC 7217 section 6, attempt to regenerate IPv6 SLAAC address in response to a DAD conflict if the address was generated with an opaque IID as outlined in RFC 7217 section 5. Test: - stack_test.TestAutoGenAddrWithOpaqueIIDDADRetries - stack_test.TestAutoGenAddrWithEUI64IIDNoDADRetries - stack_test.TestAutoGenAddrContinuesLifetimesAfterRetry PiperOrigin-RevId: 306555645 --- pkg/tcpip/network/ipv6/ipv6_test.go | 66 ++++-- pkg/tcpip/stack/ndp.go | 210 +++++++++++----- pkg/tcpip/stack/ndp_test.go | 461 ++++++++++++++++++++++++++++++++++++ pkg/tcpip/stack/nic.go | 73 ++++-- 4 files changed, 703 insertions(+), 107 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 95e5dbf8e..841a0cb7a 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -34,6 +34,7 @@ const ( // The least significant 3 bytes are the same as addr2 so both addr2 and // addr3 will have the same solicited-node address. addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02" + addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03" // Tests use the extension header identifier values as uint8 instead of // header.IPv6ExtensionHeaderIdentifier. @@ -167,6 +168,8 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { // packets destined to the IPv6 solicited-node address of an assigned IPv6 // address. func TestReceiveOnSolicitedNodeAddr(t *testing.T) { + const nicID = 1 + tests := []struct { name string protocolFactory stack.TransportProtocol @@ -184,50 +187,61 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, TransportProtocols: []stack.TransportProtocol{test.protocolFactory}, }) - e := channel.New(10, 1280, linkAddr1) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + e := channel.New(1, 1280, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - // Should not receive a packet destined to the solicited - // node address of addr2/addr3 yet as we haven't added - // those addresses. + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + + // Should not receive a packet destined to the solicited node address of + // addr2/addr3 yet as we haven't added those addresses. test.rxf(t, s, e, addr1, snmc, 0) - if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err) + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) } - // Should receive a packet destined to the solicited - // node address of addr2/addr3 now that we have added - // added addr2. + // Should receive a packet destined to the solicited node address of + // addr2/addr3 now that we have added added addr2. test.rxf(t, s, e, addr1, snmc, 1) - if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err) + if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err) } - // Should still receive a packet destined to the - // solicited node address of addr2/addr3 now that we - // have added addr3. + // Should still receive a packet destined to the solicited node address of + // addr2/addr3 now that we have added addr3. test.rxf(t, s, e, addr1, snmc, 2) - if err := s.RemoveAddress(1, addr2); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err) + if err := s.RemoveAddress(nicID, addr2); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr2, err) } - // Should still receive a packet destined to the - // solicited node address of addr2/addr3 now that we - // have removed addr2. + // Should still receive a packet destined to the solicited node address of + // addr2/addr3 now that we have removed addr2. test.rxf(t, s, e, addr1, snmc, 3) - if err := s.RemoveAddress(1, addr3); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err) + // Make sure addr3's endpoint does not get removed from the NIC by + // incrementing its reference count with a route. + r, err := s.FindRoute(nicID, addr3, addr4, ProtocolNumber, false) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr3, addr4, ProtocolNumber, err) + } + defer r.Release() + + if err := s.RemoveAddress(nicID, addr3); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr3, err) } - // Should not receive a packet destined to the solicited - // node address of addr2/addr3 yet as both of them got - // removed. + // Should not receive a packet destined to the solicited node address of + // addr2/addr3 yet as both of them got removed, even though a route using + // addr3 exists. test.rxf(t, s, e, addr1, snmc, 3) }) } diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 7c9fc48d1..7f66c6c09 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -305,6 +305,15 @@ type NDPConfigurations struct { // lifetime(s) of the generated address changes; this option only // affects the generation of new addresses as part of SLAAC. AutoGenGlobalAddresses bool + + // AutoGenAddressConflictRetries determines how many times to attempt to retry + // generation of a permanent auto-generated address in response to DAD + // conflicts. + // + // If the method used to generate the address does not support creating + // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's + // MAC address), then no attempt will be made to resolve the conflict. + AutoGenAddressConflictRetries uint8 } // DefaultNDPConfigurations returns an NDPConfigurations populated with @@ -411,8 +420,23 @@ type slaacPrefixState struct { // Nonzero only when the address is not valid forever. validUntil time.Time + // Nonzero only when the address is not preferred forever. + preferredUntil time.Time + // The prefix's permanent address endpoint. + // + // May only be nil when a SLAAC address is being (re-)generated. Otherwise, + // must not be nil as all SLAAC prefixes must have a SLAAC address. ref *referencedNetworkEndpoint + + // The number of times a permanent address has been generated for the prefix. + // + // Addresses may be regenerated in reseponse to a DAD conflicts. + generationAttempts uint8 + + // The maximum number of times to attempt regeneration of a permanent SLAAC + // address in response to DAD conflicts. + maxGenerationAttempts uint8 } // startDuplicateAddressDetection performs Duplicate Address Detection. @@ -935,60 +959,83 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { return } - // If the preferred lifetime is zero, then the prefix should be considered - // deprecated. - deprecated := pl == 0 - ref := ndp.addSLAACAddr(prefix, deprecated) - if ref == nil { - // We were unable to generate a permanent address for prefix so do nothing - // further as there is no reason to maintain state for a SLAAC prefix we - // cannot generate a permanent address for. - return - } - state := slaacPrefixState{ deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { - prefixState, ok := ndp.slaacPrefixes[prefix] + state, ok := ndp.slaacPrefixes[prefix] if !ok { - panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the SLAAC prefix %s", prefix)) + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) } - ndp.deprecateSLAACAddress(prefixState.ref) + ndp.deprecateSLAACAddress(state.ref) }), invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { - ndp.invalidateSLAACPrefix(prefix, true) + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) + } + + ndp.invalidateSLAACPrefix(prefix, state) }), - ref: ref, + maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1, + } + + now := time.Now() + + // The time an address is preferred until is needed to properly generate the + // address. + if pl < header.NDPInfiniteLifetime { + state.preferredUntil = now.Add(pl) + } + + if !ndp.generateSLAACAddr(prefix, &state) { + // We were unable to generate an address for the prefix, we do not nothing + // further as there is no reason to maintain state or timers for a prefix we + // do not have an address for. + return } // Setup the initial timers to deprecate and invalidate prefix. - if !deprecated && pl < header.NDPInfiniteLifetime { + if pl < header.NDPInfiniteLifetime && pl != 0 { state.deprecationTimer.Reset(pl) } if vl < header.NDPInfiniteLifetime { state.invalidationTimer.Reset(vl) - state.validUntil = time.Now().Add(vl) + state.validUntil = now.Add(vl) } ndp.slaacPrefixes[prefix] = state } -// addSLAACAddr adds a SLAAC address for prefix. +// generateSLAACAddr generates a SLAAC address for prefix. +// +// Returns true if an address was successfully generated. +// +// Panics if the prefix is not a SLAAC prefix or it already has an address. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referencedNetworkEndpoint { +func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool { + if r := state.ref; r != nil { + panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix())) + } + + // If we have already reached the maximum address generation attempts for the + // prefix, do not generate another address. + if state.generationAttempts == state.maxGenerationAttempts { + return false + } + addrBytes := []byte(prefix.ID()) if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { addrBytes = header.AppendOpaqueInterfaceIdentifier( addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), - 0, /* dadCounter */ + state.generationAttempts, oIID.SecretKey, ) - } else { + } else if state.generationAttempts == 0 { // Only attempt to generate an interface-specific IID if we have a valid // link address. // @@ -996,12 +1043,16 @@ func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referen // LinkEndpoint.LinkAddress) before reaching this point. linkAddr := ndp.nic.linkEP.LinkAddress() if !header.IsValidUnicastEthernetAddress(linkAddr) { - return nil + return false } // Generate an address within prefix from the modified EUI-64 of ndp's NIC's // Ethernet MAC address. header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + } else { + // We have no way to regenerate an address when addresses are not generated + // with opaque IIDs. + return false } generatedAddr := tcpip.ProtocolAddress{ @@ -1014,26 +1065,52 @@ func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referen // If the nic already has this address, do nothing further. if ndp.nic.hasPermanentAddrLocked(generatedAddr.AddressWithPrefix.Address) { - return nil + return false } // Inform the integrator that we have a new SLAAC address. ndpDisp := ndp.nic.stack.ndpDisp if ndpDisp == nil { - return nil + return false } if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), generatedAddr.AddressWithPrefix) { // Informed by the integrator not to add the address. - return nil + return false } + deprecated := time.Since(state.preferredUntil) >= 0 ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) if err != nil { panic(fmt.Sprintf("ndp: error when adding address %+v: %s", generatedAddr, err)) } - return ref + state.generationAttempts++ + state.ref = ref + return true +} + +// regenerateSLAACAddr regenerates an address for a SLAAC prefix. +// +// If generating a new address for the prefix fails, the prefix will be +// invalidated. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate address for %s", prefix)) + } + + if ndp.generateSLAACAddr(prefix, &state) { + ndp.slaacPrefixes[prefix] = state + return + } + + // We were unable to generate a permanent address for the SLAAC prefix so + // invalidate the prefix as there is no reason to maintain state for a + // SLAAC prefix we do not have an address for. + ndp.invalidateSLAACPrefix(prefix, state) } // refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix. @@ -1060,9 +1137,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl tim // deprecation timer so it can be reset. prefixState.deprecationTimer.StopLocked() + now := time.Now() + // Reset the deprecation timer if prefix has a finite preferred lifetime. - if !deprecated && pl < header.NDPInfiniteLifetime { - prefixState.deprecationTimer.Reset(pl) + if pl < header.NDPInfiniteLifetime { + if !deprecated { + prefixState.deprecationTimer.Reset(pl) + } + prefixState.preferredUntil = now.Add(pl) + } else { + prefixState.preferredUntil = time.Time{} } // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix: @@ -1105,7 +1189,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl tim prefixState.invalidationTimer.StopLocked() prefixState.invalidationTimer.Reset(effectiveVl) - prefixState.validUntil = time.Now().Add(effectiveVl) + prefixState.validUntil = now.Add(effectiveVl) } // deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP @@ -1121,48 +1205,60 @@ func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { ref.deprecated = true if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{ - Address: ref.ep.ID().LocalAddress, - PrefixLen: ref.ep.PrefixLen(), - }) + ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix()) } } // invalidateSLAACPrefix invalidates a SLAAC prefix. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, removeAddr bool) { - state, ok := ndp.slaacPrefixes[prefix] - if !ok { - return +func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) { + if r := state.ref; r != nil { + // Since we are already invalidating the prefix, do not invalidate the + // prefix when removing the address. + if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACPrefixInvalidation */); err != nil { + panic(fmt.Sprintf("ndp: removePermanentIPv6EndpointLocked(%s, false): %s", r.addrWithPrefix(), err)) + } } - state.deprecationTimer.StopLocked() - state.invalidationTimer.StopLocked() - delete(ndp.slaacPrefixes, prefix) + ndp.cleanupSLAACPrefixResources(prefix, state) +} - addr := state.ref.ep.ID().LocalAddress +// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's +// resources. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) { + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) + } - if removeAddr { - if err := ndp.nic.removePermanentAddressLocked(addr); err != nil { - panic(fmt.Sprintf("ndp: removePermanentAddressLocked(%s): %s", addr, err)) - } + prefix := addr.Subnet() + state, ok := ndp.slaacPrefixes[prefix] + if !ok || state.ref == nil || addr.Address != state.ref.ep.ID().LocalAddress { + return } - if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { - ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: state.ref.ep.PrefixLen(), - }) + if !invalidatePrefix { + // If the prefix is not being invalidated, disassociate the address from the + // prefix and do nothing further. + state.ref = nil + ndp.slaacPrefixes[prefix] = state + return } + + ndp.cleanupSLAACPrefixResources(prefix, state) } -// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC -// address's resources from ndp. +// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry. +// +// Panics if the SLAAC prefix is not known. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix) { - ndp.invalidateSLAACPrefix(addr.Subnet(), false) +func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) { + state.deprecationTimer.StopLocked() + state.invalidationTimer.StopLocked() + delete(ndp.slaacPrefixes, prefix) } // cleanupState cleans up ndp's state. @@ -1181,7 +1277,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr func (ndp *ndpState) cleanupState(hostOnly bool) { linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() linkLocalPrefixes := 0 - for prefix := range ndp.slaacPrefixes { + for prefix, state := range ndp.slaacPrefixes { // RFC 4862 section 5 states that routers are also expected to generate a // link-local address so we do not invalidate them if we are cleaning up // host-only state. @@ -1190,7 +1286,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { continue } - ndp.invalidateSLAACPrefix(prefix, true) + ndp.invalidateSLAACPrefix(prefix, state) } if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes { diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index acb2d4731..6562a2d22 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -623,6 +623,12 @@ func TestDADFail(t *testing.T) { if want := (tcpip.AddressWithPrefix{}); addr != want { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } + + // Attempting to add the address again should not fail if the address's + // state was cleaned up when DAD failed. + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + } }) } } @@ -2783,6 +2789,461 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } } +// TestAutoGenAddrWithOpaqueIIDDADRetries tests the regeneration of an +// auto-generated IPv6 address in response to a DAD conflict. +func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) { + const nicID = 1 + const nicName = "nic" + const dadTransmits = 1 + const retransmitTimer = time.Second + const maxMaxRetries = 3 + const lifetimeSeconds = 10 + + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte + secretKey := secretKeyBuf[:] + n, err := rand.Read(secretKey) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) + } + + prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) + + for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { + for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { + addrTypes := []struct { + name string + ndpConfigs stack.NDPConfigurations + autoGenLinkLocal bool + subnet tcpip.Subnet + triggerSLAACFn func(e *channel.Endpoint) + }{ + { + name: "Global address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, + }, + subnet: subnet, + triggerSLAACFn: func(e *channel.Endpoint) { + // Receive an RA with prefix1 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + + }, + }, + { + name: "LinkLocal address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + AutoGenAddressConflictRetries: maxRetries, + }, + autoGenLinkLocal: true, + subnet: header.IPv6LinkLocalPrefix.Subnet(), + triggerSLAACFn: func(e *channel.Endpoint) {}, + }, + } + + for _, addrType := range addrTypes { + maxRetries := maxRetries + numFailures := numFailures + addrType := addrType + + t.Run(fmt.Sprintf("%s with %d max retries and %d failures", addrType.name, maxRetries, numFailures), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + addrType.triggerSLAACFn(e) + + // Simulate DAD conflicts so the address is regenerated. + for i := uint8(0); i < numFailures; i++ { + addrBytes := []byte(addrType.subnet.ID()) + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, i, secretKey)), + PrefixLen: 64, + } + expectAutoGenAddrEvent(addr, newAddr) + + // Should not have any addresses assigned to the NIC. + mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if want := (tcpip.AddressWithPrefix{}); mainAddr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want) + } + + // Simulate a DAD conflict. + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + + // Attempting to add the address manually should not fail if the + // address's state was cleaned up when DAD failed. + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) + } + if err := s.RemoveAddress(nicID, addr.Address); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) + } + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + } + + // Should not have any addresses assigned to the NIC. + mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if want := (tcpip.AddressWithPrefix{}); mainAddr != want { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want) + } + + // If we had less failures than generation attempts, we should have an + // address after DAD resolves. + if maxRetries+1 > numFailures { + addrBytes := []byte(addrType.subnet.ID()) + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, numFailures, secretKey)), + PrefixLen: 64, + } + expectAutoGenAddrEvent(addr, newAddr) + + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + + mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + } + if mainAddr != addr { + t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, addr) + } + } + + // Should not attempt address regeneration again. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) + case <-time.After(defaultAsyncEventTimeout): + } + }) + } + } + } +} + +// TestAutoGenAddrWithEUI64IIDNoDADRetries tests that a regeneration attempt is +// not made for SLAAC addresses generated with an IID based on the NIC's link +// address. +func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { + const nicID = 1 + const dadTransmits = 1 + const retransmitTimer = time.Second + const maxRetries = 3 + const lifetimeSeconds = 10 + + prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) + + addrTypes := []struct { + name string + ndpConfigs stack.NDPConfigurations + autoGenLinkLocal bool + subnet tcpip.Subnet + triggerSLAACFn func(e *channel.Endpoint) + }{ + { + name: "Global address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, + }, + subnet: subnet, + triggerSLAACFn: func(e *channel.Endpoint) { + // Receive an RA with prefix1 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + + }, + }, + { + name: "LinkLocal address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + AutoGenAddressConflictRetries: maxRetries, + }, + autoGenLinkLocal: true, + subnet: header.IPv6LinkLocalPrefix.Subnet(), + triggerSLAACFn: func(e *channel.Endpoint) {}, + }, + } + + for _, addrType := range addrTypes { + addrType := addrType + + t.Run(addrType.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + addrType.triggerSLAACFn(e) + + addrBytes := []byte(addrType.subnet.ID()) + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr1, addrBytes[header.IIDOffsetInIPv6Address:]) + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: 64, + } + expectAutoGenAddrEvent(addr, newAddr) + + // Simulate a DAD conflict. + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + + // Should not attempt address regeneration. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) + case <-time.After(defaultAsyncEventTimeout): + } + }) + } +} + +// TestAutoGenAddrContinuesLifetimesAfterRetry tests that retrying address +// generation in response to DAD conflicts does not refresh the lifetimes. +func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { + const nicID = 1 + const nicName = "nic" + const dadTransmits = 1 + const retransmitTimer = 2 * time.Second + const failureTimer = time.Second + const maxRetries = 1 + const lifetimeSeconds = 5 + + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte + secretKey := secretKeyBuf[:] + n, err := rand.Read(secretKey) + if err != nil { + t.Fatalf("rand.Read(_): %s", err) + } + if n != header.OpaqueIIDSecretKeyMinBytes { + t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) + } + + prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenAddressConflictRetries: maxRetries, + }, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + // Receive an RA with prefix in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + + addrBytes := []byte(subnet.ID()) + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 0, secretKey)), + PrefixLen: 64, + } + expectAutoGenAddrEvent(addr, newAddr) + + // Simulate a DAD conflict after some time has passed. + time.Sleep(failureTimer) + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + + // Let the next address resolve. + addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey)) + expectAutoGenAddrEvent(addr, newAddr) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + + // Address should be deprecated/invalidated after the lifetime expires. + // + // Note, the remaining lifetime is calculated from when the PI was first + // processed. Since we wait for some time before simulating a DAD conflict + // and more time for the new address to resolve, the new address is only + // expected to be valid for the remaining time. The DAD conflict should + // not have reset the lifetimes. + // + // We expect either just the invalidation event or the deprecation event + // followed by the invalidation event. + select { + case e := <-ndpDisp.autoGenAddrC: + if e.eventType == deprecatedAddr { + if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") + } + } else { + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + } + case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for auto gen addr event") + } +} + // TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event // to the integrator when an RA is received with the NDP Recursive DNS Server // option with at least one valid address. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 4835251bc..016dbe15e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1012,29 +1012,31 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress } - isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) + switch r.protocol { + case header.IPv6ProtocolNumber: + return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAAPrefixInvalidation */) + default: + r.expireLocked() + return nil + } +} + +func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACPrefixInvalidation bool) *tcpip.Error { + addr := r.addrWithPrefix() + + isIPv6Unicast := header.IsV6UnicastAddress(addr.Address) if isIPv6Unicast { - // If we are removing a tentative IPv6 unicast address, stop DAD. - if kind == permanentTentative { - n.mu.ndp.stopDuplicateAddressDetection(addr) - } + n.mu.ndp.stopDuplicateAddressDetection(addr.Address) // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. if r.configType == slaac { - n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: r.ep.PrefixLen(), - }) + n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACPrefixInvalidation) } } - r.setKind(permanentExpired) - if !r.decRefLocked() { - // The endpoint still has references to it. - return nil - } + r.expireLocked() // At this point the endpoint is deleted. @@ -1044,7 +1046,7 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node // multicast group may be left by user action. if isIPv6Unicast { - snmc := header.SolicitedNodeAddr(addr) + snmc := header.SolicitedNodeAddr(addr.Address) if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { return err } @@ -1425,10 +1427,12 @@ func (n *NIC) isAddrTentative(addr tcpip.Address) bool { return ref.getKind() == permanentTentative } -// dupTentativeAddrDetected attempts to inform n that a tentative addr -// is a duplicate on a link. +// dupTentativeAddrDetected attempts to inform n that a tentative addr is a +// duplicate on a link. // -// dupTentativeAddrDetected will delete the tentative address if it exists. +// dupTentativeAddrDetected will remove the tentative address if it exists. If +// the address was generated via SLAAC, an attempt will be made to generate a +// new address. func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() @@ -1442,7 +1446,17 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - return n.removePermanentAddressLocked(addr) + // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a + // new address will be generated for it. + if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACPrefixInvalidation */); err != nil { + return err + } + + if ref.configType == slaac { + n.mu.ndp.regenerateSLAACAddr(ref.addrWithPrefix().Subnet()) + } + + return nil } // setNDPConfigs sets the NDP configurations for n. @@ -1570,6 +1584,13 @@ type referencedNetworkEndpoint struct { deprecated bool } +func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix { + return tcpip.AddressWithPrefix{ + Address: r.ep.ID().LocalAddress, + PrefixLen: r.ep.PrefixLen(), + } +} + func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind))) } @@ -1597,6 +1618,13 @@ func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing) } +// expireLocked decrements the reference count and marks the permanent endpoint +// as expired. +func (r *referencedNetworkEndpoint) expireLocked() { + r.setKind(permanentExpired) + r.decRefLocked() +} + // decRef decrements the ref count and cleans up the endpoint once it reaches // zero. func (r *referencedNetworkEndpoint) decRef() { @@ -1606,14 +1634,11 @@ func (r *referencedNetworkEndpoint) decRef() { } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. Returns true if the endpoint was removed. -func (r *referencedNetworkEndpoint) decRefLocked() bool { +// locked. +func (r *referencedNetworkEndpoint) decRefLocked() { if atomic.AddInt32(&r.refs, -1) == 0 { r.nic.removeEndpointLocked(r) - return true } - - return false } // incRef increments the ref count. It must only be called when the caller is -- cgit v1.2.3 From b33c3bb4a73974bbae4274da5100a3cd3f5deef8 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 16 Apr 2020 17:26:12 -0700 Subject: Return detailed errors when iterating NDP options Test: header_test.TestNDPOptionsIterCheck PiperOrigin-RevId: 306953867 --- pkg/tcpip/header/BUILD | 1 + pkg/tcpip/header/ndp_options.go | 196 ++++++++++++++----------- pkg/tcpip/header/ndp_test.go | 118 ++++++++------- pkg/tcpip/header/ndpoptionidentifier_string.go | 50 +++++++ pkg/tcpip/stack/ndp.go | 3 +- 5 files changed, 228 insertions(+), 140 deletions(-) create mode 100644 pkg/tcpip/header/ndpoptionidentifier_string.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 7094f3f0b..0cde694dc 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -21,6 +21,7 @@ go_library( "ndp_options.go", "ndp_router_advert.go", "ndp_router_solicit.go", + "ndpoptionidentifier_string.go", "tcp.go", "udp.go", ], diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index e6a6ad39b..444e90820 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -15,32 +15,43 @@ package header import ( + "bytes" "encoding/binary" "errors" "fmt" + "io" "math" "time" "gvisor.dev/gvisor/pkg/tcpip" ) +// NDPOptionIdentifier is an NDP option type identifier. +type NDPOptionIdentifier uint8 + const ( // NDPSourceLinkLayerAddressOptionType is the type of the Source Link Layer // Address option, as per RFC 4861 section 4.6.1. - NDPSourceLinkLayerAddressOptionType = 1 + NDPSourceLinkLayerAddressOptionType NDPOptionIdentifier = 1 // NDPTargetLinkLayerAddressOptionType is the type of the Target Link Layer // Address option, as per RFC 4861 section 4.6.1. - NDPTargetLinkLayerAddressOptionType = 2 + NDPTargetLinkLayerAddressOptionType NDPOptionIdentifier = 2 + + // NDPPrefixInformationType is the type of the Prefix Information + // option, as per RFC 4861 section 4.6.2. + NDPPrefixInformationType NDPOptionIdentifier = 3 + + // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS + // Server option, as per RFC 8106 section 5.1. + NDPRecursiveDNSServerOptionType NDPOptionIdentifier = 25 +) +const ( // NDPLinkLayerAddressSize is the size of a Source or Target Link Layer // Address option for an Ethernet address. NDPLinkLayerAddressSize = 8 - // NDPPrefixInformationType is the type of the Prefix Information - // option, as per RFC 4861 section 4.6.2. - NDPPrefixInformationType = 3 - // ndpPrefixInformationLength is the expected length, in bytes, of the // body of an NDP Prefix Information option, as per RFC 4861 section // 4.6.2 which specifies that the Length field is 4. Given this, the @@ -91,10 +102,6 @@ const ( // within an NDPPrefixInformation. ndpPrefixInformationPrefixOffset = 14 - // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS - // Server option, as per RFC 8106 section 5.1. - NDPRecursiveDNSServerOptionType = 25 - // ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte // Lifetime field within an NDPRecursiveDNSServer. ndpRecursiveDNSServerLifetimeOffset = 2 @@ -103,10 +110,10 @@ const ( // for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer. ndpRecursiveDNSServerAddressesOffset = 6 - // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS - // Server option's length field value when it contains at least one - // IPv6 address. - minNDPRecursiveDNSServerLength = 3 + // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS Server + // option's body size when it contains at least one IPv6 address, as per + // RFC 8106 section 5.3.1. + minNDPRecursiveDNSServerBodySize = 22 // lengthByteUnits is the multiplier factor for the Length field of an // NDP option. That is, the length field for NDP options is in units of @@ -132,16 +139,13 @@ var ( // few NDPOption then modify the backing NDPOptions so long as the // NDPOptionIterator obtained before modification is no longer used. type NDPOptionIterator struct { - // The NDPOptions this NDPOptionIterator is iterating over. - opts NDPOptions + opts *bytes.Buffer } // Potential errors when iterating over an NDPOptions. var ( - ErrNDPOptBufExhausted = errors.New("Buffer unexpectedly exhausted") - ErrNDPOptZeroLength = errors.New("NDP option has zero-valued Length field") - ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body") - ErrNDPInvalidLength = errors.New("NDP option's Length value is invalid as per relevant RFC") + ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body") + ErrNDPOptMalformedHeader = errors.New("NDP option has a malformed header") ) // Next returns the next element in the backing NDPOptions, or true if we are @@ -152,48 +156,50 @@ var ( func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { for { // Do we still have elements to look at? - if len(i.opts) == 0 { + if i.opts.Len() == 0 { return nil, true, nil } - // Do we have enough bytes for an NDP option that has a Length - // field of at least 1? Note, 0 in the Length field is invalid. - if len(i.opts) < lengthByteUnits { - return nil, true, ErrNDPOptBufExhausted - } - // Get the Type field. - t := i.opts[0] - - // Get the Length field. - l := i.opts[1] + temp, err := i.opts.ReadByte() + if err != nil { + if err != io.EOF { + // ReadByte should only ever return nil or io.EOF. + panic(fmt.Sprintf("unexpected error when reading the option's Type field: %s", err)) + } - // This would indicate an erroneous NDP option as the Length - // field should never be 0. - if l == 0 { - return nil, true, ErrNDPOptZeroLength + // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once + // we start parsing an option; we expect the buffer to contain enough + // bytes for the whole option. + return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Type field: %w", io.ErrUnexpectedEOF) } + kind := NDPOptionIdentifier(temp) - // How many bytes are in the option body? - numBytes := int(l) * lengthByteUnits - numBodyBytes := numBytes - 2 - - potentialBody := i.opts[2:] + // Get the Length field. + length, err := i.opts.ReadByte() + if err != nil { + if err != io.EOF { + panic(fmt.Sprintf("unexpected error when reading the option's Length field for %s: %s", kind, err)) + } - // This would indicate an erroenous NDPOptions buffer as we ran - // out of the buffer in the middle of an NDP option. - if left := len(potentialBody); left < numBodyBytes { - return nil, true, ErrNDPOptBufExhausted + return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Length field for %s: %w", kind, io.ErrUnexpectedEOF) } - // Get only the options body, leaving the rest of the options - // buffer alone. - body := potentialBody[:numBodyBytes] + // This would indicate an erroneous NDP option as the Length field should + // never be 0. + if length == 0 { + return nil, true, fmt.Errorf("zero valued Length field for %s: %w", kind, ErrNDPOptMalformedHeader) + } - // Update opts with the remaining options body. - i.opts = i.opts[numBytes:] + // Get the body. + numBytes := int(length) * lengthByteUnits + numBodyBytes := numBytes - 2 + body := i.opts.Next(numBodyBytes) + if len(body) < numBodyBytes { + return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Body for %s: %w", kind, io.ErrUnexpectedEOF) + } - switch t { + switch kind { case NDPSourceLinkLayerAddressOptionType: return NDPSourceLinkLayerAddressOption(body), false, nil @@ -205,22 +211,15 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { // body is ndpPrefixInformationLength, as per RFC 4861 // section 4.6.2. if numBodyBytes != ndpPrefixInformationLength { - return nil, true, ErrNDPOptMalformedBody + return nil, true, fmt.Errorf("got %d bytes for NDP Prefix Information option's body, expected %d bytes: %w", numBodyBytes, ndpPrefixInformationLength, ErrNDPOptMalformedBody) } return NDPPrefixInformation(body), false, nil case NDPRecursiveDNSServerOptionType: - // RFC 8106 section 5.3.1 outlines that the RDNSS option - // must have a minimum length of 3 so it contains at - // least one IPv6 address. - if l < minNDPRecursiveDNSServerLength { - return nil, true, ErrNDPInvalidLength - } - opt := NDPRecursiveDNSServer(body) - if len(opt.Addresses()) == 0 { - return nil, true, ErrNDPOptMalformedBody + if err := opt.checkAddresses(); err != nil { + return nil, true, err } return opt, false, nil @@ -247,10 +246,16 @@ type NDPOptions []byte // // See NDPOptionIterator for more information. func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) { - it := NDPOptionIterator{opts: b} + it := NDPOptionIterator{ + opts: bytes.NewBuffer(b), + } if check { - for it2 := it; true; { + it2 := NDPOptionIterator{ + opts: bytes.NewBuffer(b), + } + + for { if _, done, err := it2.Next(); err != nil || done { return it, err } @@ -278,7 +283,7 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { continue } - b[0] = o.Type() + b[0] = byte(o.Type()) // We know this safe because paddedLength would have returned // 0 if o had an invalid length (> 255 * lengthByteUnits). @@ -304,7 +309,7 @@ type NDPOption interface { fmt.Stringer // Type returns the type of the receiver. - Type() uint8 + Type() NDPOptionIdentifier // Length returns the length of the body of the receiver, in bytes. Length() int @@ -386,7 +391,7 @@ func (b NDPOptionsSerializer) Length() int { type NDPSourceLinkLayerAddressOption tcpip.LinkAddress // Type implements NDPOption.Type. -func (o NDPSourceLinkLayerAddressOption) Type() uint8 { +func (o NDPSourceLinkLayerAddressOption) Type() NDPOptionIdentifier { return NDPSourceLinkLayerAddressOptionType } @@ -426,7 +431,7 @@ func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { type NDPTargetLinkLayerAddressOption tcpip.LinkAddress // Type implements NDPOption.Type. -func (o NDPTargetLinkLayerAddressOption) Type() uint8 { +func (o NDPTargetLinkLayerAddressOption) Type() NDPOptionIdentifier { return NDPTargetLinkLayerAddressOptionType } @@ -466,7 +471,7 @@ func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { type NDPPrefixInformation []byte // Type implements NDPOption.Type. -func (o NDPPrefixInformation) Type() uint8 { +func (o NDPPrefixInformation) Type() NDPOptionIdentifier { return NDPPrefixInformationType } @@ -590,7 +595,7 @@ type NDPRecursiveDNSServer []byte // Type returns the type of an NDP Recursive DNS Server option. // // Type implements NDPOption.Type. -func (NDPRecursiveDNSServer) Type() uint8 { +func (NDPRecursiveDNSServer) Type() NDPOptionIdentifier { return NDPRecursiveDNSServerOptionType } @@ -613,7 +618,12 @@ func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { // String implements fmt.Stringer.String. func (o NDPRecursiveDNSServer) String() string { - return fmt.Sprintf("%T(%s valid for %s)", o, o.Addresses(), o.Lifetime()) + lt := o.Lifetime() + addrs, err := o.Addresses() + if err != nil { + return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err) + } + return fmt.Sprintf("%T(%s valid for %s)", o, addrs, lt) } // Lifetime returns the length of time that the DNS server addresses @@ -632,29 +642,45 @@ func (o NDPRecursiveDNSServer) Lifetime() time.Duration { // Addresses returns the recursive DNS server IPv6 addresses that may be // used for name resolution. // -// Note, some of the addresses returned MAY be link-local addresses. +// Note, the addresses MAY be link-local addresses. +func (o NDPRecursiveDNSServer) Addresses() ([]tcpip.Address, error) { + var addrs []tcpip.Address + return addrs, o.iterAddresses(func(addr tcpip.Address) { addrs = append(addrs, addr) }) +} + +// checkAddresses iterates over the addresses in an NDP Recursive DNS Server +// option and returns any error it encounters. +func (o NDPRecursiveDNSServer) checkAddresses() error { + return o.iterAddresses(nil) +} + +// iterAddresses iterates over the addresses in an NDP Recursive DNS Server +// option and calls a function with each valid unicast IPv6 address. // -// Addresses may panic if o does not hold valid IPv6 addresses. -func (o NDPRecursiveDNSServer) Addresses() []tcpip.Address { - l := len(o) - if l < ndpRecursiveDNSServerAddressesOffset { - return nil +// Note, the addresses MAY be link-local addresses. +func (o NDPRecursiveDNSServer) iterAddresses(fn func(tcpip.Address)) error { + if l := len(o); l < minNDPRecursiveDNSServerBodySize { + return fmt.Errorf("got %d bytes for NDP Recursive DNS Server option's body, expected at least %d bytes: %w", l, minNDPRecursiveDNSServerBodySize, io.ErrUnexpectedEOF) } - l -= ndpRecursiveDNSServerAddressesOffset + o = o[ndpRecursiveDNSServerAddressesOffset:] + l := len(o) if l%IPv6AddressSize != 0 { - return nil + return fmt.Errorf("NDP Recursive DNS Server option's body ends in the middle of an IPv6 address (addresses body size = %d bytes): %w", l, ErrNDPOptMalformedBody) } - buf := o[ndpRecursiveDNSServerAddressesOffset:] - var addrs []tcpip.Address - for len(buf) > 0 { - addr := tcpip.Address(buf[:IPv6AddressSize]) + for i := 0; len(o) != 0; i++ { + addr := tcpip.Address(o[:IPv6AddressSize]) if !IsV6UnicastAddress(addr) { - return nil + return fmt.Errorf("%d-th address (%s) in NDP Recursive DNS Server option is not a valid unicast IPv6 address: %w", i, addr, ErrNDPOptMalformedBody) + } + + if fn != nil { + fn(addr) } - addrs = append(addrs, addr) - buf = buf[IPv6AddressSize:] + + o = o[IPv6AddressSize:] } - return addrs + + return nil } diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index 969f8f1e8..2341329c4 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -16,6 +16,8 @@ package header import ( "bytes" + "errors" + "io" "testing" "time" @@ -543,8 +545,12 @@ func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) { want := []tcpip.Address{ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", } - if got := opt.Addresses(); !cmp.Equal(got, want) { - t.Errorf("got Addresses = %v, want = %v", got, want) + addrs, err := opt.Addresses() + if err != nil { + t.Errorf("opt.Addresses() = %s", err) + } + if diff := cmp.Diff(addrs, want); diff != "" { + t.Errorf("mismatched addresses (-want +got):\n%s", diff) } // Iterator should not return anything else. @@ -638,8 +644,12 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) { if got := opt.Lifetime(); got != test.lifetime { t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) } - if got := opt.Addresses(); !cmp.Equal(got, test.addrs) { - t.Errorf("got Addresses = %v, want = %v", got, test.addrs) + addrs, err := opt.Addresses() + if err != nil { + t.Errorf("opt.Addresses() = %s", err) + } + if diff := cmp.Diff(addrs, test.addrs); diff != "" { + t.Errorf("mismatched addresses (-want +got):\n%s", diff) } // Iterator should not return anything else. @@ -661,38 +671,38 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) { // the iterator was returned for is malformed. func TestNDPOptionsIterCheck(t *testing.T) { tests := []struct { - name string - buf []byte - expected error + name string + buf []byte + expectedErr error }{ { - "ZeroLengthField", - []byte{0, 0, 0, 0, 0, 0, 0, 0}, - ErrNDPOptZeroLength, + name: "ZeroLengthField", + buf: []byte{0, 0, 0, 0, 0, 0, 0, 0}, + expectedErr: ErrNDPOptMalformedHeader, }, { - "ValidSourceLinkLayerAddressOption", - []byte{1, 1, 1, 2, 3, 4, 5, 6}, - nil, + name: "ValidSourceLinkLayerAddressOption", + buf: []byte{1, 1, 1, 2, 3, 4, 5, 6}, + expectedErr: nil, }, { - "TooSmallSourceLinkLayerAddressOption", - []byte{1, 1, 1, 2, 3, 4, 5}, - ErrNDPOptBufExhausted, + name: "TooSmallSourceLinkLayerAddressOption", + buf: []byte{1, 1, 1, 2, 3, 4, 5}, + expectedErr: io.ErrUnexpectedEOF, }, { - "ValidTargetLinkLayerAddressOption", - []byte{2, 1, 1, 2, 3, 4, 5, 6}, - nil, + name: "ValidTargetLinkLayerAddressOption", + buf: []byte{2, 1, 1, 2, 3, 4, 5, 6}, + expectedErr: nil, }, { - "TooSmallTargetLinkLayerAddressOption", - []byte{2, 1, 1, 2, 3, 4, 5}, - ErrNDPOptBufExhausted, + name: "TooSmallTargetLinkLayerAddressOption", + buf: []byte{2, 1, 1, 2, 3, 4, 5}, + expectedErr: io.ErrUnexpectedEOF, }, { - "ValidPrefixInformation", - []byte{ + name: "ValidPrefixInformation", + buf: []byte{ 3, 4, 43, 64, 1, 2, 3, 4, 5, 6, 7, 8, @@ -702,11 +712,11 @@ func TestNDPOptionsIterCheck(t *testing.T) { 17, 18, 19, 20, 21, 22, 23, 24, }, - nil, + expectedErr: nil, }, { - "TooSmallPrefixInformation", - []byte{ + name: "TooSmallPrefixInformation", + buf: []byte{ 3, 4, 43, 64, 1, 2, 3, 4, 5, 6, 7, 8, @@ -716,11 +726,11 @@ func TestNDPOptionsIterCheck(t *testing.T) { 17, 18, 19, 20, 21, 22, 23, }, - ErrNDPOptBufExhausted, + expectedErr: io.ErrUnexpectedEOF, }, { - "InvalidPrefixInformationLength", - []byte{ + name: "InvalidPrefixInformationLength", + buf: []byte{ 3, 3, 43, 64, 1, 2, 3, 4, 5, 6, 7, 8, @@ -728,11 +738,11 @@ func TestNDPOptionsIterCheck(t *testing.T) { 9, 10, 11, 12, 13, 14, 15, 16, }, - ErrNDPOptMalformedBody, + expectedErr: ErrNDPOptMalformedBody, }, { - "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation", - []byte{ + name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation", + buf: []byte{ // Source Link-Layer Address. 1, 1, 1, 2, 3, 4, 5, 6, @@ -749,11 +759,11 @@ func TestNDPOptionsIterCheck(t *testing.T) { 17, 18, 19, 20, 21, 22, 23, 24, }, - nil, + expectedErr: nil, }, { - "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", - []byte{ + name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", + buf: []byte{ // Source Link-Layer Address. 1, 1, 1, 2, 3, 4, 5, 6, @@ -775,52 +785,52 @@ func TestNDPOptionsIterCheck(t *testing.T) { 17, 18, 19, 20, 21, 22, 23, 24, }, - nil, + expectedErr: nil, }, { - "InvalidRecursiveDNSServerCutsOffAddress", - []byte{ + name: "InvalidRecursiveDNSServerCutsOffAddress", + buf: []byte{ 25, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 4, 5, 6, 7, }, - ErrNDPOptMalformedBody, + expectedErr: ErrNDPOptMalformedBody, }, { - "InvalidRecursiveDNSServerInvalidLengthField", - []byte{ + name: "InvalidRecursiveDNSServerInvalidLengthField", + buf: []byte{ 25, 2, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, }, - ErrNDPInvalidLength, + expectedErr: io.ErrUnexpectedEOF, }, { - "RecursiveDNSServerTooSmall", - []byte{ + name: "RecursiveDNSServerTooSmall", + buf: []byte{ 25, 1, 0, 0, 0, 0, 0, }, - ErrNDPOptBufExhausted, + expectedErr: io.ErrUnexpectedEOF, }, { - "RecursiveDNSServerMulticast", - []byte{ + name: "RecursiveDNSServerMulticast", + buf: []byte{ 25, 3, 0, 0, 0, 0, 0, 0, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, }, - ErrNDPOptMalformedBody, + expectedErr: ErrNDPOptMalformedBody, }, { - "RecursiveDNSServerUnspecified", - []byte{ + name: "RecursiveDNSServerUnspecified", + buf: []byte{ 25, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }, - ErrNDPOptMalformedBody, + expectedErr: ErrNDPOptMalformedBody, }, } @@ -828,8 +838,8 @@ func TestNDPOptionsIterCheck(t *testing.T) { t.Run(test.name, func(t *testing.T) { opts := NDPOptions(test.buf) - if _, err := opts.Iter(true); err != test.expected { - t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expected) + if _, err := opts.Iter(true); !errors.Is(err, test.expectedErr) { + t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expectedErr) } // test.buf may be malformed but we chose not to check diff --git a/pkg/tcpip/header/ndpoptionidentifier_string.go b/pkg/tcpip/header/ndpoptionidentifier_string.go new file mode 100644 index 000000000..6fe9a336b --- /dev/null +++ b/pkg/tcpip/header/ndpoptionidentifier_string.go @@ -0,0 +1,50 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type NDPOptionIdentifier ."; DO NOT EDIT. + +package header + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[NDPSourceLinkLayerAddressOptionType-1] + _ = x[NDPTargetLinkLayerAddressOptionType-2] + _ = x[NDPPrefixInformationType-3] + _ = x[NDPRecursiveDNSServerOptionType-25] +} + +const ( + _NDPOptionIdentifier_name_0 = "NDPSourceLinkLayerAddressOptionTypeNDPTargetLinkLayerAddressOptionTypeNDPPrefixInformationType" + _NDPOptionIdentifier_name_1 = "NDPRecursiveDNSServerOptionType" +) + +var ( + _NDPOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94} +) + +func (i NDPOptionIdentifier) String() string { + switch { + case 1 <= i && i <= 3: + i -= 1 + return _NDPOptionIdentifier_name_0[_NDPOptionIdentifier_index_0[i]:_NDPOptionIdentifier_index_0[i+1]] + case i == 25: + return _NDPOptionIdentifier_name_1 + default: + return "NDPOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")" + } +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 7f66c6c09..8140c6dd4 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -711,7 +711,8 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { continue } - ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), opt.Addresses(), opt.Lifetime()) + addrs, _ := opt.Addresses() + ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime()) case header.NDPPrefixInformation: prefix := opt.Subnet() -- cgit v1.2.3 From a551add5d8a5bf631cd9859c761e579fdb33ec82 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Mon, 13 Apr 2020 17:37:21 -0700 Subject: Remove View.First() and View.RemoveFirst() These methods let users eaily break the VectorisedView abstraction, and allowed netstack to slip into pseudo-enforcement of the "all headers are in the first View" invariant. Removing them and replacing with PullUp(n) breaks this reliance and will make it easier to add iptables support and rework network buffer management. The new View.PullUp(n) method is low cost in the common case, when when all the headers fit in the first View. --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++++++++---- pkg/tcpip/buffer/view_test.go | 113 +++++++++++++++++++++++++++++ pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 +++++++++++++---- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 +++-- pkg/tcpip/network/ipv4/ipv4.go | 12 ++- pkg/tcpip/network/ipv6/icmp.go | 74 ++++++++++++------- pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +++++- pkg/tcpip/stack/iptables_targets.go | 23 ++++-- pkg/tcpip/stack/nic.go | 34 +++------ pkg/tcpip/stack/packet_buffer.go | 4 +- pkg/tcpip/stack/stack_test.go | 10 ++- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++++--- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 ++- 25 files changed, 395 insertions(+), 147 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index ff1cfd8f6..55c0f04f3 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,12 +121,13 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.TCPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + if !ok { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(pkt.Data.First()) + tcpHeader = header.TCP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 3359418c1..04d03d494 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,12 +120,13 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.UDPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(pkt.Data.First()) + udpHeader = header.UDP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8ec5d5d5c..f01217c91 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,7 +77,8 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. +// TrimFront removes the first "count" bytes of the vectorised view. It panics +// if count > vv.Size(). func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -86,7 +87,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } } @@ -104,7 +105,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } if copied == 0 { return 0, io.EOF @@ -126,7 +127,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } return copied } @@ -162,22 +163,37 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// First returns the first view of the vectorised view. -func (vv *VectorisedView) First() View { +// PullUp returns the first "count" bytes of the vectorised view. If those +// bytes aren't already contiguous inside the vectorised view, PullUp will +// reallocate as needed to make them contiguous. PullUp fails and returns false +// when count > vv.Size(). +func (vv *VectorisedView) PullUp(count int) (View, bool) { if len(vv.views) == 0 { - return nil + return nil, count == 0 + } + if count <= len(vv.views[0]) { + return vv.views[0][:count], true + } + if count > vv.size { + return nil, false } - return vv.views[0] -} -// RemoveFirst removes the first view of the vectorised view. -func (vv *VectorisedView) RemoveFirst() { - if len(vv.views) == 0 { - return + newFirst := NewView(count) + i := 0 + for offset := 0; offset < count; i++ { + copy(newFirst[offset:], vv.views[i]) + if count-offset < len(vv.views[i]) { + vv.views[i].TrimFront(count - offset) + break + } + offset += len(vv.views[i]) + vv.views[i] = nil } - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] + // We're guaranteed that i > 0, since count is too large for the first + // view. + vv.views[i-1] = newFirst + vv.views = vv.views[i-1:] + return newFirst, true } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -225,3 +241,10 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } + +// removeFirst panics when len(vv.views) < 1. +func (vv *VectorisedView) removeFirst() { + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] +} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 106e1994c..c56795c7b 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,6 +16,7 @@ package buffer import ( + "bytes" "reflect" "testing" ) @@ -370,3 +371,115 @@ func TestVVRead(t *testing.T) { }) } } + +var pullUpTestCases = []struct { + comment string + in VectorisedView + count int + want []byte + result VectorisedView + ok bool +}{ + { + comment: "simple case", + in: vv(2, "12"), + count: 1, + want: []byte("1"), + result: vv(2, "12"), + ok: true, + }, + { + comment: "entire View", + in: vv(2, "1", "2"), + count: 1, + want: []byte("1"), + result: vv(2, "1", "2"), + ok: true, + }, + { + comment: "spanning across two Views", + in: vv(3, "1", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, + { + comment: "spanning across all Views", + in: vv(5, "1", "23", "45"), + count: 5, + want: []byte("12345"), + result: vv(5, "12345"), + ok: true, + }, + { + comment: "count = 0", + in: vv(1, "1"), + count: 0, + want: []byte{}, + result: vv(1, "1"), + ok: true, + }, + { + comment: "count = size", + in: vv(1, "1"), + count: 1, + want: []byte("1"), + result: vv(1, "1"), + ok: true, + }, + { + comment: "count too large", + in: vv(3, "1", "23"), + count: 4, + want: nil, + result: vv(3, "1", "23"), + ok: false, + }, + { + comment: "empty vv", + in: vv(0, ""), + count: 1, + want: nil, + result: vv(0, ""), + ok: false, + }, + { + comment: "empty vv, count = 0", + in: vv(0, ""), + count: 0, + want: nil, + result: vv(0, ""), + ok: true, + }, + { + comment: "empty views", + in: vv(3, "", "1", "", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, +} + +func TestPullUp(t *testing.T) { + for _, c := range pullUpTestCases { + got, ok := c.in.PullUp(c.count) + + // Is the return value right? + if ok != c.ok { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", + c.comment, c.count, c.in, ok, c.ok) + } + if bytes.Compare(got, View(c.want)) != 0 { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", + c.comment, c.count, c.in, got, c.want) + } + + // Is the underlying structure right? + if !reflect.DeepEqual(c.in, c.result) { + t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", + c.comment, c.count, c.in, c.result) + } + } +} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 1e2255bfa..073c84ef9 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // Reject the packet if it's shorter than an ethernet header. - if vv.Size() < header.EthernetMinimumSize { + // There should be an ethernet header at the beginning of vv. + hdr, ok := vv.PullUp(header.EthernetMinimumSize) + if !ok { + // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } - - // There should be an ethernet header at the beginning of vv. - linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 27ea3f531..33f640b85 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.First()) + rcvd := []byte(c.packets[0].vv.ToView()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index be2537a82..0799c8f4d 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,11 +171,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - first := pkt.Header.View() - if len(first) == 0 { - first = pkt.Data.First() - } - logPacket(prefix, protocol, first, gso) + logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -238,7 +234,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -247,28 +243,49 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie size := uint16(0) var fragmentOffset uint16 var moreFragments bool + + // Create a clone of pkt, including any headers if present. Avoid allocating + // backing memory for the clone. + views := [8]buffer.View{} + vv := buffer.NewVectorisedView(0, views[:0]) + vv.AppendView(pkt.Header.View()) + vv.Append(pkt.Data) + switch protocol { case header.IPv4ProtocolNumber: - ipv4 := header.IPv4(b) + hdr, ok := vv.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + ipv4 := header.IPv4(hdr) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - b = b[ipv4.HeaderLength():] + vv.TrimFront(int(ipv4.HeaderLength())) id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - ipv6 := header.IPv6(b) + hdr, ok := vv.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + ipv6 := header.IPv6(hdr) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - b = b[header.IPv6MinimumSize:] + vv.TrimFront(header.IPv6MinimumSize) case header.ARPProtocolNumber: - arp := header.ARP(b) + hdr, ok := vv.PullUp(header.ARPSize) + if !ok { + return + } + vv.TrimFront(header.ARPSize) + arp := header.ARP(hdr) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -284,7 +301,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -297,7 +314,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - icmp := header.ICMPv4(b) + hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + if !ok { + break + } + icmp := header.ICMPv4(hdr) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -330,7 +351,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.ICMPv6ProtocolNumber: transName = "icmp" - icmp := header.ICMPv6(b) + hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + if !ok { + break + } + icmp := header.ICMPv6(hdr) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -361,7 +386,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.UDPProtocolNumber: transName = "udp" - udp := header.UDP(b) + hdr, ok := vv.PullUp(header.UDPMinimumSize) + if !ok { + break + } + udp := header.UDP(hdr) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -371,7 +400,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.TCPProtocolNumber: transName = "tcp" - tcp := header.TCP(b) + hdr, ok := vv.PullUp(header.TCPMinimumSize) + if !ok { + break + } + tcp := header.TCP(hdr) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7acbfa0a8..cf73a939e 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -93,7 +93,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v := pkt.Data.First() + v, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return + } h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index c4bf1ba5c..4cbefe5ab 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,7 +25,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + hdr := header.IPv4(h) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -34,12 +38,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } - hlen := int(h.HeaderLength()) - if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { + hlen := int(hdr.HeaderLength()) + if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -48,15 +52,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := hdr.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv4MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 104aafbed..17202cc7a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -328,7 +328,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return tcpip.ErrInvalidOptionValue + } + ip := header.IPv4(h) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -378,7 +382,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b68983d10..bdf3a0d25 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,7 +28,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv6(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + hdr := header.IPv6(h) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -36,17 +40,21 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := h.TransportProtocol() + p := hdr.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(pkt.Data.First()) - if !f.IsValid() || f.FragmentOffset() != 0 { + f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) + if !ok { + return + } + fragHdr := header.IPv6Fragment(f) + if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -55,19 +63,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = f.TransportProtocol() + p = fragHdr.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv6MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) + if !ok { received.Invalid.Increment() return } @@ -76,11 +84,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // - // Only the first view in vv is accounted for by h. To account for the - // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.RemoveFirst() + payload.TrimFront(len(h)) if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -101,34 +107,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - if len(v) < header.ICMPv6PacketTooBigMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := h.MTU() + mtu := header.ICMPv6(hdr).MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - if len(v) < header.ICMPv6DstUnreachableMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch h.Code() { + switch header.ICMPv6(hdr).Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - ns := header.NDPNeighborSolicit(h.NDPPayload()) + // The remainder of payload must be only the neighbor solicitation, so + // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, + // NDP messages cannot be fragmented. Also note that in the common case NDP + // datagrams are very small and ToView() will not incur allocations. + ns := header.NDPNeighborSolicit(payload.ToView()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -286,12 +298,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - na := header.NDPNeighborAdvert(h.NDPPayload()) + // The remainder of payload must be only the neighbor advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + na := header.NDPNeighborAdvert(payload.ToView()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -363,14 +379,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, h) + copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -384,7 +401,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -406,8 +423,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - p := h.NDPPayload() - if len(p) < header.NDPRAMinimumSize || !isNDPValid() { + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -425,7 +443,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - ra := header.NDPRouterAdvert(p) + // The remainder of payload must be only the router advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + ra := header.NDPRouterAdvert(payload.ToView()) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index bd099a7f8..d412ff688 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,7 +166,8 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize}, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 331b0817b..486725131 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,7 +171,11 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index e9c652042..c7c663498 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,7 +70,10 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -473,7 +476,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -517,7 +520,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -564,7 +567,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -619,7 +622,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6c0a4b24d..6b91159d4 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,6 +212,11 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. +// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -226,7 +231,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -271,14 +278,21 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pk.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data.First(). + // pkt.Data. if pkt.NetworkHeader == nil { - pkt.NetworkHeader = pkt.Data.First() + var ok bool + pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // Precondition has been violated. + panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) + } } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 7b4543caf..8be61f4b1 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,9 +96,12 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView := newPkt.Data.First() + headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return RuleDrop, 0 + } netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] + newPkt.NetworkHeader = headerView hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -117,10 +120,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.UDPMinimumSize { + if pkt.Data.Size() < header.UDPMinimumSize { + return RuleDrop, 0 + } + hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) + if !ok { return RuleDrop, 0 } - udpHeader = header.UDP(newPkt.Data.First()) + udpHeader = header.UDP(hdr) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -128,10 +135,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.TCPMinimumSize { + if pkt.Data.Size() < header.TCPMinimumSize { return RuleDrop, 0 } - tcpHeader = header.TCP(newPkt.TransportHeader) + hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return RuleDrop, 0 + } + tcpHeader = header.TCP(hdr) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 016dbe15e..0c2b1f36a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1203,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - if len(pkt.Data.First()) < netProto.MinimumPacketSize() { + netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - - src, dst := netProto.ParseAddresses(pkt.Data.First()) + src, dst := netProto.ParseAddresses(netHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1289,22 +1289,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - - firstData := pkt.Data.First() - pkt.Data.RemoveFirst() - - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { - pkt.Header = buffer.NewPrependableFromView(firstData) - } else { - firstDataLen := len(firstData) - - // pkt.Header should have enough capacity to hold n.linkEP's headers. - pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) - - // TODO(b/151227689): avoid copying the packet when forwarding - if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) - } + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { + pkt.Header = buffer.NewPrependable(linkHeaderLen) } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1332,12 +1318,13 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(pkt.Data.First()) < transProto.MinimumPacketSize() { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1375,11 +1362,12 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(pkt.Data.First()) < 8 { + transHeader, ok := pkt.Data.PullUp(8) + if !ok { return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index dc125f25e..e954a8b7e 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,7 +37,9 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. + // down the stack, each layer adds to Header. Note that forwarded + // packets don't populate Headers on their way out -- their headers and + // payload are never parsed out and remain in Data. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c7634ceb1..d45d2cc1f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,16 +95,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := pkt.Data.First() - if len(nb) < fakeNetHeaderLen { + nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 3084e6593..a611e44ab 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,10 +642,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.Header.View()[0]; dst != 3 { + hdrs := p.Pkt.Data.ToView() + if dst := hdrs[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.Header.View()[1]; src != 1 { + if src := hdrs[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index feef8dca0..b1d820372 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(pkt.Data.First()) - if h.Type() != header.ICMPv4EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(pkt.Data.First()) - if h.Type() != header.ICMPv6EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) + if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 40461fd31..7712ce652 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,7 +144,11 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h := header.TCP(s.data.First()) + h, ok := s.data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + hdr := header.TCP(h) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -156,12 +160,16 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(h.DataOffset()) - if offset < header.TCPMinimumSize || offset > len(h) { + offset := int(hdr.DataOffset()) + if offset < header.TCPMinimumSize { + return false + } + hdrWithOpts, ok := s.data.PullUp(offset) + if !ok { return false } - s.options = []byte(h[header.TCPMinimumSize:offset]) + s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -173,18 +181,19 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - s.csum = h.Checksum() + hdr = header.TCP(hdrWithOpts) + s.csum = hdr.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = h.CalculateChecksum(xsum) + xsum = hdr.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(h.SequenceNumber()) - s.ackNumber = seqnum.Value(h.AckNumber()) - s.flags = h.Flags() - s.window = seqnum.Size(h.WindowSize()) + s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(hdr.AckNumber()) + s.flags = hdr.Flags() + s.window = seqnum.Size(hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index ab1014c7f..286c66cf5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3548,7 +3548,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3575,7 +3575,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index edb54f0be..756ab913a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: hdr.SourcePort(), + Port: header.UDP(hdr).SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 6e31a9bac..52af6de22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,8 +68,13 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + if int(header.UDP(h).Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From 486759a37d61b770019a81af00e0e733c655b3a8 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 17 Apr 2020 13:54:47 -0700 Subject: Support NDP DNS Search List option Inform the netstack integrator when the netstack receives an NDP Router Advertisement message with the NDP DNS Search List option with at least one domain name. The stack will not maintain any state related to the search list - the integrator is expected to maintain any required state and invalidate domain names after their lifetime expires, or refresh the lifetime when a new one is received for a known domain name. Test: - header_test.TestNDPDNSSearchListOption - header_test.TestNDPDNSSearchListOptionSerialize - header_test.TestNDPSearchListOptionDomainNameLabelInvalidSymbols - header_test.TestNDPOptionsIterCheck - stack_test.TestNDPDNSSearchListDispatch PiperOrigin-RevId: 307109375 --- pkg/tcpip/header/ndp_options.go | 213 +++++++++++++++ pkg/tcpip/header/ndp_test.go | 574 ++++++++++++++++++++++++++++++++++++++++ pkg/tcpip/stack/ndp.go | 18 ++ pkg/tcpip/stack/ndp_test.go | 125 +++++++++ 4 files changed, 930 insertions(+) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 444e90820..5d3975c56 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -45,6 +45,10 @@ const ( // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS // Server option, as per RFC 8106 section 5.1. NDPRecursiveDNSServerOptionType NDPOptionIdentifier = 25 + + // NDPDNSSearchListOptionType is the type of the DNS Search List option, + // as per RFC 8106 section 5.2. + NDPDNSSearchListOptionType = 31 ) const ( @@ -115,6 +119,27 @@ const ( // RFC 8106 section 5.3.1. minNDPRecursiveDNSServerBodySize = 22 + // ndpDNSSearchListLifetimeOffset is the start of the 4-byte + // Lifetime field within an NDPDNSSearchList. + ndpDNSSearchListLifetimeOffset = 2 + + // ndpDNSSearchListDomainNamesOffset is the start of the DNS search list + // domain names within an NDPDNSSearchList. + ndpDNSSearchListDomainNamesOffset = 6 + + // minNDPDNSSearchListBodySize is the minimum NDP DNS Search List option's + // body size when it contains at least one domain name, as per RFC 8106 + // section 5.3.1. + minNDPDNSSearchListBodySize = 14 + + // maxDomainNameLabelLength is the maximum length of a domain name + // label, as per RFC 1035 section 3.1. + maxDomainNameLabelLength = 63 + + // maxDomainNameLength is the maximum length of a domain name, including + // label AND label length octet, as per RFC 1035 section 3.1. + maxDomainNameLength = 255 + // lengthByteUnits is the multiplier factor for the Length field of an // NDP option. That is, the length field for NDP options is in units of // 8 octets, as per RFC 4861 section 4.6. @@ -224,6 +249,14 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { return opt, false, nil + case NDPDNSSearchListOptionType: + opt := NDPDNSSearchList(body) + if err := opt.checkDomainNames(); err != nil { + return nil, true, err + } + + return opt, false, nil + default: // We do not yet recognize the option, just skip for // now. This is okay because RFC 4861 allows us to @@ -684,3 +717,183 @@ func (o NDPRecursiveDNSServer) iterAddresses(fn func(tcpip.Address)) error { return nil } + +// NDPDNSSearchList is the NDP DNS Search List option, as defined by +// RFC 8106 section 5.2. +type NDPDNSSearchList []byte + +// Type implements NDPOption.Type. +func (o NDPDNSSearchList) Type() NDPOptionIdentifier { + return NDPDNSSearchListOptionType +} + +// Length implements NDPOption.Length. +func (o NDPDNSSearchList) Length() int { + return len(o) +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPDNSSearchList) serializeInto(b []byte) int { + used := copy(b, o) + + // Zero out the reserved bytes that are before the Lifetime field. + for i := 0; i < ndpDNSSearchListLifetimeOffset; i++ { + b[i] = 0 + } + + return used +} + +// String implements fmt.Stringer.String. +func (o NDPDNSSearchList) String() string { + lt := o.Lifetime() + domainNames, err := o.DomainNames() + if err != nil { + return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err) + } + return fmt.Sprintf("%T(%s valid for %s)", o, domainNames, lt) +} + +// Lifetime returns the length of time that the DNS search list of domain names +// in this option may be used for name resolution. +// +// Note, a value of 0 implies the domain names should no longer be used, +// and a value of infinity/forever is represented by NDPInfiniteLifetime. +func (o NDPDNSSearchList) Lifetime() time.Duration { + // The field is the time in seconds, as per RFC 8106 section 5.1. + return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpDNSSearchListLifetimeOffset:])) +} + +// DomainNames returns a DNS search list of domain names. +// +// DomainNames will parse the backing buffer as outlined by RFC 1035 section +// 3.1 and return a list of strings, with all domain names in lower case. +func (o NDPDNSSearchList) DomainNames() ([]string, error) { + var domainNames []string + return domainNames, o.iterDomainNames(func(domainName string) { domainNames = append(domainNames, domainName) }) +} + +// checkDomainNames iterates over the domain names in an NDP DNS Search List +// option and returns any error it encounters. +func (o NDPDNSSearchList) checkDomainNames() error { + return o.iterDomainNames(nil) +} + +// iterDomainNames iterates over the domain names in an NDP DNS Search List +// option and calls a function with each valid domain name. +func (o NDPDNSSearchList) iterDomainNames(fn func(string)) error { + if l := len(o); l < minNDPDNSSearchListBodySize { + return fmt.Errorf("got %d bytes for NDP DNS Search List option's body, expected at least %d bytes: %w", l, minNDPDNSSearchListBodySize, io.ErrUnexpectedEOF) + } + + var searchList bytes.Reader + searchList.Reset(o[ndpDNSSearchListDomainNamesOffset:]) + + var scratch [maxDomainNameLength]byte + domainName := bytes.NewBuffer(scratch[:]) + + // Parse the domain names, as per RFC 1035 section 3.1. + for searchList.Len() != 0 { + domainName.Reset() + + // Parse a label within a domain name, as per RFC 1035 section 3.1. + for { + // The first byte is the label length. + labelLenByte, err := searchList.ReadByte() + if err != nil { + if err != io.EOF { + // ReadByte should only ever return nil or io.EOF. + panic(fmt.Sprintf("unexpected error when reading a label's length: %s", err)) + } + + // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected + // once we start parsing a domain name; we expect the buffer to contain + // enough bytes for the whole domain name. + return fmt.Errorf("unexpected exhausted buffer while parsing a new label for a domain from NDP Search List option: %w", io.ErrUnexpectedEOF) + } + labelLen := int(labelLenByte) + + // A zero-length label implies the end of a domain name. + if labelLen == 0 { + // If the domain name is empty or we have no callback function, do + // nothing further with the current domain name. + if domainName.Len() == 0 || fn == nil { + break + } + + // Ignore the trailing period in the parsed domain name. + domainName.Truncate(domainName.Len() - 1) + fn(domainName.String()) + break + } + + // The label's length must not exceed the maximum length for a label. + if labelLen > maxDomainNameLabelLength { + return fmt.Errorf("label length of %d bytes is greater than the max label length of %d bytes for an NDP Search List option: %w", labelLen, maxDomainNameLabelLength, ErrNDPOptMalformedBody) + } + + // The label (and trailing period) must not make the domain name too long. + if labelLen+1 > domainName.Cap()-domainName.Len() { + return fmt.Errorf("label would make an NDP Search List option's domain name longer than the max domain name length of %d bytes: %w", maxDomainNameLength, ErrNDPOptMalformedBody) + } + + // Copy the label and add a trailing period. + for i := 0; i < labelLen; i++ { + b, err := searchList.ReadByte() + if err != nil { + if err != io.EOF { + panic(fmt.Sprintf("unexpected error when reading domain name's label: %s", err)) + } + + return fmt.Errorf("read %d out of %d bytes for a domain name's label from NDP Search List option: %w", i, labelLen, io.ErrUnexpectedEOF) + } + + // As per RFC 1035 section 2.3.1: + // 1) the label must only contain ASCII include letters, digits and + // hyphens + // 2) the first character in a label must be a letter + // 3) the last letter in a label must be a letter or digit + + if !isLetter(b) { + if i == 0 { + return fmt.Errorf("first character of a domain name's label in an NDP Search List option must be a letter, got character code = %d: %w", b, ErrNDPOptMalformedBody) + } + + if b == '-' { + if i == labelLen-1 { + return fmt.Errorf("last character of a domain name's label in an NDP Search List option must not be a hyphen (-): %w", ErrNDPOptMalformedBody) + } + } else if !isDigit(b) { + return fmt.Errorf("domain name's label in an NDP Search List option may only contain letters, digits and hyphens, got character code = %d: %w", b, ErrNDPOptMalformedBody) + } + } + + // If b is an upper case character, make it lower case. + if isUpperLetter(b) { + b = b - 'A' + 'a' + } + + if err := domainName.WriteByte(b); err != nil { + panic(fmt.Sprintf("unexpected error writing label to domain name buffer: %s", err)) + } + } + if err := domainName.WriteByte('.'); err != nil { + panic(fmt.Sprintf("unexpected error writing trailing period to domain name buffer: %s", err)) + } + } + } + + return nil +} + +func isLetter(b byte) bool { + return b >= 'a' && b <= 'z' || isUpperLetter(b) +} + +func isUpperLetter(b byte) bool { + return b >= 'A' && b <= 'Z' +} + +func isDigit(b byte) bool { + return b >= '0' && b <= '9' +} diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index 2341329c4..dc4591253 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -17,7 +17,9 @@ package header import ( "bytes" "errors" + "fmt" "io" + "regexp" "testing" "time" @@ -667,6 +669,477 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) { } } +// TestNDPDNSSearchListOption tests the getters of NDPDNSSearchList. +func TestNDPDNSSearchListOption(t *testing.T) { + tests := []struct { + name string + buf []byte + lifetime time.Duration + domainNames []string + err error + }{ + { + name: "Valid1Label", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, 'a', 'b', 'c', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: []string{ + "abc", + }, + err: nil, + }, + { + name: "Valid2Label", + buf: []byte{ + 0, 0, + 0, 0, 0, 5, + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 0, + 0, 0, 0, 0, 0, 0, + }, + lifetime: 5 * time.Second, + domainNames: []string{ + "abc.abcd", + }, + err: nil, + }, + { + name: "Valid3Label", + buf: []byte{ + 0, 0, + 1, 0, 0, 0, + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 1, 'e', + 0, + 0, 0, 0, 0, + }, + lifetime: 16777216 * time.Second, + domainNames: []string{ + "abc.abcd.e", + }, + err: nil, + }, + { + name: "Valid2Domains", + buf: []byte{ + 0, 0, + 1, 2, 3, 4, + 3, 'a', 'b', 'c', + 0, + 2, 'd', 'e', + 3, 'x', 'y', 'z', + 0, + 0, 0, 0, + }, + lifetime: 16909060 * time.Second, + domainNames: []string{ + "abc", + "de.xyz", + }, + err: nil, + }, + { + name: "Valid3DomainsMixedCase", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 3, 'a', 'B', 'c', + 0, + 2, 'd', 'E', + 3, 'X', 'y', 'z', + 0, + 1, 'J', + 0, + }, + lifetime: 0, + domainNames: []string{ + "abc", + "de.xyz", + "j", + }, + err: nil, + }, + { + name: "ValidDomainAfterNULL", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 3, 'a', 'B', 'c', + 0, 0, 0, 0, + 2, 'd', 'E', + 3, 'X', 'y', 'z', + 0, + }, + lifetime: 0, + domainNames: []string{ + "abc", + "de.xyz", + }, + err: nil, + }, + { + name: "Valid0Domains", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 0, + 0, 0, 0, 0, 0, 0, 0, + }, + lifetime: 0, + domainNames: nil, + err: nil, + }, + { + name: "NoTrailingNull", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 7, 'a', 'b', 'c', 'd', 'e', 'f', 'g', + }, + lifetime: 0, + domainNames: nil, + err: io.ErrUnexpectedEOF, + }, + { + name: "IncorrectLength", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 8, 'a', 'b', 'c', 'd', 'e', 'f', 'g', + }, + lifetime: 0, + domainNames: nil, + err: io.ErrUnexpectedEOF, + }, + { + name: "IncorrectLengthWithNULL", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 7, 'a', 'b', 'c', 'd', 'e', 'f', + 0, + }, + lifetime: 0, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "LabelOfLength63", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 0, + }, + lifetime: 0, + domainNames: []string{ + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk", + }, + err: nil, + }, + { + name: "LabelOfLength64", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 64, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', + 0, + }, + lifetime: 0, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "DomainNameOfLength255", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', + 0, + }, + lifetime: 0, + domainNames: []string{ + "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghij", + }, + err: nil, + }, + { + name: "DomainNameOfLength256", + buf: []byte{ + 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 0, + }, + lifetime: 0, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "StartingDigitForLabel", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, '9', 'b', 'c', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "StartingHyphenForLabel", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, '-', 'b', 'c', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "EndingHyphenForLabel", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, 'a', 'b', '-', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: nil, + err: ErrNDPOptMalformedBody, + }, + { + name: "EndingDigitForLabel", + buf: []byte{ + 0, 0, + 0, 0, 0, 1, + 3, 'a', 'b', '9', + 0, + 0, 0, 0, + }, + lifetime: time.Second, + domainNames: []string{ + "ab9", + }, + err: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opt := NDPDNSSearchList(test.buf) + + if got := opt.Lifetime(); got != test.lifetime { + t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime) + } + domainNames, err := opt.DomainNames() + if !errors.Is(err, test.err) { + t.Errorf("opt.DomainNames() = %s", err) + } + if diff := cmp.Diff(domainNames, test.domainNames); diff != "" { + t.Errorf("mismatched domain names (-want +got):\n%s", diff) + } + }) + } +} + +func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) { + for r := rune(0); r <= 255; r++ { + t.Run(fmt.Sprintf("RuneVal=%d", r), func(t *testing.T) { + buf := []byte{ + 0, 0, + 0, 0, 0, 0, + 3, 'a', 0 /* will be replaced */, 'c', + 0, + 0, 0, 0, + } + buf[8] = uint8(r) + opt := NDPDNSSearchList(buf) + + // As per RFC 1035 section 2.3.1, the label must only include ASCII + // letters, digits and hyphens (a-z, A-Z, 0-9, -). + var expectedErr error + re := regexp.MustCompile(`[a-zA-Z0-9-]`) + if !re.Match([]byte{byte(r)}) { + expectedErr = ErrNDPOptMalformedBody + } + + if domainNames, err := opt.DomainNames(); !errors.Is(err, expectedErr) { + t.Errorf("got opt.DomainNames() = (%s, %v), want = (_, %v)", domainNames, err, ErrNDPOptMalformedBody) + } + }) + } +} + +func TestNDPDNSSearchListOptionSerialize(t *testing.T) { + b := []byte{ + 9, 8, + 1, 0, 0, 0, + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 1, 'e', + 0, + } + targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} + expected := []byte{ + 31, 3, 0, 0, + 1, 0, 0, 0, + 3, 'a', 'b', 'c', + 4, 'a', 'b', 'c', 'd', + 1, 'e', + 0, + 0, 0, 0, 0, + } + opts := NDPOptions(targetBuf) + serializer := NDPOptionsSerializer{ + NDPDNSSearchList(b), + } + if got, want := opts.Serialize(serializer), len(expected); got != want { + t.Errorf("got Serialize = %d, want = %d", got, want) + } + if !bytes.Equal(targetBuf, expected) { + t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPDNSSearchListOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPDNSSearchListOptionType) + } + + opt, ok := next.(NDPDNSSearchList) + if !ok { + t.Fatalf("next (type = %T) cannot be casted to an NDPDNSSearchList", next) + } + if got := opt.Type(); got != 31 { + t.Errorf("got Type = %d, want = 31", got) + } + if got := opt.Length(); got != 22 { + t.Errorf("got Length = %d, want = 22", got) + } + if got, want := opt.Lifetime(), 16777216*time.Second; got != want { + t.Errorf("got Lifetime = %s, want = %s", got, want) + } + domainNames, err := opt.DomainNames() + if err != nil { + t.Errorf("opt.DomainNames() = %s", err) + } + if diff := cmp.Diff(domainNames, []string{"abc.abcd.e"}); diff != "" { + t.Errorf("domain names mismatch (-want +got):\n%s", diff) + } + + // Iterator should not return anything else. + next, done, err = it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } +} + // TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions // the iterator was returned for is malformed. func TestNDPOptionsIterCheck(t *testing.T) { @@ -832,6 +1305,107 @@ func TestNDPOptionsIterCheck(t *testing.T) { }, expectedErr: ErrNDPOptMalformedBody, }, + { + name: "DNSSearchListLargeCompliantRFC1035", + buf: []byte{ + 31, 33, 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', + 0, + }, + expectedErr: nil, + }, + { + name: "DNSSearchListNonCompliantRFC1035", + buf: []byte{ + 31, 33, 0, 0, + 0, 0, 0, 0, + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', + 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'j', 'k', + 0, + 0, 0, 0, 0, 0, 0, 0, 0, + }, + expectedErr: ErrNDPOptMalformedBody, + }, + { + name: "DNSSearchListValidSmall", + buf: []byte{ + 31, 2, 0, 0, + 0, 0, 0, 0, + 6, 'a', 'b', 'c', 'd', 'e', 'f', + 0, + }, + expectedErr: nil, + }, + { + name: "DNSSearchListTooSmall", + buf: []byte{ + 31, 1, 0, 0, + 0, 0, 0, + }, + expectedErr: io.ErrUnexpectedEOF, + }, } for _, test := range tests { diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 8140c6dd4..193a9dfde 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -241,6 +241,16 @@ type NDPDispatcher interface { // call functions on the stack itself. OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) + // OnDNSSearchListOption will be called when an NDP option with a DNS + // search list has been received. + // + // It is up to the caller to use the domain names in the search list + // for only their valid lifetime. OnDNSSearchListOption may be called + // with new or already known domain names. If called with known domain + // names, their valid lifetimes must be refreshed to lifetime (it may + // be increased, decreased or completely invalidated when lifetime = 0. + OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) + // OnDHCPv6Configuration will be called with an updated configuration that is // available via DHCPv6 for a specified NIC. // @@ -714,6 +724,14 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { addrs, _ := opt.Addresses() ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime()) + case header.NDPDNSSearchList: + if ndp.nic.stack.ndpDisp == nil { + continue + } + + domainNames, _ := opt.DomainNames() + ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime()) + case header.NDPPrefixInformation: prefix := opt.Subnet() diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 6562a2d22..6dd460984 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -133,6 +133,12 @@ type ndpRDNSSEvent struct { rdnss ndpRDNSS } +type ndpDNSSLEvent struct { + nicID tcpip.NICID + domainNames []string + lifetime time.Duration +} + type ndpDHCPv6Event struct { nicID tcpip.NICID configuration stack.DHCPv6ConfigurationFromNDPRA @@ -150,6 +156,8 @@ type ndpDispatcher struct { rememberPrefix bool autoGenAddrC chan ndpAutoGenAddrEvent rdnssC chan ndpRDNSSEvent + dnsslC chan ndpDNSSLEvent + routeTable []tcpip.Route dhcpv6ConfigurationC chan ndpDHCPv6Event } @@ -257,6 +265,17 @@ func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tc } } +// Implements stack.NDPDispatcher.OnDNSSearchListOption. +func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) { + if n.dnsslC != nil { + n.dnsslC <- ndpDNSSLEvent{ + nicID, + domainNames, + lifetime, + } + } +} + // Implements stack.NDPDispatcher.OnDHCPv6Configuration. func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) { if c := n.dhcpv6ConfigurationC; c != nil { @@ -3386,6 +3405,112 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } } +// TestNDPDNSSearchListDispatch tests that the integrator is informed when an +// NDP DNS Search List option is received with at least one domain name in the +// search list. +func TestNDPDNSSearchListDispatch(t *testing.T) { + const nicID = 1 + + ndpDisp := ndpDispatcher{ + dnsslC: make(chan ndpDNSSLEvent, 3), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + }, + NDPDisp: &ndpDisp, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + optSer := header.NDPOptionsSerializer{ + header.NDPDNSSearchList([]byte{ + 0, 0, + 0, 0, 0, 0, + 2, 'h', 'i', + 0, + }), + header.NDPDNSSearchList([]byte{ + 0, 0, + 0, 0, 0, 1, + 1, 'i', + 0, + 2, 'a', 'm', + 2, 'm', 'e', + 0, + }), + header.NDPDNSSearchList([]byte{ + 0, 0, + 0, 0, 1, 0, + 3, 'x', 'y', 'z', + 0, + 5, 'h', 'e', 'l', 'l', 'o', + 5, 'w', 'o', 'r', 'l', 'd', + 0, + 4, 't', 'h', 'i', 's', + 2, 'i', 's', + 1, 'a', + 4, 't', 'e', 's', 't', + 0, + }), + } + expected := []struct { + domainNames []string + lifetime time.Duration + }{ + { + domainNames: []string{ + "hi", + }, + lifetime: 0, + }, + { + domainNames: []string{ + "i", + "am.me", + }, + lifetime: time.Second, + }, + { + domainNames: []string{ + "xyz", + "hello.world", + "this.is.a.test", + }, + lifetime: 256 * time.Second, + }, + } + + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer)) + + for i, expected := range expected { + select { + case dnssl := <-ndpDisp.dnsslC: + if dnssl.nicID != nicID { + t.Errorf("got %d-th dnssl nicID = %d, want = %d", i, dnssl.nicID, nicID) + } + if diff := cmp.Diff(dnssl.domainNames, expected.domainNames); diff != "" { + t.Errorf("%d-th dnssl domain names mismatch (-want +got):\n%s", i, diff) + } + if dnssl.lifetime != expected.lifetime { + t.Errorf("got %d-th dnssl lifetime = %s, want = %s", i, dnssl.lifetime, expected.lifetime) + } + default: + t.Fatal("expected a DNSSL event") + } + } + + // Should have no more DNSSL options. + select { + case <-ndpDisp.dnsslC: + t.Fatal("unexpectedly got a DNSSL event") + default: + } +} + // TestCleanupNDPState tests that all discovered routers and prefixes, and // auto-generated addresses are invalidated when a NIC becomes a router. func TestCleanupNDPState(t *testing.T) { -- cgit v1.2.3 From 120d3b50f4875824ec69f0cc39a09ac84fced35c Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Tue, 21 Apr 2020 07:15:25 -0700 Subject: Automated rollback of changelist 307477185 PiperOrigin-RevId: 307598974 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++---------- pkg/tcpip/buffer/view_test.go | 113 ----------------------------- pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 ++++------------- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 ++--- pkg/tcpip/network/ipv4/ipv4.go | 12 +-- pkg/tcpip/network/ipv6/icmp.go | 74 +++++++------------ pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +----- pkg/tcpip/stack/iptables_targets.go | 23 ++---- pkg/tcpip/stack/nic.go | 34 ++++++--- pkg/tcpip/stack/packet_buffer.go | 4 +- pkg/tcpip/stack/stack_test.go | 10 +-- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++----- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 +-- 25 files changed, 147 insertions(+), 395 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 55c0f04f3..ff1cfd8f6 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,13 +121,12 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.TCPMinimumSize { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(pkt.Data.First()) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 04d03d494..3359418c1 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,13 +120,12 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.UDPMinimumSize { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(pkt.Data.First()) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index f01217c91..8ec5d5d5c 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,8 +77,7 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. It panics -// if count > vv.Size(). +// TrimFront removes the first "count" bytes of the vectorised view. func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -87,7 +86,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } } @@ -105,7 +104,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } if copied == 0 { return 0, io.EOF @@ -127,7 +126,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } return copied } @@ -163,37 +162,22 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// PullUp returns the first "count" bytes of the vectorised view. If those -// bytes aren't already contiguous inside the vectorised view, PullUp will -// reallocate as needed to make them contiguous. PullUp fails and returns false -// when count > vv.Size(). -func (vv *VectorisedView) PullUp(count int) (View, bool) { +// First returns the first view of the vectorised view. +func (vv *VectorisedView) First() View { if len(vv.views) == 0 { - return nil, count == 0 - } - if count <= len(vv.views[0]) { - return vv.views[0][:count], true - } - if count > vv.size { - return nil, false + return nil } + return vv.views[0] +} - newFirst := NewView(count) - i := 0 - for offset := 0; offset < count; i++ { - copy(newFirst[offset:], vv.views[i]) - if count-offset < len(vv.views[i]) { - vv.views[i].TrimFront(count - offset) - break - } - offset += len(vv.views[i]) - vv.views[i] = nil +// RemoveFirst removes the first view of the vectorised view. +func (vv *VectorisedView) RemoveFirst() { + if len(vv.views) == 0 { + return } - // We're guaranteed that i > 0, since count is too large for the first - // view. - vv.views[i-1] = newFirst - vv.views = vv.views[i-1:] - return newFirst, true + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -241,10 +225,3 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } - -// removeFirst panics when len(vv.views) < 1. -func (vv *VectorisedView) removeFirst() { - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] -} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index c56795c7b..106e1994c 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,7 +16,6 @@ package buffer import ( - "bytes" "reflect" "testing" ) @@ -371,115 +370,3 @@ func TestVVRead(t *testing.T) { }) } } - -var pullUpTestCases = []struct { - comment string - in VectorisedView - count int - want []byte - result VectorisedView - ok bool -}{ - { - comment: "simple case", - in: vv(2, "12"), - count: 1, - want: []byte("1"), - result: vv(2, "12"), - ok: true, - }, - { - comment: "entire View", - in: vv(2, "1", "2"), - count: 1, - want: []byte("1"), - result: vv(2, "1", "2"), - ok: true, - }, - { - comment: "spanning across two Views", - in: vv(3, "1", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, - { - comment: "spanning across all Views", - in: vv(5, "1", "23", "45"), - count: 5, - want: []byte("12345"), - result: vv(5, "12345"), - ok: true, - }, - { - comment: "count = 0", - in: vv(1, "1"), - count: 0, - want: []byte{}, - result: vv(1, "1"), - ok: true, - }, - { - comment: "count = size", - in: vv(1, "1"), - count: 1, - want: []byte("1"), - result: vv(1, "1"), - ok: true, - }, - { - comment: "count too large", - in: vv(3, "1", "23"), - count: 4, - want: nil, - result: vv(3, "1", "23"), - ok: false, - }, - { - comment: "empty vv", - in: vv(0, ""), - count: 1, - want: nil, - result: vv(0, ""), - ok: false, - }, - { - comment: "empty vv, count = 0", - in: vv(0, ""), - count: 0, - want: nil, - result: vv(0, ""), - ok: true, - }, - { - comment: "empty views", - in: vv(3, "", "1", "", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, -} - -func TestPullUp(t *testing.T) { - for _, c := range pullUpTestCases { - got, ok := c.in.PullUp(c.count) - - // Is the return value right? - if ok != c.ok { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", - c.comment, c.count, c.in, ok, c.ok) - } - if bytes.Compare(got, View(c.want)) != 0 { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", - c.comment, c.count, c.in, got, c.want) - } - - // Is the underlying structure right? - if !reflect.DeepEqual(c.in, c.result) { - t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", - c.comment, c.count, c.in, c.result) - } - } -} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 073c84ef9..1e2255bfa 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // There should be an ethernet header at the beginning of vv. - hdr, ok := vv.PullUp(header.EthernetMinimumSize) - if !ok { - // Reject the packet if it's shorter than an ethernet header. + // Reject the packet if it's shorter than an ethernet header. + if vv.Size() < header.EthernetMinimumSize { return tcpip.ErrBadAddress } - linkHeader := header.Ethernet(hdr) + + // There should be an ethernet header at the beginning of vv. + linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 33f640b85..27ea3f531 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.ToView()) + rcvd := []byte(c.packets[0].vv.First()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 0799c8f4d..be2537a82 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,7 +171,11 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - logPacket(prefix, protocol, pkt, gso) + first := pkt.Header.View() + if len(first) == 0 { + first = pkt.Data.First() + } + logPacket(prefix, protocol, first, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -234,7 +238,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -243,49 +247,28 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P size := uint16(0) var fragmentOffset uint16 var moreFragments bool - - // Create a clone of pkt, including any headers if present. Avoid allocating - // backing memory for the clone. - views := [8]buffer.View{} - vv := buffer.NewVectorisedView(0, views[:0]) - vv.AppendView(pkt.Header.View()) - vv.Append(pkt.Data) - switch protocol { case header.IPv4ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - ipv4 := header.IPv4(hdr) + ipv4 := header.IPv4(b) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - vv.TrimFront(int(ipv4.HeaderLength())) + b = b[ipv4.HeaderLength():] id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv6MinimumSize) - if !ok { - return - } - ipv6 := header.IPv6(hdr) + ipv6 := header.IPv6(b) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - vv.TrimFront(header.IPv6MinimumSize) + b = b[header.IPv6MinimumSize:] case header.ARPProtocolNumber: - hdr, ok := vv.PullUp(header.ARPSize) - if !ok { - return - } - vv.TrimFront(header.ARPSize) - arp := header.ARP(hdr) + arp := header.ARP(b) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -301,7 +284,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -314,11 +297,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) - if !ok { - break - } - icmp := header.ICMPv4(hdr) + icmp := header.ICMPv4(b) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -351,11 +330,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) - if !ok { - break - } - icmp := header.ICMPv6(hdr) + icmp := header.ICMPv6(b) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -386,11 +361,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.UDPProtocolNumber: transName = "udp" - hdr, ok := vv.PullUp(header.UDPMinimumSize) - if !ok { - break - } - udp := header.UDP(hdr) + udp := header.UDP(b) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -400,11 +371,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.TCPProtocolNumber: transName = "tcp" - hdr, ok := vv.PullUp(header.TCPMinimumSize) - if !ok { - break - } - tcp := header.TCP(hdr) + tcp := header.TCP(b) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index cf73a939e..7acbfa0a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -93,10 +93,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v, ok := pkt.Data.PullUp(header.ARPSize) - if !ok { - return - } + v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 4cbefe5ab..c4bf1ba5c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,11 +25,7 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - hdr := header.IPv4(h) + h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -38,12 +34,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { return } - hlen := int(hdr.HeaderLength()) - if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { + hlen := int(h.HeaderLength()) + if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -52,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := h.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok { + v := pkt.Data.First() + if len(v) < header.ICMPv4MinimumSize { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 17202cc7a..104aafbed 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -328,11 +328,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return tcpip.ErrInvalidOptionValue - } - ip := header.IPv4(h) + ip := header.IPv4(pkt.Data.First()) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -382,11 +378,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } + headerView := pkt.Data.First() h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdf3a0d25..b68983d10 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,11 +28,7 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - return - } - hdr := header.IPv6(h) + h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -40,21 +36,17 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := hdr.TransportProtocol() + p := h.TransportProtocol() if p == header.IPv6FragmentHeader { - f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) - if !ok { - return - } - fragHdr := header.IPv6Fragment(f) - if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { + f := header.IPv6Fragment(pkt.Data.First()) + if !f.IsValid() || f.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -63,19 +55,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = fragHdr.TransportProtocol() + p = f.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) - if !ok { + v := pkt.Data.First() + if len(v) < header.ICMPv6MinimumSize { received.Invalid.Increment() return } @@ -84,9 +76,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // + // Only the first view in vv is accounted for by h. To account for the + // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.TrimFront(len(h)) + payload.RemoveFirst() if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -107,40 +101,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) - if !ok { + if len(v) < header.ICMPv6PacketTooBigMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := header.ICMPv6(hdr).MTU() + mtu := h.MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) - if !ok { + if len(v) < header.ICMPv6DstUnreachableMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch header.ICMPv6(hdr).Code() { + switch h.Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - // The remainder of payload must be only the neighbor solicitation, so - // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, - // NDP messages cannot be fragmented. Also note that in the common case NDP - // datagrams are very small and ToView() will not incur allocations. - ns := header.NDPNeighborSolicit(payload.ToView()) + ns := header.NDPNeighborSolicit(h.NDPPayload()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -298,16 +286,12 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - // The remainder of payload must be only the neighbor advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - na := header.NDPNeighborAdvert(payload.ToView()) + na := header.NDPNeighborAdvert(h.NDPPayload()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -379,15 +363,14 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) - if !ok { + if len(v) < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, icmpHdr) + copy(packet, h) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -401,7 +384,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { + if len(v) < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -423,9 +406,8 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - // Is the NDP payload of sufficient size to hold a Router - // Advertisement? - if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { + p := h.NDPPayload() + if len(p) < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -443,11 +425,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - // The remainder of payload must be only the router advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - ra := header.NDPRouterAdvert(payload.ToView()) + ra := header.NDPRouterAdvert(p) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d412ff688..bd099a7f8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,8 +166,7 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - }, + size: header.ICMPv6NeighborSolicitMinimumSize}, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 486725131..331b0817b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,11 +171,7 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } + headerView := pkt.Data.First() h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c7c663498..e9c652042 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,10 +70,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) - if !ok { - return - } + b := pkt.Data.First() pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -476,7 +473,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -520,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -567,7 +564,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -622,7 +619,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6b91159d4..6c0a4b24d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,11 +212,6 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. -// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -231,9 +226,7 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Precondition: pkt.NetworkHeader is set. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -278,21 +271,14 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Precondition: pk.NetworkHeader is set. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data. + // pkt.Data.First(). if pkt.NetworkHeader == nil { - var ok bool - pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // Precondition has been violated. - panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) - } + pkt.NetworkHeader = pkt.Data.First() } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 8be61f4b1..7b4543caf 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,12 +96,9 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return RuleDrop, 0 - } + headerView := newPkt.Data.First() netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView + newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -120,14 +117,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if pkt.Data.Size() < header.UDPMinimumSize { - return RuleDrop, 0 - } - hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.UDPMinimumSize { return RuleDrop, 0 } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(newPkt.Data.First()) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -135,14 +128,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if pkt.Data.Size() < header.TCPMinimumSize { + if len(pkt.Data.First()) < header.TCPMinimumSize { return RuleDrop, 0 } - hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) - if !ok { - return RuleDrop, 0 - } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(newPkt.TransportHeader) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0c2b1f36a..016dbe15e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1203,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) - if !ok { + if len(pkt.Data.First()) < netProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(netHeader) + + src, dst := netProto.ParseAddresses(pkt.Data.First()) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1289,8 +1289,22 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { - pkt.Header = buffer.NewPrependable(linkHeaderLen) + + firstData := pkt.Data.First() + pkt.Data.RemoveFirst() + + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { + pkt.Header = buffer.NewPrependableFromView(firstData) + } else { + firstDataLen := len(firstData) + + // pkt.Header should have enough capacity to hold n.linkEP's headers. + pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) + + // TODO(b/151227689): avoid copying the packet when forwarding + if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) + } } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1318,13 +1332,12 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + if len(pkt.Data.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1362,12 +1375,11 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - transHeader, ok := pkt.Data.PullUp(8) - if !ok { + if len(pkt.Data.First()) < 8 { return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index e954a8b7e..dc125f25e 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,9 +37,7 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. Note that forwarded - // packets don't populate Headers on their way out -- their headers and - // payload are never parsed out and remain in Data. + // down the stack, each layer adds to Header. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d45d2cc1f..c7634ceb1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,18 +95,16 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { - return - } + b := pkt.Data.First() pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { + nb := pkt.Data.First() + if len(nb) < fakeNetHeaderLen { return } + pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a611e44ab..3084e6593 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,11 +642,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - hdrs := p.Pkt.Data.ToView() - if dst := hdrs[0]; dst != 3 { + if dst := p.Pkt.Header.View()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := hdrs[1]; src != 1 { + if src := p.Pkt.Header.View()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b1d820372..feef8dca0 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.Data.First()) + if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) - if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.Data.First()) + if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 7712ce652..40461fd31 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,11 +144,7 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h, ok := s.data.PullUp(header.TCPMinimumSize) - if !ok { - return false - } - hdr := header.TCP(h) + h := header.TCP(s.data.First()) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -160,16 +156,12 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(hdr.DataOffset()) - if offset < header.TCPMinimumSize { - return false - } - hdrWithOpts, ok := s.data.PullUp(offset) - if !ok { + offset := int(h.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(h) { return false } - s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) + s.options = []byte(h[header.TCPMinimumSize:offset]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -181,19 +173,18 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - hdr = header.TCP(hdrWithOpts) - s.csum = hdr.Checksum() + s.csum = h.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = hdr.CalculateChecksum(xsum) + xsum = h.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) - s.ackNumber = seqnum.Value(hdr.AckNumber()) - s.flags = hdr.Flags() - s.window = seqnum.Size(hdr.WindowSize()) + s.sequenceNumber = seqnum.Value(h.SequenceNumber()) + s.ackNumber = seqnum.Value(h.AckNumber()) + s.flags = h.Flags() + s.window = seqnum.Size(h.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 286c66cf5..ab1014c7f 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3548,7 +3548,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3575,7 +3575,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 756ab913a..edb54f0be 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 52af6de22..6e31a9bac 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,13 +68,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Malformed packet. - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true - } - if int(header.UDP(h).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From eccae0f77d3708d591119488f427eca90de7c711 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 23 Apr 2020 17:27:24 -0700 Subject: Remove View.First() and View.RemoveFirst() These methods let users eaily break the VectorisedView abstraction, and allowed netstack to slip into pseudo-enforcement of the "all headers are in the first View" invariant. Removing them and replacing with PullUp(n) breaks this reliance and will make it easier to add iptables support and rework network buffer management. The new View.PullUp(n) method is low cost in the common case, when when all the headers fit in the first View. PiperOrigin-RevId: 308163542 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++++++++---- pkg/tcpip/buffer/view_test.go | 113 +++++++++++++++++++++++++++++ pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/rawfile/BUILD | 9 ++- pkg/tcpip/link/rawfile/rawfile_test.go | 46 ++++++++++++ pkg/tcpip/link/rawfile/rawfile_unsafe.go | 6 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 +++++++++++++---- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 +++-- pkg/tcpip/network/ipv4/ipv4.go | 12 ++- pkg/tcpip/network/ipv6/icmp.go | 74 ++++++++++++------- pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +++++- pkg/tcpip/stack/iptables_targets.go | 23 ++++-- pkg/tcpip/stack/nic.go | 34 +++------ pkg/tcpip/stack/packet_buffer.go | 8 +- pkg/tcpip/stack/stack_test.go | 10 ++- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++++--- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 ++- 28 files changed, 458 insertions(+), 149 deletions(-) create mode 100644 pkg/tcpip/link/rawfile/rawfile_test.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index ff1cfd8f6..55c0f04f3 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,12 +121,13 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.TCPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + if !ok { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(pkt.Data.First()) + tcpHeader = header.TCP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 3359418c1..04d03d494 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,12 +120,13 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.UDPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(pkt.Data.First()) + udpHeader = header.UDP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8ec5d5d5c..f01217c91 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,7 +77,8 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. +// TrimFront removes the first "count" bytes of the vectorised view. It panics +// if count > vv.Size(). func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -86,7 +87,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } } @@ -104,7 +105,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } if copied == 0 { return 0, io.EOF @@ -126,7 +127,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } return copied } @@ -162,22 +163,37 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// First returns the first view of the vectorised view. -func (vv *VectorisedView) First() View { +// PullUp returns the first "count" bytes of the vectorised view. If those +// bytes aren't already contiguous inside the vectorised view, PullUp will +// reallocate as needed to make them contiguous. PullUp fails and returns false +// when count > vv.Size(). +func (vv *VectorisedView) PullUp(count int) (View, bool) { if len(vv.views) == 0 { - return nil + return nil, count == 0 + } + if count <= len(vv.views[0]) { + return vv.views[0][:count], true + } + if count > vv.size { + return nil, false } - return vv.views[0] -} -// RemoveFirst removes the first view of the vectorised view. -func (vv *VectorisedView) RemoveFirst() { - if len(vv.views) == 0 { - return + newFirst := NewView(count) + i := 0 + for offset := 0; offset < count; i++ { + copy(newFirst[offset:], vv.views[i]) + if count-offset < len(vv.views[i]) { + vv.views[i].TrimFront(count - offset) + break + } + offset += len(vv.views[i]) + vv.views[i] = nil } - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] + // We're guaranteed that i > 0, since count is too large for the first + // view. + vv.views[i-1] = newFirst + vv.views = vv.views[i-1:] + return newFirst, true } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -225,3 +241,10 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } + +// removeFirst panics when len(vv.views) < 1. +func (vv *VectorisedView) removeFirst() { + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] +} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 106e1994c..c56795c7b 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,6 +16,7 @@ package buffer import ( + "bytes" "reflect" "testing" ) @@ -370,3 +371,115 @@ func TestVVRead(t *testing.T) { }) } } + +var pullUpTestCases = []struct { + comment string + in VectorisedView + count int + want []byte + result VectorisedView + ok bool +}{ + { + comment: "simple case", + in: vv(2, "12"), + count: 1, + want: []byte("1"), + result: vv(2, "12"), + ok: true, + }, + { + comment: "entire View", + in: vv(2, "1", "2"), + count: 1, + want: []byte("1"), + result: vv(2, "1", "2"), + ok: true, + }, + { + comment: "spanning across two Views", + in: vv(3, "1", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, + { + comment: "spanning across all Views", + in: vv(5, "1", "23", "45"), + count: 5, + want: []byte("12345"), + result: vv(5, "12345"), + ok: true, + }, + { + comment: "count = 0", + in: vv(1, "1"), + count: 0, + want: []byte{}, + result: vv(1, "1"), + ok: true, + }, + { + comment: "count = size", + in: vv(1, "1"), + count: 1, + want: []byte("1"), + result: vv(1, "1"), + ok: true, + }, + { + comment: "count too large", + in: vv(3, "1", "23"), + count: 4, + want: nil, + result: vv(3, "1", "23"), + ok: false, + }, + { + comment: "empty vv", + in: vv(0, ""), + count: 1, + want: nil, + result: vv(0, ""), + ok: false, + }, + { + comment: "empty vv, count = 0", + in: vv(0, ""), + count: 0, + want: nil, + result: vv(0, ""), + ok: true, + }, + { + comment: "empty views", + in: vv(3, "", "1", "", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, +} + +func TestPullUp(t *testing.T) { + for _, c := range pullUpTestCases { + got, ok := c.in.PullUp(c.count) + + // Is the return value right? + if ok != c.ok { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", + c.comment, c.count, c.in, ok, c.ok) + } + if bytes.Compare(got, View(c.want)) != 0 { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", + c.comment, c.count, c.in, got, c.want) + } + + // Is the underlying structure right? + if !reflect.DeepEqual(c.in, c.result) { + t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", + c.comment, c.count, c.in, c.result) + } + } +} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 1e2255bfa..073c84ef9 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // Reject the packet if it's shorter than an ethernet header. - if vv.Size() < header.EthernetMinimumSize { + // There should be an ethernet header at the beginning of vv. + hdr, ok := vv.PullUp(header.EthernetMinimumSize) + if !ok { + // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } - - // There should be an ethernet header at the beginning of vv. - linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 14b527bc2..9cc08d0e2 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -18,3 +18,10 @@ go_library( "@org_golang_x_sys//unix:go_default_library", ], ) + +go_test( + name = "rawfile_test", + size = "small", + srcs = ["rawfile_test.go"], + library = ":rawfile", +) diff --git a/pkg/tcpip/link/rawfile/rawfile_test.go b/pkg/tcpip/link/rawfile/rawfile_test.go new file mode 100644 index 000000000..8f14ba761 --- /dev/null +++ b/pkg/tcpip/link/rawfile/rawfile_test.go @@ -0,0 +1,46 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package rawfile + +import ( + "syscall" + "testing" +) + +func TestNonBlockingWrite3ZeroLength(t *testing.T) { + fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) + if err != nil { + t.Fatalf("failed to open /dev/null: %v", err) + } + defer syscall.Close(fd) + + if err := NonBlockingWrite3(fd, []byte{}, []byte{0}, nil); err != nil { + t.Fatalf("failed to write: %v", err) + } +} + +func TestNonBlockingWrite3Nil(t *testing.T) { + fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) + if err != nil { + t.Fatalf("failed to open /dev/null: %v", err) + } + defer syscall.Close(fd) + + if err := NonBlockingWrite3(fd, nil, []byte{0}, nil); err != nil { + t.Fatalf("failed to write: %v", err) + } +} diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 44e25d475..92efd0bf8 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -76,9 +76,13 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { // We have two buffers. Build the iovec that represents them and issue // a writev syscall. + var base *byte + if len(b1) > 0 { + base = &b1[0] + } iovec := [3]syscall.Iovec{ { - Base: &b1[0], + Base: base, Len: uint64(len(b1)), }, { diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 27ea3f531..33f640b85 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.First()) + rcvd := []byte(c.packets[0].vv.ToView()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index be2537a82..0799c8f4d 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,11 +171,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - first := pkt.Header.View() - if len(first) == 0 { - first = pkt.Data.First() - } - logPacket(prefix, protocol, first, gso) + logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -238,7 +234,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -247,28 +243,49 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie size := uint16(0) var fragmentOffset uint16 var moreFragments bool + + // Create a clone of pkt, including any headers if present. Avoid allocating + // backing memory for the clone. + views := [8]buffer.View{} + vv := buffer.NewVectorisedView(0, views[:0]) + vv.AppendView(pkt.Header.View()) + vv.Append(pkt.Data) + switch protocol { case header.IPv4ProtocolNumber: - ipv4 := header.IPv4(b) + hdr, ok := vv.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + ipv4 := header.IPv4(hdr) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - b = b[ipv4.HeaderLength():] + vv.TrimFront(int(ipv4.HeaderLength())) id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - ipv6 := header.IPv6(b) + hdr, ok := vv.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + ipv6 := header.IPv6(hdr) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - b = b[header.IPv6MinimumSize:] + vv.TrimFront(header.IPv6MinimumSize) case header.ARPProtocolNumber: - arp := header.ARP(b) + hdr, ok := vv.PullUp(header.ARPSize) + if !ok { + return + } + vv.TrimFront(header.ARPSize) + arp := header.ARP(hdr) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -284,7 +301,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -297,7 +314,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - icmp := header.ICMPv4(b) + hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + if !ok { + break + } + icmp := header.ICMPv4(hdr) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -330,7 +351,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.ICMPv6ProtocolNumber: transName = "icmp" - icmp := header.ICMPv6(b) + hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + if !ok { + break + } + icmp := header.ICMPv6(hdr) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -361,7 +386,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.UDPProtocolNumber: transName = "udp" - udp := header.UDP(b) + hdr, ok := vv.PullUp(header.UDPMinimumSize) + if !ok { + break + } + udp := header.UDP(hdr) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -371,7 +400,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.TCPProtocolNumber: transName = "tcp" - tcp := header.TCP(b) + hdr, ok := vv.PullUp(header.TCPMinimumSize) + if !ok { + break + } + tcp := header.TCP(hdr) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7acbfa0a8..cf73a939e 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -93,7 +93,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v := pkt.Data.First() + v, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return + } h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index c4bf1ba5c..4cbefe5ab 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,7 +25,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + hdr := header.IPv4(h) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -34,12 +38,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } - hlen := int(h.HeaderLength()) - if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { + hlen := int(hdr.HeaderLength()) + if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -48,15 +52,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := hdr.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv4MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 104aafbed..17202cc7a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -328,7 +328,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return tcpip.ErrInvalidOptionValue + } + ip := header.IPv4(h) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -378,7 +382,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b68983d10..bdf3a0d25 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,7 +28,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv6(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + hdr := header.IPv6(h) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -36,17 +40,21 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := h.TransportProtocol() + p := hdr.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(pkt.Data.First()) - if !f.IsValid() || f.FragmentOffset() != 0 { + f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) + if !ok { + return + } + fragHdr := header.IPv6Fragment(f) + if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -55,19 +63,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = f.TransportProtocol() + p = fragHdr.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv6MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) + if !ok { received.Invalid.Increment() return } @@ -76,11 +84,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // - // Only the first view in vv is accounted for by h. To account for the - // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.RemoveFirst() + payload.TrimFront(len(h)) if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -101,34 +107,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - if len(v) < header.ICMPv6PacketTooBigMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := h.MTU() + mtu := header.ICMPv6(hdr).MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - if len(v) < header.ICMPv6DstUnreachableMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch h.Code() { + switch header.ICMPv6(hdr).Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - ns := header.NDPNeighborSolicit(h.NDPPayload()) + // The remainder of payload must be only the neighbor solicitation, so + // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, + // NDP messages cannot be fragmented. Also note that in the common case NDP + // datagrams are very small and ToView() will not incur allocations. + ns := header.NDPNeighborSolicit(payload.ToView()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -286,12 +298,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - na := header.NDPNeighborAdvert(h.NDPPayload()) + // The remainder of payload must be only the neighbor advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + na := header.NDPNeighborAdvert(payload.ToView()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -363,14 +379,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, h) + copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -384,7 +401,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -406,8 +423,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - p := h.NDPPayload() - if len(p) < header.NDPRAMinimumSize || !isNDPValid() { + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -425,7 +443,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - ra := header.NDPRouterAdvert(p) + // The remainder of payload must be only the router advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + ra := header.NDPRouterAdvert(payload.ToView()) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index bd099a7f8..d412ff688 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,7 +166,8 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize}, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 331b0817b..486725131 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,7 +171,11 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index e9c652042..c7c663498 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,7 +70,10 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -473,7 +476,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -517,7 +520,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -564,7 +567,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -619,7 +622,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6c0a4b24d..6b91159d4 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,6 +212,11 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. +// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -226,7 +231,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -271,14 +278,21 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pk.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data.First(). + // pkt.Data. if pkt.NetworkHeader == nil { - pkt.NetworkHeader = pkt.Data.First() + var ok bool + pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // Precondition has been violated. + panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) + } } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 7b4543caf..8be61f4b1 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,9 +96,12 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView := newPkt.Data.First() + headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return RuleDrop, 0 + } netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] + newPkt.NetworkHeader = headerView hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -117,10 +120,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.UDPMinimumSize { + if pkt.Data.Size() < header.UDPMinimumSize { + return RuleDrop, 0 + } + hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) + if !ok { return RuleDrop, 0 } - udpHeader = header.UDP(newPkt.Data.First()) + udpHeader = header.UDP(hdr) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -128,10 +135,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.TCPMinimumSize { + if pkt.Data.Size() < header.TCPMinimumSize { return RuleDrop, 0 } - tcpHeader = header.TCP(newPkt.TransportHeader) + hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return RuleDrop, 0 + } + tcpHeader = header.TCP(hdr) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 016dbe15e..0c2b1f36a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1203,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - if len(pkt.Data.First()) < netProto.MinimumPacketSize() { + netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - - src, dst := netProto.ParseAddresses(pkt.Data.First()) + src, dst := netProto.ParseAddresses(netHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1289,22 +1289,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - - firstData := pkt.Data.First() - pkt.Data.RemoveFirst() - - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { - pkt.Header = buffer.NewPrependableFromView(firstData) - } else { - firstDataLen := len(firstData) - - // pkt.Header should have enough capacity to hold n.linkEP's headers. - pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) - - // TODO(b/151227689): avoid copying the packet when forwarding - if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) - } + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { + pkt.Header = buffer.NewPrependable(linkHeaderLen) } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1332,12 +1318,13 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(pkt.Data.First()) < transProto.MinimumPacketSize() { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1375,11 +1362,12 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(pkt.Data.First()) < 8 { + transHeader, ok := pkt.Data.PullUp(8) + if !ok { return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index dc125f25e..7d36f8e84 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,7 +37,13 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. + // down the stack, each layer adds to Header. Note that forwarded + // packets don't populate Headers on their way out -- their headers and + // payload are never parsed out and remain in Data. + // + // TODO(gvisor.dev/issue/170): Forwarded packets don't currently + // populate Header, but should. This will be doable once early parsing + // (https://github.com/google/gvisor/pull/1995) is supported. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c7634ceb1..d45d2cc1f 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,16 +95,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := pkt.Data.First() - if len(nb) < fakeNetHeaderLen { + nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 3084e6593..a611e44ab 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,10 +642,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.Header.View()[0]; dst != 3 { + hdrs := p.Pkt.Data.ToView() + if dst := hdrs[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.Header.View()[1]; src != 1 { + if src := hdrs[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index feef8dca0..b1d820372 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(pkt.Data.First()) - if h.Type() != header.ICMPv4EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(pkt.Data.First()) - if h.Type() != header.ICMPv6EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) + if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 40461fd31..7712ce652 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,7 +144,11 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h := header.TCP(s.data.First()) + h, ok := s.data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + hdr := header.TCP(h) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -156,12 +160,16 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(h.DataOffset()) - if offset < header.TCPMinimumSize || offset > len(h) { + offset := int(hdr.DataOffset()) + if offset < header.TCPMinimumSize { + return false + } + hdrWithOpts, ok := s.data.PullUp(offset) + if !ok { return false } - s.options = []byte(h[header.TCPMinimumSize:offset]) + s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -173,18 +181,19 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - s.csum = h.Checksum() + hdr = header.TCP(hdrWithOpts) + s.csum = hdr.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = h.CalculateChecksum(xsum) + xsum = hdr.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(h.SequenceNumber()) - s.ackNumber = seqnum.Value(h.AckNumber()) - s.flags = h.Flags() - s.window = seqnum.Size(h.WindowSize()) + s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(hdr.AckNumber()) + s.flags = hdr.Flags() + s.window = seqnum.Size(hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index ab1014c7f..286c66cf5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3548,7 +3548,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3575,7 +3575,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index edb54f0be..756ab913a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: hdr.SourcePort(), + Port: header.UDP(hdr).SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 6e31a9bac..52af6de22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,8 +68,13 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + if int(header.UDP(h).Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From 1ceee045294a6059093851645968f5a7e00a58f3 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 24 Apr 2020 12:45:33 -0700 Subject: Do not copy tcpip.CancellableTimer A CancellableTimer's AfterFunc timer instance creates a closure over the CancellableTimer's address. This closure makes a CancellableTimer unsafe to copy. No behaviour change, existing tests pass. PiperOrigin-RevId: 308306664 --- pkg/tcpip/stack/ndp.go | 31 +++++++++++++++++++++---------- pkg/tcpip/timer.go | 25 ++++++++++++++++++++++--- pkg/tcpip/timer_test.go | 16 ++++++++-------- 3 files changed, 51 insertions(+), 21 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 193a9dfde..c11d62f97 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -412,20 +412,33 @@ type dadState struct { // defaultRouterState holds data associated with a default router discovered by // a Router Advertisement (RA). type defaultRouterState struct { - invalidationTimer tcpip.CancellableTimer + // Timer to invalidate the default router. + // + // May not be nil. + invalidationTimer *tcpip.CancellableTimer } // onLinkPrefixState holds data associated with an on-link prefix discovered by // a Router Advertisement's Prefix Information option (PI) when the NDP // configurations was configured to do so. type onLinkPrefixState struct { - invalidationTimer tcpip.CancellableTimer + // Timer to invalidate the on-link prefix. + // + // May not be nil. + invalidationTimer *tcpip.CancellableTimer } // slaacPrefixState holds state associated with a SLAAC prefix. type slaacPrefixState struct { - deprecationTimer tcpip.CancellableTimer - invalidationTimer tcpip.CancellableTimer + // Timer to deprecate the prefix. + // + // May not be nil. + deprecationTimer *tcpip.CancellableTimer + + // Timer to invalidate the prefix. + // + // May not be nil. + invalidationTimer *tcpip.CancellableTimer // Nonzero only when the address is not valid forever. validUntil time.Time @@ -775,7 +788,6 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { } rtr.invalidationTimer.StopLocked() - delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. @@ -804,7 +816,7 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { } state := defaultRouterState{ - invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { ndp.invalidateDefaultRouter(ip) }), } @@ -834,7 +846,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) } state := onLinkPrefixState{ - invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { ndp.invalidateOnLinkPrefix(prefix) }), } @@ -859,7 +871,6 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { } s.invalidationTimer.StopLocked() - delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. @@ -979,7 +990,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { } state := slaacPrefixState{ - deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) @@ -987,7 +998,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { ndp.deprecateSLAACAddress(state.ref) }), - invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() { + invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go index 67f66fc72..59f3b391f 100644 --- a/pkg/tcpip/timer.go +++ b/pkg/tcpip/timer.go @@ -88,6 +88,9 @@ func (t *cancellableTimerInstance) stop() { // // The term "related work" is defined as some work that needs to be done while // holding some lock that the timer must also hold while doing some work. +// +// Note, it is not safe to copy a CancellableTimer as its timer instance creates +// a closure over the address of the CancellableTimer. type CancellableTimer struct { // The active instance of a cancellable timer. instance cancellableTimerInstance @@ -154,12 +157,28 @@ func (t *CancellableTimer) Reset(d time.Duration) { } } -// MakeCancellableTimer returns an unscheduled CancellableTimer with the given +// Lock is a no-op used by the copylocks checker from go vet. +// +// See CancellableTimer for details about why it shouldn't be copied. +// +// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more +// details about the copylocks checker. +func (*CancellableTimer) Lock() {} + +// Unlock is a no-op used by the copylocks checker from go vet. +// +// See CancellableTimer for details about why it shouldn't be copied. +// +// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more +// details about the copylocks checker. +func (*CancellableTimer) Unlock() {} + +// NewCancellableTimer returns an unscheduled CancellableTimer with the given // locker and fn. // // fn MUST NOT attempt to lock locker. // // Callers must call Reset to schedule the timer to fire. -func MakeCancellableTimer(locker sync.Locker, fn func()) CancellableTimer { - return CancellableTimer{locker: locker, fn: fn} +func NewCancellableTimer(locker sync.Locker, fn func()) *CancellableTimer { + return &CancellableTimer{locker: locker, fn: fn} } diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index 730134906..b4940e397 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -43,7 +43,7 @@ func TestCancellableTimerReassignment(t *testing.T) { // that has an active timer (even if it has been stopped as a stopped // timer may be blocked on a lock before it can check if it has been // stopped while another goroutine holds the same lock). - timer = tcpip.MakeCancellableTimer(&lock, func() { + timer = *tcpip.NewCancellableTimer(&lock, func() { wg.Done() }) timer.Reset(shortDuration) @@ -59,7 +59,7 @@ func TestCancellableTimerFire(t *testing.T) { ch := make(chan struct{}) var lock sync.Mutex - timer := tcpip.MakeCancellableTimer(&lock, func() { + timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) timer.Reset(shortDuration) @@ -85,7 +85,7 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { ch := make(chan struct{}) var lock sync.Mutex - timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) timer.Reset(middleDuration) lock.Lock() @@ -116,7 +116,7 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) { var lock sync.Mutex lock.Lock() - timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) timer.Reset(shortDuration) timer.StopLocked() lock.Unlock() @@ -153,7 +153,7 @@ func TestCancellableTimerImmediatelyStop(t *testing.T) { for i := 0; i < 1000; i++ { lock.Lock() - timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) timer.Reset(shortDuration) timer.StopLocked() lock.Unlock() @@ -174,7 +174,7 @@ func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { var lock sync.Mutex lock.Lock() - timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) timer.Reset(shortDuration) timer.StopLocked() lock.Unlock() @@ -205,7 +205,7 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { var lock sync.Mutex lock.Lock() - timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) timer.Reset(shortDuration) for i := 0; i < 10; i++ { // Sleep until the timer fires and gets blocked trying to take the lock. @@ -237,7 +237,7 @@ func TestManyCancellableTimerResetUnderLock(t *testing.T) { var lock sync.Mutex lock.Lock() - timer := tcpip.MakeCancellableTimer(&lock, func() { ch <- struct{}{} }) + timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) timer.Reset(shortDuration) for i := 0; i < 10; i++ { timer.StopLocked() -- cgit v1.2.3 From 55f0c3316af8ea2a1fcc16511efc580f307623f6 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Mon, 27 Apr 2020 12:25:10 -0700 Subject: Automated rollback of changelist 308163542 PiperOrigin-RevId: 308674219 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++---------- pkg/tcpip/buffer/view_test.go | 113 ----------------------------- pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/rawfile/BUILD | 9 +-- pkg/tcpip/link/rawfile/rawfile_test.go | 46 ------------ pkg/tcpip/link/rawfile/rawfile_unsafe.go | 6 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 ++++------------- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 ++--- pkg/tcpip/network/ipv4/ipv4.go | 12 +-- pkg/tcpip/network/ipv6/icmp.go | 74 +++++++------------ pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +----- pkg/tcpip/stack/iptables_targets.go | 23 ++---- pkg/tcpip/stack/nic.go | 34 ++++++--- pkg/tcpip/stack/packet_buffer.go | 8 +- pkg/tcpip/stack/stack_test.go | 10 +-- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++----- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 +-- 28 files changed, 149 insertions(+), 458 deletions(-) delete mode 100644 pkg/tcpip/link/rawfile/rawfile_test.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 55c0f04f3..ff1cfd8f6 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,13 +121,12 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.TCPMinimumSize { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(pkt.Data.First()) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 04d03d494..3359418c1 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,13 +120,12 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.UDPMinimumSize { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(pkt.Data.First()) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index f01217c91..8ec5d5d5c 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,8 +77,7 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. It panics -// if count > vv.Size(). +// TrimFront removes the first "count" bytes of the vectorised view. func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -87,7 +86,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } } @@ -105,7 +104,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } if copied == 0 { return 0, io.EOF @@ -127,7 +126,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.removeFirst() + vv.RemoveFirst() } return copied } @@ -163,37 +162,22 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// PullUp returns the first "count" bytes of the vectorised view. If those -// bytes aren't already contiguous inside the vectorised view, PullUp will -// reallocate as needed to make them contiguous. PullUp fails and returns false -// when count > vv.Size(). -func (vv *VectorisedView) PullUp(count int) (View, bool) { +// First returns the first view of the vectorised view. +func (vv *VectorisedView) First() View { if len(vv.views) == 0 { - return nil, count == 0 - } - if count <= len(vv.views[0]) { - return vv.views[0][:count], true - } - if count > vv.size { - return nil, false + return nil } + return vv.views[0] +} - newFirst := NewView(count) - i := 0 - for offset := 0; offset < count; i++ { - copy(newFirst[offset:], vv.views[i]) - if count-offset < len(vv.views[i]) { - vv.views[i].TrimFront(count - offset) - break - } - offset += len(vv.views[i]) - vv.views[i] = nil +// RemoveFirst removes the first view of the vectorised view. +func (vv *VectorisedView) RemoveFirst() { + if len(vv.views) == 0 { + return } - // We're guaranteed that i > 0, since count is too large for the first - // view. - vv.views[i-1] = newFirst - vv.views = vv.views[i-1:] - return newFirst, true + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -241,10 +225,3 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } - -// removeFirst panics when len(vv.views) < 1. -func (vv *VectorisedView) removeFirst() { - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] -} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index c56795c7b..106e1994c 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,7 +16,6 @@ package buffer import ( - "bytes" "reflect" "testing" ) @@ -371,115 +370,3 @@ func TestVVRead(t *testing.T) { }) } } - -var pullUpTestCases = []struct { - comment string - in VectorisedView - count int - want []byte - result VectorisedView - ok bool -}{ - { - comment: "simple case", - in: vv(2, "12"), - count: 1, - want: []byte("1"), - result: vv(2, "12"), - ok: true, - }, - { - comment: "entire View", - in: vv(2, "1", "2"), - count: 1, - want: []byte("1"), - result: vv(2, "1", "2"), - ok: true, - }, - { - comment: "spanning across two Views", - in: vv(3, "1", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, - { - comment: "spanning across all Views", - in: vv(5, "1", "23", "45"), - count: 5, - want: []byte("12345"), - result: vv(5, "12345"), - ok: true, - }, - { - comment: "count = 0", - in: vv(1, "1"), - count: 0, - want: []byte{}, - result: vv(1, "1"), - ok: true, - }, - { - comment: "count = size", - in: vv(1, "1"), - count: 1, - want: []byte("1"), - result: vv(1, "1"), - ok: true, - }, - { - comment: "count too large", - in: vv(3, "1", "23"), - count: 4, - want: nil, - result: vv(3, "1", "23"), - ok: false, - }, - { - comment: "empty vv", - in: vv(0, ""), - count: 1, - want: nil, - result: vv(0, ""), - ok: false, - }, - { - comment: "empty vv, count = 0", - in: vv(0, ""), - count: 0, - want: nil, - result: vv(0, ""), - ok: true, - }, - { - comment: "empty views", - in: vv(3, "", "1", "", "23"), - count: 2, - want: []byte("12"), - result: vv(3, "12", "3"), - ok: true, - }, -} - -func TestPullUp(t *testing.T) { - for _, c := range pullUpTestCases { - got, ok := c.in.PullUp(c.count) - - // Is the return value right? - if ok != c.ok { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", - c.comment, c.count, c.in, ok, c.ok) - } - if bytes.Compare(got, View(c.want)) != 0 { - t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", - c.comment, c.count, c.in, got, c.want) - } - - // Is the underlying structure right? - if !reflect.DeepEqual(c.in, c.result) { - t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", - c.comment, c.count, c.in, c.result) - } - } -} diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 073c84ef9..1e2255bfa 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // There should be an ethernet header at the beginning of vv. - hdr, ok := vv.PullUp(header.EthernetMinimumSize) - if !ok { - // Reject the packet if it's shorter than an ethernet header. + // Reject the packet if it's shorter than an ethernet header. + if vv.Size() < header.EthernetMinimumSize { return tcpip.ErrBadAddress } - linkHeader := header.Ethernet(hdr) + + // There should be an ethernet header at the beginning of vv. + linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 9cc08d0e2..14b527bc2 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -18,10 +18,3 @@ go_library( "@org_golang_x_sys//unix:go_default_library", ], ) - -go_test( - name = "rawfile_test", - size = "small", - srcs = ["rawfile_test.go"], - library = ":rawfile", -) diff --git a/pkg/tcpip/link/rawfile/rawfile_test.go b/pkg/tcpip/link/rawfile/rawfile_test.go deleted file mode 100644 index 8f14ba761..000000000 --- a/pkg/tcpip/link/rawfile/rawfile_test.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build linux - -package rawfile - -import ( - "syscall" - "testing" -) - -func TestNonBlockingWrite3ZeroLength(t *testing.T) { - fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) - if err != nil { - t.Fatalf("failed to open /dev/null: %v", err) - } - defer syscall.Close(fd) - - if err := NonBlockingWrite3(fd, []byte{}, []byte{0}, nil); err != nil { - t.Fatalf("failed to write: %v", err) - } -} - -func TestNonBlockingWrite3Nil(t *testing.T) { - fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) - if err != nil { - t.Fatalf("failed to open /dev/null: %v", err) - } - defer syscall.Close(fd) - - if err := NonBlockingWrite3(fd, nil, []byte{0}, nil); err != nil { - t.Fatalf("failed to write: %v", err) - } -} diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 92efd0bf8..44e25d475 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -76,13 +76,9 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { // We have two buffers. Build the iovec that represents them and issue // a writev syscall. - var base *byte - if len(b1) > 0 { - base = &b1[0] - } iovec := [3]syscall.Iovec{ { - Base: base, + Base: &b1[0], Len: uint64(len(b1)), }, { diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 33f640b85..27ea3f531 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.ToView()) + rcvd := []byte(c.packets[0].vv.First()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 0799c8f4d..be2537a82 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,7 +171,11 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - logPacket(prefix, protocol, pkt, gso) + first := pkt.Header.View() + if len(first) == 0 { + first = pkt.Data.First() + } + logPacket(prefix, protocol, first, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -234,7 +238,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -243,49 +247,28 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P size := uint16(0) var fragmentOffset uint16 var moreFragments bool - - // Create a clone of pkt, including any headers if present. Avoid allocating - // backing memory for the clone. - views := [8]buffer.View{} - vv := buffer.NewVectorisedView(0, views[:0]) - vv.AppendView(pkt.Header.View()) - vv.Append(pkt.Data) - switch protocol { case header.IPv4ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - ipv4 := header.IPv4(hdr) + ipv4 := header.IPv4(b) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - vv.TrimFront(int(ipv4.HeaderLength())) + b = b[ipv4.HeaderLength():] id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv6MinimumSize) - if !ok { - return - } - ipv6 := header.IPv6(hdr) + ipv6 := header.IPv6(b) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - vv.TrimFront(header.IPv6MinimumSize) + b = b[header.IPv6MinimumSize:] case header.ARPProtocolNumber: - hdr, ok := vv.PullUp(header.ARPSize) - if !ok { - return - } - vv.TrimFront(header.ARPSize) - arp := header.ARP(hdr) + arp := header.ARP(b) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -301,7 +284,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -314,11 +297,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) - if !ok { - break - } - icmp := header.ICMPv4(hdr) + icmp := header.ICMPv4(b) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -351,11 +330,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) - if !ok { - break - } - icmp := header.ICMPv6(hdr) + icmp := header.ICMPv6(b) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -386,11 +361,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.UDPProtocolNumber: transName = "udp" - hdr, ok := vv.PullUp(header.UDPMinimumSize) - if !ok { - break - } - udp := header.UDP(hdr) + udp := header.UDP(b) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -400,11 +371,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.TCPProtocolNumber: transName = "tcp" - hdr, ok := vv.PullUp(header.TCPMinimumSize) - if !ok { - break - } - tcp := header.TCP(hdr) + tcp := header.TCP(b) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index cf73a939e..7acbfa0a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -93,10 +93,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v, ok := pkt.Data.PullUp(header.ARPSize) - if !ok { - return - } + v := pkt.Data.First() h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 4cbefe5ab..c4bf1ba5c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,11 +25,7 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - hdr := header.IPv4(h) + h := header.IPv4(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -38,12 +34,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { return } - hlen := int(hdr.HeaderLength()) - if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { + hlen := int(h.HeaderLength()) + if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -52,15 +48,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := h.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok { + v := pkt.Data.First() + if len(v) < header.ICMPv4MinimumSize { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 17202cc7a..104aafbed 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -328,11 +328,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return tcpip.ErrInvalidOptionValue - } - ip := header.IPv4(h) + ip := header.IPv4(pkt.Data.First()) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -382,11 +378,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } + headerView := pkt.Data.First() h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdf3a0d25..b68983d10 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,11 +28,7 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - return - } - hdr := header.IPv6(h) + h := header.IPv6(pkt.Data.First()) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -40,21 +36,17 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := hdr.TransportProtocol() + p := h.TransportProtocol() if p == header.IPv6FragmentHeader { - f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) - if !ok { - return - } - fragHdr := header.IPv6Fragment(f) - if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { + f := header.IPv6Fragment(pkt.Data.First()) + if !f.IsValid() || f.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -63,19 +55,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = fragHdr.TransportProtocol() + p = f.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) - if !ok { + v := pkt.Data.First() + if len(v) < header.ICMPv6MinimumSize { received.Invalid.Increment() return } @@ -84,9 +76,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // + // Only the first view in vv is accounted for by h. To account for the + // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.TrimFront(len(h)) + payload.RemoveFirst() if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -107,40 +101,34 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) - if !ok { + if len(v) < header.ICMPv6PacketTooBigMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := header.ICMPv6(hdr).MTU() + mtu := h.MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) - if !ok { + if len(v) < header.ICMPv6DstUnreachableMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch header.ICMPv6(hdr).Code() { + switch h.Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - // The remainder of payload must be only the neighbor solicitation, so - // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, - // NDP messages cannot be fragmented. Also note that in the common case NDP - // datagrams are very small and ToView() will not incur allocations. - ns := header.NDPNeighborSolicit(payload.ToView()) + ns := header.NDPNeighborSolicit(h.NDPPayload()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -298,16 +286,12 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - // The remainder of payload must be only the neighbor advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - na := header.NDPNeighborAdvert(payload.ToView()) + na := header.NDPNeighborAdvert(h.NDPPayload()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -379,15 +363,14 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) - if !ok { + if len(v) < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, icmpHdr) + copy(packet, h) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -401,7 +384,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { + if len(v) < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -423,9 +406,8 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - // Is the NDP payload of sufficient size to hold a Router - // Advertisement? - if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { + p := h.NDPPayload() + if len(p) < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -443,11 +425,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - // The remainder of payload must be only the router advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. - ra := header.NDPRouterAdvert(payload.ToView()) + ra := header.NDPRouterAdvert(p) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d412ff688..bd099a7f8 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,8 +166,7 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - }, + size: header.ICMPv6NeighborSolicitMinimumSize}, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 486725131..331b0817b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,11 +171,7 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } + headerView := pkt.Data.First() h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c7c663498..e9c652042 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,10 +70,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) - if !ok { - return - } + b := pkt.Data.First() pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -476,7 +473,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -520,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -567,7 +564,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -622,7 +619,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Data.ToView() + b := p.Pkt.Header.View() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6b91159d4..6c0a4b24d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,11 +212,6 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. -// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -231,9 +226,7 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Precondition: pkt.NetworkHeader is set. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -278,21 +271,14 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Precondition: pk.NetworkHeader is set. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data. + // pkt.Data.First(). if pkt.NetworkHeader == nil { - var ok bool - pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // Precondition has been violated. - panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) - } + pkt.NetworkHeader = pkt.Data.First() } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 8be61f4b1..7b4543caf 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,12 +96,9 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return RuleDrop, 0 - } + headerView := newPkt.Data.First() netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView + newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -120,14 +117,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if pkt.Data.Size() < header.UDPMinimumSize { - return RuleDrop, 0 - } - hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) - if !ok { + if len(pkt.Data.First()) < header.UDPMinimumSize { return RuleDrop, 0 } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(newPkt.Data.First()) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -135,14 +128,10 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if pkt.Data.Size() < header.TCPMinimumSize { + if len(pkt.Data.First()) < header.TCPMinimumSize { return RuleDrop, 0 } - hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) - if !ok { - return RuleDrop, 0 - } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(newPkt.TransportHeader) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0c2b1f36a..016dbe15e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1203,12 +1203,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) - if !ok { + if len(pkt.Data.First()) < netProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(netHeader) + + src, dst := netProto.ParseAddresses(pkt.Data.First()) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1289,8 +1289,22 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { - pkt.Header = buffer.NewPrependable(linkHeaderLen) + + firstData := pkt.Data.First() + pkt.Data.RemoveFirst() + + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { + pkt.Header = buffer.NewPrependableFromView(firstData) + } else { + firstDataLen := len(firstData) + + // pkt.Header should have enough capacity to hold n.linkEP's headers. + pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) + + // TODO(b/151227689): avoid copying the packet when forwarding + if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) + } } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1318,13 +1332,12 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + if len(pkt.Data.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1362,12 +1375,11 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - transHeader, ok := pkt.Data.PullUp(8) - if !ok { + if len(pkt.Data.First()) < 8 { return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 7d36f8e84..dc125f25e 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,13 +37,7 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. Note that forwarded - // packets don't populate Headers on their way out -- their headers and - // payload are never parsed out and remain in Data. - // - // TODO(gvisor.dev/issue/170): Forwarded packets don't currently - // populate Header, but should. This will be doable once early parsing - // (https://github.com/google/gvisor/pull/1995) is supported. + // down the stack, each layer adds to Header. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index d45d2cc1f..c7634ceb1 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,18 +95,16 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { - return - } + b := pkt.Data.First() pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { + nb := pkt.Data.First() + if len(nb) < fakeNetHeaderLen { return } + pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a611e44ab..3084e6593 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,11 +642,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - hdrs := p.Pkt.Data.ToView() - if dst := hdrs[0]; dst != 3 { + if dst := p.Pkt.Header.View()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := hdrs[1]; src != 1 { + if src := p.Pkt.Header.View()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b1d820372..feef8dca0 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.Data.First()) + if h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) - if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.Data.First()) + if h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 7712ce652..40461fd31 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,11 +144,7 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h, ok := s.data.PullUp(header.TCPMinimumSize) - if !ok { - return false - } - hdr := header.TCP(h) + h := header.TCP(s.data.First()) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -160,16 +156,12 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(hdr.DataOffset()) - if offset < header.TCPMinimumSize { - return false - } - hdrWithOpts, ok := s.data.PullUp(offset) - if !ok { + offset := int(h.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(h) { return false } - s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) + s.options = []byte(h[header.TCPMinimumSize:offset]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -181,19 +173,18 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - hdr = header.TCP(hdrWithOpts) - s.csum = hdr.Checksum() + s.csum = h.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = hdr.CalculateChecksum(xsum) + xsum = h.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) - s.ackNumber = seqnum.Value(hdr.AckNumber()) - s.flags = hdr.Flags() - s.window = seqnum.Size(hdr.WindowSize()) + s.sequenceNumber = seqnum.Value(h.SequenceNumber()) + s.ackNumber = seqnum.Value(h.AckNumber()) + s.flags = h.Flags() + s.window = seqnum.Size(h.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 286c66cf5..ab1014c7f 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3548,7 +3548,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3575,7 +3575,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf := vv.First()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 756ab913a..edb54f0be 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 52af6de22..6e31a9bac 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,13 +68,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Malformed packet. - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true - } - if int(header.UDP(h).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.Data.First()) + if int(hdr.Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From 37a59bc76da7e0b20be3cef1fcb1d5cb8fbc839d Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 28 Apr 2020 16:01:11 -0700 Subject: Support IPv6 Privacy Extensions for SLAAC Support generating temporary (short-lived) IPv6 SLAAC addresses to address privacy concerns outlined in RFC 4941. Tests: - stack_test.TestAutoGenTempAddr - stack_test.TestNoAutoGenTempAddrForLinkLocal - stack_test.TestAutoGenTempAddrRegen - stack_test.TestAutoGenTempAddrRegenTimerUpdates - stack_test.TestNoAutoGenTempAddrWithoutStableAddr - stack_test.TestAutoGenAddrInResponseToDADConflicts PiperOrigin-RevId: 308915566 --- pkg/tcpip/header/ipv6.go | 52 +++ pkg/tcpip/stack/ndp.go | 645 +++++++++++++++++++++---- pkg/tcpip/stack/ndp_test.go | 1089 ++++++++++++++++++++++++++++++++++++------- pkg/tcpip/stack/nic.go | 34 +- pkg/tcpip/stack/stack.go | 20 + 5 files changed, 1571 insertions(+), 269 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index ba80b64a8..4f367fe4c 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -17,6 +17,7 @@ package header import ( "crypto/sha256" "encoding/binary" + "fmt" "strings" "gvisor.dev/gvisor/pkg/tcpip" @@ -445,3 +446,54 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { return GlobalScope, nil } } + +// InitialTempIID generates the initial temporary IID history value to generate +// temporary SLAAC addresses with. +// +// Panics if initialTempIIDHistory is not at least IIDSize bytes. +func InitialTempIID(initialTempIIDHistory []byte, seed []byte, nicID tcpip.NICID) { + h := sha256.New() + // h.Write never returns an error. + h.Write(seed) + var nicIDBuf [4]byte + binary.BigEndian.PutUint32(nicIDBuf[:], uint32(nicID)) + h.Write(nicIDBuf[:]) + + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + if n := copy(initialTempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize)) + } +} + +// GenerateTempIPv6SLAACAddr generates a temporary SLAAC IPv6 address for an +// associated stable/permanent SLAAC address. +// +// GenerateTempIPv6SLAACAddr will update the temporary IID history value to be +// used when generating a new temporary IID. +// +// Panics if tempIIDHistory is not at least IIDSize bytes. +func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) tcpip.AddressWithPrefix { + addrBytes := []byte(stableAddr) + h := sha256.New() + h.Write(tempIIDHistory) + h.Write(addrBytes[IIDOffsetInIPv6Address:]) + var sumBuf [sha256.Size]byte + sum := h.Sum(sumBuf[:0]) + + // The rightmost 64 bits of sum are saved for the next iteration. + if n := copy(tempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize)) + } + + // The leftmost 64 bits of sum is used as the IID. + if n := copy(addrBytes[IIDOffsetInIPv6Address:], sum); n != IIDSize { + panic(fmt.Sprintf("copied %d IID bytes, expected %d bytes", n, IIDSize)) + } + + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: IIDOffsetInIPv6Address * 8, + } +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index c11d62f97..fbc0e3f55 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -119,6 +119,32 @@ const ( // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so // 128 - 64 = 64. validPrefixLenForAutoGen = 64 + + // defaultAutoGenTempGlobalAddresses is the default configuration for whether + // or not to generate temporary SLAAC addresses. + defaultAutoGenTempGlobalAddresses = true + + // defaultMaxTempAddrValidLifetime is the default maximum valid lifetime + // for temporary SLAAC addresses generated as part of RFC 4941. + // + // Default = 7 days (from RFC 4941 section 5). + defaultMaxTempAddrValidLifetime = 7 * 24 * time.Hour + + // defaultMaxTempAddrPreferredLifetime is the default preferred lifetime + // for temporary SLAAC addresses generated as part of RFC 4941. + // + // Default = 1 day (from RFC 4941 section 5). + defaultMaxTempAddrPreferredLifetime = 24 * time.Hour + + // defaultRegenAdvanceDuration is the default duration before the deprecation + // of a temporary address when a new address will be generated. + // + // Default = 5s (from RFC 4941 section 5). + defaultRegenAdvanceDuration = 5 * time.Second + + // minRegenAdvanceDuration is the minimum duration before the deprecation + // of a temporary address when a new address will be generated. + minRegenAdvanceDuration = time.Duration(0) ) var ( @@ -131,6 +157,37 @@ var ( // // Min = 2hrs. MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour + + // MaxDesyncFactor is the upper bound for the preferred lifetime's desync + // factor for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // Must be greater than 0. + // + // Max = 10m (from RFC 4941 section 5). + MaxDesyncFactor = 10 * time.Minute + + // MinMaxTempAddrPreferredLifetime is the minimum value allowed for the + // maximum preferred lifetime for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // This value guarantees that a temporary address will be preferred for at + // least 1hr if the SLAAC prefix is valid for at least that time. + MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour + + // MinMaxTempAddrValidLifetime is the minimum value allowed for the + // maximum valid lifetime for temporary SLAAC addresses. + // + // This is exported as a variable (instead of a constant) so tests + // can update it to a smaller value. + // + // This value guarantees that a temporary address will be valid for at least + // 2hrs if the SLAAC prefix is valid for at least that time. + MinMaxTempAddrValidLifetime = 2 * time.Hour ) // DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an @@ -324,35 +381,49 @@ type NDPConfigurations struct { // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's // MAC address), then no attempt will be made to resolve the conflict. AutoGenAddressConflictRetries uint8 + + // AutoGenTempGlobalAddresses determines whether or not temporary SLAAC + // addresses will be generated for a NIC as part of SLAAC privacy extensions, + // RFC 4941. + // + // Ignored if AutoGenGlobalAddresses is false. + AutoGenTempGlobalAddresses bool + + // MaxTempAddrValidLifetime is the maximum valid lifetime for temporary + // SLAAC addresses. + MaxTempAddrValidLifetime time.Duration + + // MaxTempAddrPreferredLifetime is the maximum preferred lifetime for + // temporary SLAAC addresses. + MaxTempAddrPreferredLifetime time.Duration + + // RegenAdvanceDuration is the duration before the deprecation of a temporary + // address when a new address will be generated. + RegenAdvanceDuration time.Duration } // DefaultNDPConfigurations returns an NDPConfigurations populated with // default values. func DefaultNDPConfigurations() NDPConfigurations { return NDPConfigurations{ - DupAddrDetectTransmits: defaultDupAddrDetectTransmits, - RetransmitTimer: defaultRetransmitTimer, - MaxRtrSolicitations: defaultMaxRtrSolicitations, - RtrSolicitationInterval: defaultRtrSolicitationInterval, - MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay, - HandleRAs: defaultHandleRAs, - DiscoverDefaultRouters: defaultDiscoverDefaultRouters, - DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, - AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, + DupAddrDetectTransmits: defaultDupAddrDetectTransmits, + RetransmitTimer: defaultRetransmitTimer, + MaxRtrSolicitations: defaultMaxRtrSolicitations, + RtrSolicitationInterval: defaultRtrSolicitationInterval, + MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay, + HandleRAs: defaultHandleRAs, + DiscoverDefaultRouters: defaultDiscoverDefaultRouters, + DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes, + AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses, + AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses, + MaxTempAddrValidLifetime: defaultMaxTempAddrValidLifetime, + MaxTempAddrPreferredLifetime: defaultMaxTempAddrPreferredLifetime, + RegenAdvanceDuration: defaultRegenAdvanceDuration, } } // validate modifies an NDPConfigurations with valid values. If invalid values // are present in c, the corresponding default values will be used instead. -// -// If RetransmitTimer is less than minimumRetransmitTimer, then a value of -// defaultRetransmitTimer will be used. -// -// If RtrSolicitationInterval is less than minimumRtrSolicitationInterval, then -// a value of defaultRtrSolicitationInterval will be used. -// -// If MaxRtrSolicitationDelay is less than minimumMaxRtrSolicitationDelay, then -// a value of defaultMaxRtrSolicitationDelay will be used. func (c *NDPConfigurations) validate() { if c.RetransmitTimer < minimumRetransmitTimer { c.RetransmitTimer = defaultRetransmitTimer @@ -365,6 +436,18 @@ func (c *NDPConfigurations) validate() { if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay { c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay } + + if c.MaxTempAddrValidLifetime < MinMaxTempAddrValidLifetime { + c.MaxTempAddrValidLifetime = MinMaxTempAddrValidLifetime + } + + if c.MaxTempAddrPreferredLifetime < MinMaxTempAddrPreferredLifetime || c.MaxTempAddrPreferredLifetime > c.MaxTempAddrValidLifetime { + c.MaxTempAddrPreferredLifetime = MinMaxTempAddrPreferredLifetime + } + + if c.RegenAdvanceDuration < minRegenAdvanceDuration { + c.RegenAdvanceDuration = minRegenAdvanceDuration + } } // ndpState is the per-interface NDP state. @@ -394,6 +477,14 @@ type ndpState struct { // The last learned DHCPv6 configuration from an NDP RA. dhcpv6Configuration DHCPv6ConfigurationFromNDPRA + + // temporaryIIDHistory is the history value used to generate a new temporary + // IID. + temporaryIIDHistory [header.IIDSize]byte + + // temporaryAddressDesyncFactor is the preferred lifetime's desync factor for + // temporary SLAAC addresses. + temporaryAddressDesyncFactor time.Duration } // dadState holds the Duplicate Address Detection timer and channel to signal @@ -414,7 +505,7 @@ type dadState struct { type defaultRouterState struct { // Timer to invalidate the default router. // - // May not be nil. + // Must not be nil. invalidationTimer *tcpip.CancellableTimer } @@ -424,20 +515,48 @@ type defaultRouterState struct { type onLinkPrefixState struct { // Timer to invalidate the on-link prefix. // - // May not be nil. + // Must not be nil. + invalidationTimer *tcpip.CancellableTimer +} + +// tempSLAACAddrState holds state associated with a temporary SLAAC address. +type tempSLAACAddrState struct { + // Timer to deprecate the temporary SLAAC address. + // + // Must not be nil. + deprecationTimer *tcpip.CancellableTimer + + // Timer to invalidate the temporary SLAAC address. + // + // Must not be nil. invalidationTimer *tcpip.CancellableTimer + + // Timer to regenerate the temporary SLAAC address. + // + // Must not be nil. + regenTimer *tcpip.CancellableTimer + + createdAt time.Time + + // The address's endpoint. + // + // Must not be nil. + ref *referencedNetworkEndpoint + + // Has a new temporary SLAAC address already been regenerated? + regenerated bool } // slaacPrefixState holds state associated with a SLAAC prefix. type slaacPrefixState struct { // Timer to deprecate the prefix. // - // May not be nil. + // Must not be nil. deprecationTimer *tcpip.CancellableTimer // Timer to invalidate the prefix. // - // May not be nil. + // Must not be nil. invalidationTimer *tcpip.CancellableTimer // Nonzero only when the address is not valid forever. @@ -446,19 +565,27 @@ type slaacPrefixState struct { // Nonzero only when the address is not preferred forever. preferredUntil time.Time - // The prefix's permanent address endpoint. + // The endpoint for the stable address generated for a SLAAC prefix. // // May only be nil when a SLAAC address is being (re-)generated. Otherwise, - // must not be nil as all SLAAC prefixes must have a SLAAC address. - ref *referencedNetworkEndpoint + // must not be nil as all SLAAC prefixes must have a stable address. + stableAddrRef *referencedNetworkEndpoint + + // The temporary (short-lived) addresses generated for the SLAAC prefix. + tempAddrs map[tcpip.Address]tempSLAACAddrState - // The number of times a permanent address has been generated for the prefix. + // The next two fields are used by both stable and temporary addresses + // generated for a SLAAC prefix. This is safe as only 1 address will be + // in the generation and DAD process at any time. That is, no two addresses + // will be generated at the same time for a given SLAAC prefix. + + // The number of times an address has been generated. // // Addresses may be regenerated in reseponse to a DAD conflicts. generationAttempts uint8 - // The maximum number of times to attempt regeneration of a permanent SLAAC - // address in response to DAD conflicts. + // The maximum number of times to attempt regeneration of a SLAAC address + // in response to DAD conflicts. maxGenerationAttempts uint8 } @@ -536,10 +663,10 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref } ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() if done { // If we reach this point, it means that DAD was stopped after we released // the NIC's read lock and before we obtained the write lock. - ndp.nic.mu.Unlock() return } @@ -551,8 +678,6 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // schedule the next DAD timer. remaining-- timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) - - ndp.nic.mu.Unlock() return } @@ -560,15 +685,18 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // the last NDP NS. Either way, clean up addr's DAD state and let the // integrator know DAD has completed. delete(ndp.dad, addr) - ndp.nic.mu.Unlock() - - if err != nil { - log.Printf("ndpdad: error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) - } if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err) } + + // If DAD resolved for a stable SLAAC address, attempt generation of a + // temporary SLAAC address. + if dadDone && ref.configType == slaac { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */) + } }) ndp.dad[addr] = dadState{ @@ -953,9 +1081,10 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform prefix := pi.Subnet() // Check if we already maintain SLAAC state for prefix. - if _, ok := ndp.slaacPrefixes[prefix]; ok { + if state, ok := ndp.slaacPrefixes[prefix]; ok { // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes. - ndp.refreshSLAACPrefixLifetimes(prefix, pl, vl) + ndp.refreshSLAACPrefixLifetimes(prefix, &state, pl, vl) + ndp.slaacPrefixes[prefix] = state return } @@ -996,7 +1125,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) } - ndp.deprecateSLAACAddress(state.ref) + ndp.deprecateSLAACAddress(state.stableAddrRef) }), invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] @@ -1006,6 +1135,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { ndp.invalidateSLAACPrefix(prefix, state) }), + tempAddrs: make(map[tcpip.Address]tempSLAACAddrState), maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1, } @@ -1035,9 +1165,49 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { state.validUntil = now.Add(vl) } + // If the address is assigned (DAD resolved), generate a temporary address. + if state.stableAddrRef.getKind() == permanent { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */) + } + ndp.slaacPrefixes[prefix] = state } +// addSLAACAddr adds a SLAAC address to the NIC. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint { + // If the nic already has this address, do nothing further. + if ndp.nic.hasPermanentAddrLocked(addr.Address) { + return nil + } + + // Inform the integrator that we have a new SLAAC address. + ndpDisp := ndp.nic.stack.ndpDisp + if ndpDisp == nil { + return nil + } + + if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addr) { + // Informed by the integrator not to add the address. + return nil + } + + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr, + } + + ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated) + if err != nil { + panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err)) + } + + return ref +} + // generateSLAACAddr generates a SLAAC address for prefix. // // Returns true if an address was successfully generated. @@ -1046,7 +1216,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool { - if r := state.ref; r != nil { + if r := state.stableAddrRef; r != nil { panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix())) } @@ -1085,39 +1255,18 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt return false } - generatedAddr := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: tcpip.Address(addrBytes), - PrefixLen: validPrefixLenForAutoGen, - }, - } - - // If the nic already has this address, do nothing further. - if ndp.nic.hasPermanentAddrLocked(generatedAddr.AddressWithPrefix.Address) { - return false - } - - // Inform the integrator that we have a new SLAAC address. - ndpDisp := ndp.nic.stack.ndpDisp - if ndpDisp == nil { - return false + generatedAddr := tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: validPrefixLenForAutoGen, } - if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), generatedAddr.AddressWithPrefix) { - // Informed by the integrator not to add the address. - return false + if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil { + state.stableAddrRef = ref + state.generationAttempts++ + return true } - deprecated := time.Since(state.preferredUntil) >= 0 - ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated) - if err != nil { - panic(fmt.Sprintf("ndp: error when adding address %+v: %s", generatedAddr, err)) - } - - state.generationAttempts++ - state.ref = ref - return true + return false } // regenerateSLAACAddr regenerates an address for a SLAAC prefix. @@ -1143,24 +1292,167 @@ func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) { ndp.invalidateSLAACPrefix(prefix, state) } -// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix. +// generateTempSLAACAddr generates a new temporary SLAAC address. // -// pl is the new preferred lifetime. vl is the new valid lifetime. +// If resetGenAttempts is true, the prefix's generation counter will be reset. +// +// Returns true if a new address was generated. +func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool { + // Are we configured to auto-generate new temporary global addresses for the + // prefix? + if !ndp.configs.AutoGenTempGlobalAddresses || prefix == header.IPv6LinkLocalPrefix.Subnet() { + return false + } + + if resetGenAttempts { + prefixState.generationAttempts = 0 + prefixState.maxGenerationAttempts = ndp.configs.AutoGenAddressConflictRetries + 1 + } + + // If we have already reached the maximum address generation attempts for the + // prefix, do not generate another address. + if prefixState.generationAttempts == prefixState.maxGenerationAttempts { + return false + } + + stableAddr := prefixState.stableAddrRef.ep.ID().LocalAddress + now := time.Now() + + // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary + // address is the lower of the valid lifetime of the stable address or the + // maximum temporary address valid lifetime. + vl := ndp.configs.MaxTempAddrValidLifetime + if prefixState.validUntil != (time.Time{}) { + if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL { + vl = prefixVL + } + } + + if vl <= 0 { + // Cannot create an address without a valid lifetime. + return false + } + + // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary + // address is the lower of the preferred lifetime of the stable address or the + // maximum temporary address preferred lifetime - the temporary address desync + // factor. + pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor + if prefixState.preferredUntil != (time.Time{}) { + if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL { + // Respect the preferred lifetime of the prefix, as per RFC 4941 section + // 3.3 step 4. + pl = prefixPL + } + } + + // As per RFC 4941 section 3.3 step 5, a temporary address is created only if + // the calculated preferred lifetime is greater than the advance regeneration + // duration. In particular, we MUST NOT create a temporary address with a zero + // Preferred Lifetime. + if pl <= ndp.configs.RegenAdvanceDuration { + return false + } + + generatedAddr := header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr) + + // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary + // address with a zero preferred lifetime. The checks above ensure this + // so we know the address is not deprecated. + ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */) + if ref == nil { + return false + } + + state := tempSLAACAddrState{ + deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr)) + } + + ndp.deprecateSLAACAddress(tempAddrState.ref) + }), + invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to invalidate temporary address %s", generatedAddr)) + } + + ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) + }), + regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + prefixState, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) + } + + tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to regenerate temporary address after %s", generatedAddr)) + } + + // If an address has already been regenerated for this address, don't + // regenerate another address. + if tempAddrState.regenerated { + return + } + + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + tempAddrState.regenerated = ndp.generateTempSLAACAddr(prefix, &prefixState, true /* resetGenAttempts */) + prefixState.tempAddrs[generatedAddr.Address] = tempAddrState + ndp.slaacPrefixes[prefix] = prefixState + }), + createdAt: now, + ref: ref, + } + + state.deprecationTimer.Reset(pl) + state.invalidationTimer.Reset(vl) + state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration) + + prefixState.generationAttempts++ + prefixState.tempAddrs[generatedAddr.Address] = state + + return true +} + +// regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl time.Duration) { - prefixState, ok := ndp.slaacPrefixes[prefix] +func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) { + state, ok := ndp.slaacPrefixes[prefix] if !ok { - panic(fmt.Sprintf("ndp: SLAAC prefix state not found to refresh lifetimes for %s", prefix)) + panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate temporary address for %s", prefix)) } - defer func() { ndp.slaacPrefixes[prefix] = prefixState }() + ndp.generateTempSLAACAddr(prefix, &state, resetGenAttempts) + ndp.slaacPrefixes[prefix] = state +} + +// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix. +// +// pl is the new preferred lifetime. vl is the new valid lifetime. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) { // If the preferred lifetime is zero, then the prefix should be deprecated. deprecated := pl == 0 if deprecated { - ndp.deprecateSLAACAddress(prefixState.ref) + ndp.deprecateSLAACAddress(prefixState.stableAddrRef) } else { - prefixState.ref.deprecated = false + prefixState.stableAddrRef.deprecated = false } // If prefix was preferred for some finite lifetime before, stop the @@ -1190,36 +1482,118 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl tim // // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. - // Handle the infinite valid lifetime separately as we do not keep a timer in - // this case. if vl >= header.NDPInfiniteLifetime { + // Handle the infinite valid lifetime separately as we do not keep a timer + // in this case. prefixState.invalidationTimer.StopLocked() prefixState.validUntil = time.Time{} - return - } + } else { + var effectiveVl time.Duration + var rl time.Duration + + // If the prefix was originally set to be valid forever, assume the + // remaining time to be the maximum possible value. + if prefixState.validUntil == (time.Time{}) { + rl = header.NDPInfiniteLifetime + } else { + rl = time.Until(prefixState.validUntil) + } - var effectiveVl time.Duration - var rl time.Duration + if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { + effectiveVl = vl + } else if rl > MinPrefixInformationValidLifetimeForUpdate { + effectiveVl = MinPrefixInformationValidLifetimeForUpdate + } - // If the prefix was originally set to be valid forever, assume the remaining - // time to be the maximum possible value. - if prefixState.validUntil == (time.Time{}) { - rl = header.NDPInfiniteLifetime - } else { - rl = time.Until(prefixState.validUntil) + if effectiveVl != 0 { + prefixState.invalidationTimer.StopLocked() + prefixState.invalidationTimer.Reset(effectiveVl) + prefixState.validUntil = now.Add(effectiveVl) + } } - if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { - effectiveVl = vl - } else if rl <= MinPrefixInformationValidLifetimeForUpdate { + // If DAD is not yet complete on the stable address, there is no need to do + // work with temporary addresses. + if prefixState.stableAddrRef.getKind() != permanent { return - } else { - effectiveVl = MinPrefixInformationValidLifetimeForUpdate } - prefixState.invalidationTimer.StopLocked() - prefixState.invalidationTimer.Reset(effectiveVl) - prefixState.validUntil = now.Add(effectiveVl) + // Note, we do not need to update the entries in the temporary address map + // after updating the timers because the timers are held as pointers. + var regenForAddr tcpip.Address + allAddressesRegenerated := true + for tempAddr, tempAddrState := range prefixState.tempAddrs { + // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary + // address is the lower of the valid lifetime of the stable address or the + // maximum temporary address valid lifetime. Note, the valid lifetime of a + // temporary address is relative to the address's creation time. + validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime) + if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 { + validUntil = prefixState.validUntil + } + + // If the address is no longer valid, invalidate it immediately. Otherwise, + // reset the invalidation timer. + newValidLifetime := validUntil.Sub(now) + if newValidLifetime <= 0 { + ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState) + continue + } + tempAddrState.invalidationTimer.StopLocked() + tempAddrState.invalidationTimer.Reset(newValidLifetime) + + // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary + // address is the lower of the preferred lifetime of the stable address or + // the maximum temporary address preferred lifetime - the temporary address + // desync factor. Note, the preferred lifetime of a temporary address is + // relative to the address's creation time. + preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor) + if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 { + preferredUntil = prefixState.preferredUntil + } + + // If the address is no longer preferred, deprecate it immediately. + // Otherwise, reset the deprecation timer. + newPreferredLifetime := preferredUntil.Sub(now) + tempAddrState.deprecationTimer.StopLocked() + if newPreferredLifetime <= 0 { + ndp.deprecateSLAACAddress(tempAddrState.ref) + } else { + tempAddrState.ref.deprecated = false + tempAddrState.deprecationTimer.Reset(newPreferredLifetime) + } + + tempAddrState.regenTimer.StopLocked() + if tempAddrState.regenerated { + } else { + allAddressesRegenerated = false + + if newPreferredLifetime <= ndp.configs.RegenAdvanceDuration { + // The new preferred lifetime is less than the advance regeneration + // duration so regenerate an address for this temporary address + // immediately after we finish iterating over the temporary addresses. + regenForAddr = tempAddr + } else { + tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) + } + } + } + + // Generate a new temporary address if all of the existing temporary addresses + // have been regenerated, or we need to immediately regenerate an address + // due to an update in preferred lifetime. + // + // If each temporay address has already been regenerated, no new temporary + // address will be generated. To ensure continuation of temporary SLAAC + // addresses, we manually try to regenerate an address here. + if len(regenForAddr) != 0 || allAddressesRegenerated { + // Reset the generation attempts counter as we are starting the generation + // of a new address for the SLAAC prefix. + if state, ok := prefixState.tempAddrs[regenForAddr]; ndp.generateTempSLAACAddr(prefix, prefixState, true /* resetGenAttempts */) && ok { + state.regenerated = true + prefixState.tempAddrs[regenForAddr] = state + } + } } // deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP @@ -1243,11 +1617,11 @@ func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) { - if r := state.ref; r != nil { + if r := state.stableAddrRef; r != nil { // Since we are already invalidating the prefix, do not invalidate the // prefix when removing the address. - if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACPrefixInvalidation */); err != nil { - panic(fmt.Sprintf("ndp: removePermanentIPv6EndpointLocked(%s, false): %s", r.addrWithPrefix(), err)) + if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", r.addrWithPrefix(), err)) } } @@ -1265,14 +1639,14 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr prefix := addr.Subnet() state, ok := ndp.slaacPrefixes[prefix] - if !ok || state.ref == nil || addr.Address != state.ref.ep.ID().LocalAddress { + if !ok || state.stableAddrRef == nil || addr.Address != state.stableAddrRef.ep.ID().LocalAddress { return } if !invalidatePrefix { // If the prefix is not being invalidated, disassociate the address from the // prefix and do nothing further. - state.ref = nil + state.stableAddrRef = nil ndp.slaacPrefixes[prefix] = state return } @@ -1286,11 +1660,68 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) { + // Invalidate all temporary addresses. + for tempAddr, tempAddrState := range state.tempAddrs { + ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState) + } + + state.stableAddrRef = nil state.deprecationTimer.StopLocked() state.invalidationTimer.StopLocked() delete(ndp.slaacPrefixes, prefix) } +// invalidateTempSLAACAddr invalidates a temporary SLAAC address. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { + // Since we are already invalidating the address, do not invalidate the + // address when removing the address. + if err := ndp.nic.removePermanentIPv6EndpointLocked(tempAddrState.ref, false /* allowSLAACInvalidation */); err != nil { + panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.ref.addrWithPrefix(), err)) + } + + ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState) +} + +// cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary +// SLAAC address's resources from ndp. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) { + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr) + } + + if !invalidateAddr { + return + } + + prefix := addr.Subnet() + state, ok := ndp.slaacPrefixes[prefix] + if !ok { + panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry to clean up temp addr %s resources", addr)) + } + + tempAddrState, ok := state.tempAddrs[addr.Address] + if !ok { + panic(fmt.Sprintf("ndp: must have a tempAddr entry to clean up temp addr %s resources", addr)) + } + + ndp.cleanupTempSLAACAddrResources(state.tempAddrs, addr.Address, tempAddrState) +} + +// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's +// timers and entry. +// +// The NIC that ndp belongs to MUST be locked. +func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { + tempAddrState.deprecationTimer.StopLocked() + tempAddrState.invalidationTimer.StopLocked() + tempAddrState.regenTimer.StopLocked() + delete(tempAddrs, tempAddr) +} + // cleanupState cleans up ndp's state. // // If hostOnly is true, then only host-specific state will be cleaned up. @@ -1450,3 +1881,13 @@ func (ndp *ndpState) stopSolicitingRouters() { ndp.rtrSolicitTimer.Stop() ndp.rtrSolicitTimer = nil } + +// initializeTempAddrState initializes state related to temporary SLAAC +// addresses. +func (ndp *ndpState) initializeTempAddrState() { + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.nic.stack.tempIIDSeed, ndp.nic.ID()) + + if MaxDesyncFactor != 0 { + ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) + } +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 6dd460984..421df674f 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1801,6 +1801,726 @@ func TestAutoGenAddr(t *testing.T) { } } +func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string { + ret := "" + for _, c := range containList { + if !containsV6Addr(addrs, c) { + ret += fmt.Sprintf("should have %s in the list of addresses\n", c) + } + } + for _, c := range notContainList { + if containsV6Addr(addrs, c) { + ret += fmt.Sprintf("should not have %s in the list of addresses\n", c) + } + } + return ret +} + +// TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when +// configured to do so as part of IPv6 Privacy Extensions. +func TestAutoGenTempAddr(t *testing.T) { + const ( + nicID = 1 + newMinVL = 5 + newMinVLDuration = newMinVL * time.Second + ) + + savedMinPrefixInformationValidLifetimeForUpdate := stack.MinPrefixInformationValidLifetimeForUpdate + savedMaxDesync := stack.MaxDesyncFactor + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate + stack.MaxDesyncFactor = savedMaxDesync + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration + stack.MaxDesyncFactor = time.Nanosecond + + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + + tests := []struct { + name string + dupAddrTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD disabled", + }, + { + name: "DAD enabled", + dupAddrTransmits: 1, + retransmitTimer: time.Second, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for i, test := range tests { + i := i + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + seed := []byte{uint8(i)} + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], seed, nicID) + newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { + return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) + } + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 2), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + TempIIDSeed: seed, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + expectDADEventAsync := func(addr tcpip.Address) { + t.Helper() + + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) + default: + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + expectDADEventAsync(addr1.Address) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + tempAddr1 := newTempAddr(addr1.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(tempAddr1, newAddr) + expectDADEventAsync(tempAddr1.Address) + if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred + // lifetimes. + tempAddr2 := newTempAddr(addr2.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + expectDADEventAsync(addr2.Address) + expectAutoGenAddrEventAsync(tempAddr2, newAddr) + expectDADEventAsync(tempAddr2.Address) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Deprecate prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Refresh lifetimes for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Reduce valid lifetime and deprecate addresses of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for addrs of prefix1 to be invalidated. They should be + // invalidated at the same time. + select { + case e := <-ndpDisp.autoGenAddrC: + var nextAddr tcpip.AddressWithPrefix + if e.addr == addr1 { + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + nextAddr = tempAddr1 + } else { + if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + nextAddr = addr1 + } + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + case <-time.After(newMinVLDuration + defaultTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in a PI w/ 0 lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unexpected auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + }) + } + }) +} + +// TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not +// generated for auto generated link-local addresses. +func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { + const nicID = 1 + + savedMaxDesyncFactor := stack.MaxDesyncFactor + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + }() + stack.MaxDesyncFactor = time.Nanosecond + + tests := []struct { + name string + dupAddrTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD disabled", + }, + { + name: "DAD enabled", + dupAddrTransmits: 1, + retransmitTimer: time.Second, + }, + } + + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the + // parallel tests complete. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run("group", func(t *testing.T) { + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + AutoGenIPv6LinkLocal: true, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // The stable link-local address should auto-generate and resolve DAD. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + + // No new addresses should be generated. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unxpected auto gen addr event = %+v", e) + case <-time.After(defaultAsyncEventTimeout): + } + }) + } + }) +} + +// TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address +// will not be generated until after DAD completes, even if a new Router +// Advertisement is received to refresh lifetimes. +func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { + const ( + nicID = 1 + dadTransmits = 1 + retransmitTimer = 2 * time.Second + ) + + savedMaxDesyncFactor := stack.MaxDesyncFactor + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + }() + stack.MaxDesyncFactor = 0 + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempAddr := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // Receive an RA to trigger SLAAC for prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + + // DAD on the stable address for prefix has not yet completed. Receiving a new + // RA that would refresh lifetimes should not generate a temporary SLAAC + // address for the prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + default: + } + + // Wait for DAD to complete for the stable address then expect the temporary + // address to be generated. + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } +} + +// TestAutoGenTempAddrRegen tests that temporary SLAAC addresses are +// regenerated. +func TestAutoGenTempAddrRegen(t *testing.T) { + const ( + nicID = 1 + regenAfter = 2 * time.Second + newMinVL = 10 + newMinVLDuration = newMinVL * time.Second + ) + + savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + }() + stack.MaxDesyncFactor = 0 + stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration + stack.MinMaxTempAddrValidLifetime = newMinVLDuration + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: newMinVLDuration - regenAfter, + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(addr, newAddr) + expectAutoGenAddrEvent(tempAddr1, newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for regeneration + expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncEventTimeout) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for regeneration + expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Stop generating temporary addresses + ndpConfigs.AutoGenTempGlobalAddresses = false + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + + // Wait for all the temporary addresses to get invalidated. + tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3} + invalidateAfter := newMinVLDuration - 2*regenAfter + for _, addr := range tempAddrs { + // Wait for a deprecation then invalidation event, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation timers could fire in any + // order. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we shouldn't get a deprecation + // event after. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated event = %+v", e) + case <-time.After(defaultTimeout): + } + } else { + t.Fatalf("got unexpected auto-generated event = %+v", e) + } + case <-time.After(invalidateAfter + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + + invalidateAfter = regenAfter + } + if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" { + t.Fatal(mismatch) + } +} + +// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's +// regeneration timer gets updated when refreshing the address's lifetimes. +func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { + const ( + nicID = 1 + regenAfter = 2 * time.Second + newMinVL = 10 + newMinVLDuration = newMinVL * time.Second + ) + + savedMaxDesyncFactor := stack.MaxDesyncFactor + savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime + savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime + defer func() { + stack.MaxDesyncFactor = savedMaxDesyncFactor + stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime + stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime + }() + stack.MaxDesyncFactor = 0 + stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration + stack.MinMaxTempAddrValidLifetime = newMinVLDuration + + prefix, _, addr := prefixSubnetAddr(0, linkAddr1) + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: newMinVLDuration - regenAfter, + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + }) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(addr, newAddr) + expectAutoGenAddrEvent(tempAddr1, newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Deprecate the prefix. + // + // A new temporary address should be generated after the regeneration + // time has passed since the prefix is deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0)) + expectAutoGenAddrEvent(addr, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + case <-time.After(regenAfter + defaultAsyncEventTimeout): + } + + // Prefer the prefix again. + // + // A new temporary address should immediately be generated since the + // regeneration time has already passed since the last address was generated + // - this regeneration does not depend on a timer. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(tempAddr2, newAddr) + + // Increase the maximum lifetimes for temporary addresses to large values + // then refresh the lifetimes of the prefix. + // + // A new address should not be generated after the regeneration time that was + // expected for the previous check. This is because the preferred lifetime for + // the temporary addresses has increased, so it will take more time to + // regenerate a new temporary address. Note, new addresses are only + // regenerated after the preferred lifetime - the regenerate advance duration + // as paased. + ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second + ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + case <-time.After(regenAfter + defaultAsyncEventTimeout): + } + + // Set the maximum lifetimes for temporary addresses such that on the next + // RA, the regeneration timer gets reset. + // + // The maximum lifetime is the sum of the minimum lifetimes for temporary + // addresses + the time that has already passed since the last address was + // generated so that the regeneration timer is needed to generate the next + // address. + newLifetimes := newMinVLDuration + regenAfter + defaultAsyncEventTimeout + ndpConfigs.MaxTempAddrValidLifetime = newLifetimes + ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout) +} + // stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher, // channel.Endpoint and stack.Stack. // @@ -2196,7 +2916,6 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { } else { t.Fatalf("got unexpected auto-generated event") } - case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } @@ -2808,9 +3527,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } } -// TestAutoGenAddrWithOpaqueIIDDADRetries tests the regeneration of an -// auto-generated IPv6 address in response to a DAD conflict. -func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) { +func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { const nicID = 1 const nicName = "nic" const dadTransmits = 1 @@ -2818,6 +3535,13 @@ func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) { const maxMaxRetries = 3 const lifetimeSeconds = 10 + // Needed for the temporary address sub test. + savedMaxDesync := stack.MaxDesyncFactor + defer func() { + stack.MaxDesyncFactor = savedMaxDesync + }() + stack.MaxDesyncFactor = time.Nanosecond + var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte secretKey := secretKeyBuf[:] n, err := rand.Read(secretKey) @@ -2830,185 +3554,234 @@ func TestAutoGenAddrWithOpaqueIIDDADRetries(t *testing.T) { prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) - for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { - for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { - addrTypes := []struct { - name string - ndpConfigs stack.NDPConfigurations - autoGenLinkLocal bool - subnet tcpip.Subnet - triggerSLAACFn func(e *channel.Endpoint) - }{ - { - name: "Global address", - ndpConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - HandleRAs: true, - AutoGenGlobalAddresses: true, - AutoGenAddressConflictRetries: maxRetries, - }, - subnet: subnet, - triggerSLAACFn: func(e *channel.Endpoint) { - // Receive an RA with prefix1 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + addrForSubnet := func(subnet tcpip.Subnet, dadCounter uint8) tcpip.AddressWithPrefix { + addrBytes := []byte(subnet.ID()) + return tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, dadCounter, secretKey)), + PrefixLen: 64, + } + } - }, - }, - { - name: "LinkLocal address", - ndpConfigs: stack.NDPConfigurations{ - DupAddrDetectTransmits: dadTransmits, - RetransmitTimer: retransmitTimer, - AutoGenAddressConflictRetries: maxRetries, - }, - autoGenLinkLocal: true, - subnet: header.IPv6LinkLocalPrefix.Subnet(), - triggerSLAACFn: func(e *channel.Endpoint) {}, - }, + expectAutoGenAddrEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected addr auto gen event") + } + } - for _, addrType := range addrTypes { - maxRetries := maxRetries - numFailures := numFailures - addrType := addrType + expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - t.Run(fmt.Sprintf("%s with %d max retries and %d failures", addrType.name, maxRetries, numFailures), func(t *testing.T) { - t.Parallel() + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: addrType.ndpConfigs, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: secretKey, - }, - }) - opts := stack.NICOptions{Name: nicName} - if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { - t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) - } + expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) { + t.Helper() - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } + expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) { + t.Helper() - addrType.triggerSLAACFn(e) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + } - // Simulate DAD conflicts so the address is regenerated. - for i := uint8(0); i < numFailures; i++ { - addrBytes := []byte(addrType.subnet.ID()) - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, i, secretKey)), - PrefixLen: 64, - } - expectAutoGenAddrEvent(addr, newAddr) + stableAddrForTempAddrTest := addrForSubnet(subnet, 0) - // Should not have any addresses assigned to the NIC. - mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) + addrTypes := []struct { + name string + ndpConfigs stack.NDPConfigurations + autoGenLinkLocal bool + prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix + addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix + }{ + { + name: "Global address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + }, + prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { + // Receive an RA with prefix1 in a PI. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + return nil + + }, + addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { + return addrForSubnet(subnet, dadCounter) + }, + }, + { + name: "LinkLocal address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, + autoGenLinkLocal: true, + prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { + return nil + }, + addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { + return addrForSubnet(header.IPv6LinkLocalPrefix.Subnet(), dadCounter) + }, + }, + { + name: "Temporary address", + ndpConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { + header.InitialTempIID(tempIIDHistory, nil, nicID) + + // Generate a stable SLAAC address so temporary addresses will be + // generated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) + expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, true) + + // The stable address will be assigned throughout the test. + return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} + }, + addrGenFn: func(_ uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix { + return header.GenerateTempIPv6SLAACAddr(tempIIDHistory, stableAddrForTempAddrTest.Address) + }, + }, + } + + for _, addrType := range addrTypes { + // This Run will not return until the parallel tests finish. + // + // We need this because we need to do some teardown work after the parallel + // tests complete and limit the number of parallel tests running at the same + // time to reduce flakes. + // + // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for + // more details. + t.Run(addrType.name, func(t *testing.T) { + for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { + for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { + maxRetries := maxRetries + numFailures := numFailures + addrType := addrType + + t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } - if want := (tcpip.AddressWithPrefix{}); mainAddr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want) + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := addrType.ndpConfigs + ndpConfigs.AutoGenAddressConflictRetries = maxRetries + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: func(_ tcpip.NICID, nicName string) string { + return nicName + }, + SecretKey: secretKey, + }, + }) + opts := stack.NICOptions{Name: nicName} + if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) } - // Simulate a DAD conflict. - if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { - t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) - } - expectAutoGenAddrEvent(addr, invalidatedAddr) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + var tempIIDHistory [header.IIDSize]byte + stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:]) + + // Simulate DAD conflicts so the address is regenerated. + for i := uint8(0); i < numFailures; i++ { + addr := addrType.addrGenFn(i, tempIIDHistory[:]) + expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) + + // Should not have any new addresses assigned to the NIC. + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { + t.Fatal(mismatch) } - default: - t.Fatal("expected DAD event") - } - // Attempting to add the address manually should not fail if the - // address's state was cleaned up when DAD failed. - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) - } - if err := s.RemoveAddress(nicID, addr.Address); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) - } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + // Simulate a DAD conflict. + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) } - default: - t.Fatal("expected DAD event") - } - } + expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) + expectDADEvent(t, &ndpDisp, addr.Address, false) - // Should not have any addresses assigned to the NIC. - mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) - } - if want := (tcpip.AddressWithPrefix{}); mainAddr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, want) - } + // Attempting to add the address manually should not fail if the + // address's state was cleaned up when DAD failed. + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err) + } + if err := s.RemoveAddress(nicID, addr.Address); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) + } + expectDADEvent(t, &ndpDisp, addr.Address, false) + } - // If we had less failures than generation attempts, we should have an - // address after DAD resolves. - if maxRetries+1 > numFailures { - addrBytes := []byte(addrType.subnet.ID()) - addr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], addrType.subnet, nicName, numFailures, secretKey)), - PrefixLen: 64, + // Should not have any new addresses assigned to the NIC. + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { + t.Fatal(mismatch) } - expectAutoGenAddrEvent(addr, newAddr) - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + // If we had less failures than generation attempts, we should have + // an address after DAD resolves. + if maxRetries+1 > numFailures { + addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) + expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) + expectDADEventAsync(t, &ndpDisp, addr.Address, true) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { + t.Fatal(mismatch) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): - t.Fatal("timed out waiting for DAD event") } - mainAddr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) - } - if mainAddr != addr { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", mainAddr, addr) + // Should not attempt address generation again. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) + case <-time.After(defaultAsyncEventTimeout): } - } - - // Should not attempt address regeneration again. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncEventTimeout): - } - }) + }) + } } - } + }) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 016dbe15e..25188b4fb 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -131,6 +131,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), } + nic.mu.ndp.initializeTempAddrState() // Register supported packet endpoint protocols. for _, netProto := range header.Ethertypes { @@ -1014,14 +1015,14 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { switch r.protocol { case header.IPv6ProtocolNumber: - return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAAPrefixInvalidation */) + return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */) default: r.expireLocked() return nil } } -func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACPrefixInvalidation bool) *tcpip.Error { +func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error { addr := r.addrWithPrefix() isIPv6Unicast := header.IsV6UnicastAddress(addr.Address) @@ -1031,8 +1032,11 @@ func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, al // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. - if r.configType == slaac { - n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACPrefixInvalidation) + switch r.configType { + case slaac: + n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) + case slaacTemp: + n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation) } } @@ -1448,12 +1452,19 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a // new address will be generated for it. - if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACPrefixInvalidation */); err != nil { + if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACInvalidation */); err != nil { return err } - if ref.configType == slaac { - n.mu.ndp.regenerateSLAACAddr(ref.addrWithPrefix().Subnet()) + prefix := ref.addrWithPrefix().Subnet() + + switch ref.configType { + case slaac: + n.mu.ndp.regenerateSLAACAddr(prefix) + case slaacTemp: + // Do not reset the generation attempts counter for the prefix as the + // temporary address is being regenerated in response to a DAD conflict. + n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */) } return nil @@ -1552,9 +1563,14 @@ const ( // multicast group). static networkEndpointConfigType = iota - // A slaac configured endpoint is an IPv6 endpoint that was - // added by SLAAC as per RFC 4862 section 5.5.3. + // A SLAAC configured endpoint is an IPv6 endpoint that was added by + // SLAAC as per RFC 4862 section 5.5.3. slaac + + // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by + // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are + // not expected to be valid (or preferred) forever; hence the term temporary. + slaacTemp ) type referencedNetworkEndpoint struct { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 41398a1b6..4a2dc3dc6 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -464,6 +464,10 @@ type Stack struct { // (IIDs) as outlined by RFC 7217. opaqueIIDOpts OpaqueInterfaceIdentifierOptions + // tempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + tempIIDSeed []byte + // forwarder holds the packets that wait for their link-address resolutions // to complete, and forwards them when each resolution is done. forwarder *forwardQueue @@ -541,6 +545,21 @@ type Options struct { // // RandSource must be thread-safe. RandSource mathrand.Source + + // TempIIDSeed is used to seed the initial temporary interface identifier + // history value used to generate IIDs for temporary SLAAC addresses. + // + // Temporary SLAAC adresses are short-lived addresses which are unpredictable + // and random from the perspective of other nodes on the network. It is + // recommended that the seed be a random byte buffer of at least + // header.IIDSize bytes to make sure that temporary SLAAC addresses are + // sufficiently random. It should follow minimum randomness requirements for + // security as outlined by RFC 4086. + // + // Note: using a nil value, the same seed across netstack program runs, or a + // seed that is too small would reduce randomness and increase predictability, + // defeating the purpose of temporary SLAAC addresses. + TempIIDSeed []byte } // TransportEndpointInfo holds useful information about a transport endpoint @@ -664,6 +683,7 @@ func New(opts Options) *Stack { uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, opaqueIIDOpts: opts.OpaqueIIDOpts, + tempIIDSeed: opts.TempIIDSeed, forwarder: newForwardQueue(), randomGenerator: mathrand.New(randSrc), } -- cgit v1.2.3 From 043b7d83bd76c616fa32b815528eec77f2aad5ff Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 30 Apr 2020 10:21:57 -0700 Subject: Prefer temporary addresses Implement rule 7 of Source Address Selection RFC 6724 section 5. This makes temporary (short-lived) addresses preferred over non-temporary addresses when earlier rules are equal. Test: stack_test.TestIPv6SourceAddressSelectionScopeAndSameAddress PiperOrigin-RevId: 309250975 --- pkg/tcpip/stack/nic.go | 7 +++++- pkg/tcpip/stack/stack_test.go | 56 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 58 insertions(+), 5 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 25188b4fb..440970a21 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -452,7 +452,7 @@ type ipv6AddrCandidate struct { // primaryIPv6Endpoint returns an IPv6 endpoint following Source Address // Selection (RFC 6724 section 5). // -// Note, only rules 1-3 are followed. +// Note, only rules 1-3 and 7 are followed. // // remoteAddr must be a valid IPv6 address. func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { @@ -523,6 +523,11 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn return sbDep } + // Prefer temporary addresses as per RFC 6724 section 5 rule 7. + if saTemp, sbTemp := sa.ref.configType == slaacTemp, sb.ref.configType == slaacTemp; saTemp != sbTemp { + return saTemp + } + // sa and sb are equal, return the endpoint that is closest to the front of // the primary endpoint list. return i < j diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c7634ceb1..4a686c891 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2870,14 +2870,25 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") nicID = 1 + lifetimeSeconds = 9999 ) + prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1) + + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], nil, nicID) + tempGlobalAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr1.Address).Address + tempGlobalAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr2.Address).Address + // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test. tests := []struct { - name string - nicAddrs []tcpip.Address - connectAddr tcpip.Address - expectedLocalAddr tcpip.Address + name string + slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix + nicAddrs []tcpip.Address + slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix + connectAddr tcpip.Address + expectedLocalAddr tcpip.Address }{ // Test Rule 1 of RFC 6724 section 5. { @@ -2967,6 +2978,22 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { expectedLocalAddr: uniqueLocalAddr1, }, + // Test Rule 7 of RFC 6724 section 5. + { + name: "Temp Global most preferred (last address)", + slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: globalAddr2, + expectedLocalAddr: tempGlobalAddr1, + }, + { + name: "Temp Global most preferred (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + slaacPrefixForTempAddrAfterNICAddrAdd: prefix1, + connectAddr: globalAddr2, + expectedLocalAddr: tempGlobalAddr1, + }, + // Test returning the endpoint that is closest to the front when // candidate addresses are "equal" from the perspective of RFC 6724 // section 5. @@ -2988,6 +3015,13 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { connectAddr: uniqueLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, + { + name: "Temp Global for Global", + slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, + slaacPrefixForTempAddrAfterNICAddrAdd: prefix2, + connectAddr: globalAddr1, + expectedLocalAddr: tempGlobalAddr2, + }, } for _, test := range tests { @@ -2996,6 +3030,12 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDispatcher{}, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -3007,12 +3047,20 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { }}) s.AddLinkAddress(nicID, llAddr3, linkAddr3) + if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) { + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) + } + for _, a := range test.nicAddrs { if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil { t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err) } } + if test.slaacPrefixForTempAddrAfterNICAddrAdd != (tcpip.AddressWithPrefix{}) { + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrAfterNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) + } + if t.Failed() { t.FailNow() } -- cgit v1.2.3 From ae15d90436ec5ecd8795bed2a357b1990123e8fd Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 30 Apr 2020 16:39:18 -0700 Subject: FIFO QDisc implementation Updates #231 PiperOrigin-RevId: 309323808 --- benchmarks/tcp/BUILD | 1 + benchmarks/tcp/tcp_proxy.go | 3 +- pkg/tcpip/link/fdbased/BUILD | 1 + pkg/tcpip/link/fdbased/endpoint.go | 104 +++++++---- pkg/tcpip/link/fdbased/endpoint_unsafe.go | 10 -- pkg/tcpip/link/qdisc/fifo/BUILD | 19 +++ pkg/tcpip/link/qdisc/fifo/endpoint.go | 209 +++++++++++++++++++++++ pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go | 84 +++++++++ pkg/tcpip/network/arp/arp.go | 7 + pkg/tcpip/network/ipv4/ipv4.go | 5 + pkg/tcpip/network/ipv6/ipv6.go | 5 + pkg/tcpip/stack/forwarder_test.go | 4 + pkg/tcpip/stack/packet_buffer.go | 6 + pkg/tcpip/stack/registration.go | 4 + pkg/tcpip/stack/route.go | 6 + pkg/tcpip/stack/stack_test.go | 4 + pkg/tcpip/transport/tcp/connect.go | 3 + runsc/boot/BUILD | 1 + runsc/boot/config.go | 5 + runsc/boot/network.go | 50 ++++++ runsc/main.go | 8 +- runsc/sandbox/network.go | 5 +- 22 files changed, 497 insertions(+), 47 deletions(-) create mode 100644 pkg/tcpip/link/qdisc/fifo/BUILD create mode 100644 pkg/tcpip/link/qdisc/fifo/endpoint.go create mode 100644 pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go (limited to 'pkg/tcpip/stack') diff --git a/benchmarks/tcp/BUILD b/benchmarks/tcp/BUILD index d5e401acc..6dde7d9e6 100644 --- a/benchmarks/tcp/BUILD +++ b/benchmarks/tcp/BUILD @@ -10,6 +10,7 @@ go_binary( "//pkg/tcpip", "//pkg/tcpip/adapters/gonet", "//pkg/tcpip/link/fdbased", + "//pkg/tcpip/link/qdisc/fifo", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go index 73b7c4f5b..dc1593b34 100644 --- a/benchmarks/tcp/tcp_proxy.go +++ b/benchmarks/tcp/tcp_proxy.go @@ -36,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" + "gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -203,7 +204,7 @@ func newNetstackImpl(mode string) (impl, error) { if err != nil { return nil, fmt.Errorf("failed to create FD endpoint: %v", err) } - if err := s.CreateNIC(nicID, ep); err != nil { + if err := s.CreateNIC(nicID, fifo.New(ep, runtime.GOMAXPROCS(0), 1000)); err != nil { return nil, fmt.Errorf("error creating NIC %q: %v", *iface, err) } if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index abe725548..aa6db9aea 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -14,6 +14,7 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/binary", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index b857ce9d0..53a9712c6 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -44,6 +44,7 @@ import ( "syscall" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -428,7 +429,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } } - vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr) + vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) return rawfile.NonBlockingWrite3(fd, vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) } @@ -439,19 +440,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil) } -// WritePackets writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -// -// NOTE: This API uses sendmmsg to batch packets. As a result the underlying FD -// picked to write the packet out has to be the same for all packets in the -// list. In other words all packets in the batch should belong to the same -// flow. -func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - n := pkts.Len() - - mmsgHdrs := make([]rawfile.MMsgHdr, n) - i := 0 - for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { +func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) { + // Send a batch of packets through batchFD. + mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) + for _, pkt := range batch { var ethHdrBuf []byte iovLen := 0 if e.hdrSize > 0 { @@ -459,13 +451,13 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe ethHdrBuf = make([]byte, header.EthernetMinimumSize) eth := header.Ethernet(ethHdrBuf) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, - Type: protocol, + DstAddr: pkt.EgressRoute.RemoteLinkAddress, + Type: pkt.NetworkProtocolNumber, } // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + if pkt.EgressRoute.LocalLinkAddress != "" { + ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress } else { ethHdr.SrcAddr = e.addr } @@ -473,34 +465,34 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe iovLen++ } - var vnetHdrBuf []byte vnetHdr := virtioNetHdr{} + var vnetHdrBuf []byte if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { - if gso != nil { + if pkt.GSOOptions != nil { vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) - if gso.NeedsCsum { + if pkt.GSOOptions.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM - vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen - vnetHdr.csumOffset = gso.CsumOffset + vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen + vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset } - if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS { - switch gso.Type { + if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data.Size()) > pkt.GSOOptions.MSS { + switch pkt.GSOOptions.Type { case stack.GSOTCPv4: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 case stack.GSOTCPv6: vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 default: - panic(fmt.Sprintf("Unknown gso type: %v", gso.Type)) + panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type)) } - vnetHdr.gsoSize = gso.MSS + vnetHdr.gsoSize = pkt.GSOOptions.MSS } } - vnetHdrBuf = vnetHdrToByteSlice(&vnetHdr) + vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) iovLen++ } iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views())) - mmsgHdr := &mmsgHdrs[i] + var mmsgHdr rawfile.MMsgHdr mmsgHdr.Msg.Iov = &iovecs[0] iovecIdx := 0 if vnetHdrBuf != nil { @@ -535,22 +527,68 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe pktSize += vec.Len } mmsgHdr.Msg.Iovlen = uint64(iovecIdx) - i++ + mmsgHdrs = append(mmsgHdrs, mmsgHdr) } packets := 0 - for packets < n { - fd := e.fds[pkts.Front().Hash%uint32(len(e.fds))] - sent, err := rawfile.NonBlockingSendMMsg(fd, mmsgHdrs) + for len(mmsgHdrs) > 0 { + sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs) if err != nil { return packets, err } packets += sent mmsgHdrs = mmsgHdrs[sent:] } + return packets, nil } +// WritePackets writes outbound packets to the underlying file descriptors. If +// one is not currently writable, the packet is dropped. +// +// Being a batch API, each packet in pkts should have the following +// fields populated: +// - pkt.EgressRoute +// - pkt.GSOOptions +// - pkt.NetworkProtocolNumber +func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // Preallocate to avoid repeated reallocation as we append to batch. + // batchSz is 47 because when SWGSO is in use then a single 65KB TCP + // segment can get split into 46 segments of 1420 bytes and a single 216 + // byte segment. + const batchSz = 47 + batch := make([]*stack.PacketBuffer, 0, batchSz) + batchFD := -1 + sentPackets := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if len(batch) == 0 { + batchFD = e.fds[pkt.Hash%uint32(len(e.fds))] + } + pktFD := e.fds[pkt.Hash%uint32(len(e.fds))] + if sendNow := pktFD != batchFD; !sendNow { + batch = append(batch, pkt) + continue + } + n, err := e.sendBatch(batchFD, batch) + sentPackets += n + if err != nil { + return sentPackets, err + } + batch = batch[:0] + batch = append(batch, pkt) + batchFD = pktFD + } + + if len(batch) != 0 { + n, err := e.sendBatch(batchFD, batch) + sentPackets += n + if err != nil { + return sentPackets, err + } + } + return sentPackets, nil +} + // viewsEqual tests whether v1 and v2 refer to the same backing bytes. func viewsEqual(vs1, vs2 []buffer.View) bool { return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0]) diff --git a/pkg/tcpip/link/fdbased/endpoint_unsafe.go b/pkg/tcpip/link/fdbased/endpoint_unsafe.go index d81858353..df14eaad1 100644 --- a/pkg/tcpip/link/fdbased/endpoint_unsafe.go +++ b/pkg/tcpip/link/fdbased/endpoint_unsafe.go @@ -17,17 +17,7 @@ package fdbased import ( - "reflect" "unsafe" ) const virtioNetHdrSize = int(unsafe.Sizeof(virtioNetHdr{})) - -func vnetHdrToByteSlice(hdr *virtioNetHdr) (slice []byte) { - *(*reflect.SliceHeader)(unsafe.Pointer(&slice)) = reflect.SliceHeader{ - Data: uintptr((unsafe.Pointer(hdr))), - Len: virtioNetHdrSize, - Cap: virtioNetHdrSize, - } - return -} diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD new file mode 100644 index 000000000..054c213bc --- /dev/null +++ b/pkg/tcpip/link/qdisc/fifo/BUILD @@ -0,0 +1,19 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "fifo", + srcs = [ + "endpoint.go", + "packet_buffer_queue.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/sleep", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go new file mode 100644 index 000000000..be9fec3b3 --- /dev/null +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -0,0 +1,209 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fifo provides the implementation of data-link layer endpoints that +// wrap another endpoint and queues all outbound packets and asynchronously +// dispatches them to the lower endpoint. +package fifo + +import ( + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// endpoint represents a LinkEndpoint which implements a FIFO queue for all +// outgoing packets. endpoint can have 1 or more underlying queueDispatchers. +// All outgoing packets are consistenly hashed to a single underlying queue +// using the PacketBuffer.Hash if set, otherwise all packets are queued to the +// first queue to avoid reordering in case of missing hash. +type endpoint struct { + dispatcher stack.NetworkDispatcher + lower stack.LinkEndpoint + wg sync.WaitGroup + dispatchers []*queueDispatcher +} + +// queueDispatcher is responsible for dispatching all outbound packets in its +// queue. It will also smartly batch packets when possible and write them +// through the lower LinkEndpoint. +type queueDispatcher struct { + lower stack.LinkEndpoint + q *packetBufferQueue + newPacketWaker sleep.Waker + closeWaker sleep.Waker +} + +// New creates a new fifo link endpoint with the n queues with maximum +// capacity of queueLen. +func New(lower stack.LinkEndpoint, n int, queueLen int) stack.LinkEndpoint { + e := &endpoint{ + lower: lower, + } + // Create the required dispatchers + for i := 0; i < n; i++ { + qd := &queueDispatcher{ + q: &packetBufferQueue{limit: queueLen}, + lower: lower, + } + e.dispatchers = append(e.dispatchers, qd) + e.wg.Add(1) + go func() { + defer e.wg.Done() + qd.dispatchLoop() + }() + } + return e +} + +func (q *queueDispatcher) dispatchLoop() { + const newPacketWakerID = 1 + const closeWakerID = 2 + s := sleep.Sleeper{} + s.AddWaker(&q.newPacketWaker, newPacketWakerID) + s.AddWaker(&q.closeWaker, closeWakerID) + defer s.Done() + + const batchSize = 32 + var batch stack.PacketBufferList + for { + id, ok := s.Fetch(true) + if ok && id == closeWakerID { + return + } + for pkt := q.q.dequeue(); pkt != nil; pkt = q.q.dequeue() { + batch.PushBack(pkt) + if batch.Len() < batchSize && !q.q.empty() { + continue + } + // We pass a protocol of zero here because each packet carries its + // NetworkProtocol. + q.lower.WritePackets(nil /* route */, nil /* gso */, batch, 0 /* protocol */) + for pkt := batch.Front(); pkt != nil; pkt = pkt.Next() { + pkt.EgressRoute.Release() + batch.Remove(pkt) + } + batch.Reset() + } + } +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. +func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) +} + +// Attach implements stack.LinkEndpoint.Attach. +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher + e.lower.Attach(e) +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. +func (e *endpoint) MTU() uint32 { + return e.lower.MTU() +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.lower.Capabilities() +} + +// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. +func (e *endpoint) MaxHeaderLength() uint16 { + return e.lower.MaxHeaderLength() +} + +// LinkAddress implements stack.LinkEndpoint.LinkAddress. +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + return e.lower.LinkAddress() +} + +// GSOMaxSize returns the maximum GSO packet size. +func (e *endpoint) GSOMaxSize() uint32 { + if gso, ok := e.lower.(stack.GSOEndpoint); ok { + return gso.GSOMaxSize() + } + return 0 +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { + // WritePacket caller's do not set the following fields in PacketBuffer + // so we populate them here. + newRoute := r.Clone() + pkt.EgressRoute = &newRoute + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = protocol + d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] + if !d.q.enqueue(&pkt) { + return tcpip.ErrNoBufferSpace + } + d.newPacketWaker.Assert() + return nil +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +// +// Being a batch API each packet in pkts should have the following fields +// populated: +// - pkt.EgressRoute +// - pkt.GSOOptions +// - pkt.NetworkProtocolNumber +func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + enqueued := 0 + for pkt := pkts.Front(); pkt != nil; { + d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] + nxt := pkt.Next() + // Since qdisc can hold onto a packet for long we should Clone + // the route here to ensure it doesn't get released while the + // packet is still in our queue. + newRoute := pkt.EgressRoute.Clone() + pkt.EgressRoute = &newRoute + if !d.q.enqueue(pkt) { + if enqueued > 0 { + d.newPacketWaker.Assert() + } + return enqueued, tcpip.ErrNoBufferSpace + } + pkt = nxt + enqueued++ + d.newPacketWaker.Assert() + } + return enqueued, nil +} + +// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. +func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + return e.lower.WriteRawPacket(vv) +} + +// Wait implements stack.LinkEndpoint.Wait. +func (e *endpoint) Wait() { + e.lower.Wait() + + // The linkEP is gone. Teardown the outbound dispatcher goroutines. + for i := range e.dispatchers { + e.dispatchers[i].closeWaker.Assert() + } + + e.wg.Wait() +} diff --git a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go new file mode 100644 index 000000000..eb5abb906 --- /dev/null +++ b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go @@ -0,0 +1,84 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fifo + +import ( + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// packetBufferQueue is a bounded, thread-safe queue of PacketBuffers. +// +type packetBufferQueue struct { + mu sync.Mutex + list stack.PacketBufferList + limit int + used int +} + +// emptyLocked determines if the queue is empty. +// Preconditions: q.mu must be held. +func (q *packetBufferQueue) emptyLocked() bool { + return q.used == 0 +} + +// empty determines if the queue is empty. +func (q *packetBufferQueue) empty() bool { + q.mu.Lock() + r := q.emptyLocked() + q.mu.Unlock() + + return r +} + +// setLimit updates the limit. No PacketBuffers are immediately dropped in case +// the queue becomes full due to the new limit. +func (q *packetBufferQueue) setLimit(limit int) { + q.mu.Lock() + q.limit = limit + q.mu.Unlock() +} + +// enqueue adds the given packet to the queue. +// +// Returns true when the PacketBuffer is successfully added to the queue, in +// which case ownership of the reference is transferred to the queue. And +// returns false if the queue is full, in which case ownership is retained by +// the caller. +func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool { + q.mu.Lock() + r := q.used < q.limit + if r { + q.list.PushBack(s) + q.used++ + } + q.mu.Unlock() + + return r +} + +// dequeue removes and returns the next PacketBuffer from queue, if one exists. +// Ownership is transferred to the caller. +func (q *packetBufferQueue) dequeue() *stack.PacketBuffer { + q.mu.Lock() + s := q.list.Front() + if s != nil { + q.list.Remove(s) + q.used-- + } + q.mu.Unlock() + + return s +} diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7acbfa0a8..9f47b4ff2 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -42,6 +42,7 @@ const ( // endpoint implements stack.NetworkEndpoint. type endpoint struct { + protocol *protocol nicID tcpip.NICID linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache @@ -83,6 +84,11 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderPara return tcpip.ErrNotSupported } +// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return e.protocol.Number() +} + // WritePackets implements stack.NetworkEndpoint.WritePackets. func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported @@ -142,6 +148,7 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi return nil, tcpip.ErrBadLocalAddress } return &endpoint{ + protocol: p, nicID: nicID, linkEP: sender, linkAddrCache: linkAddrCache, diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 104aafbed..a9dec0c0e 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -118,6 +118,11 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } +// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return e.protocol.Number() +} + // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to // write. It assumes that the IP header is entirely in pkt.Header but does not // assume that only the IP header is in pkt.Header. It assumes that the input diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 331b0817b..82928fb66 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -416,6 +416,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // Close cleans up resources associated with the endpoint. func (*endpoint) Close() {} +// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return e.protocol.Number() +} + type protocol struct { // defaultTTL is the current default TTL for the protocol. Only the // uint8 portion of it is meaningful and it must be accessed diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index e9c652042..9bc97b84e 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -89,6 +89,10 @@ func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities { return f.ep.Capabilities() } +func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return f.proto.Number() +} + func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index dc125f25e..06d312207 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -60,6 +60,12 @@ type PacketBuffer struct { // Owner is implemented by task to get the uid and gid. // Only set for locally generated packets. Owner tcpip.PacketOwner + + // The following fields are only set by the qdisc layer when the packet + // is added to a queue. + EgressRoute *Route + GSOOptions *GSO + NetworkProtocolNumber tcpip.NetworkProtocolNumber } // Clone makes a copy of pk. It clones the Data field, which creates a new diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 23ca9ee03..b331427c6 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -269,6 +269,10 @@ type NetworkEndpoint interface { // Close is called when the endpoint is reomved from a stack. Close() + + // NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for + // this endpoint. + NetworkProtocolNumber() tcpip.NetworkProtocolNumber } // NetworkProtocol is the interface that needs to be implemented by network diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index a0e5e0300..53148dc03 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -217,6 +217,12 @@ func (r *Route) MTU() uint32 { return r.ref.ep.MTU() } +// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying +// network endpoint. +func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return r.ref.ep.NetworkProtocolNumber() +} + // Release frees all resources associated with the route. func (r *Route) Release() { if r.ref != nil { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 4a686c891..3f4e08434 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -126,6 +126,10 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { return f.ep.Capabilities() } +func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return f.proto.Number() +} + func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 76e27bf26..a7e088d4e 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -801,6 +801,9 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso pkt.Header = buffer.NewPrependable(hdrSize) pkt.Hash = tf.txHash pkt.Owner = owner + pkt.EgressRoute = r + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = r.NetworkProtocolNumber() data.ReadToVV(&pkt.Data, packetSize) buildTCPHdr(r, tf, &pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index ed3c8f546..abcaf4206 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -88,6 +88,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/link/fdbased", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/qdisc/fifo", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", diff --git a/runsc/boot/config.go b/runsc/boot/config.go index 715a19112..6d6a705f8 100644 --- a/runsc/boot/config.go +++ b/runsc/boot/config.go @@ -187,6 +187,10 @@ type Config struct { // SoftwareGSO indicates that software segmentation offload is enabled. SoftwareGSO bool + // QDisc indicates the type of queuening discipline to use by default + // for non-loopback interfaces. + QDisc QueueingDiscipline + // LogPackets indicates that all network packets should be logged. LogPackets bool @@ -294,6 +298,7 @@ func (c *Config) ToFlags() []string { "--gso=" + strconv.FormatBool(c.HardwareGSO), "--software-gso=" + strconv.FormatBool(c.SoftwareGSO), "--overlayfs-stale-read=" + strconv.FormatBool(c.OverlayfsStaleRead), + "--qdisc=" + c.QDisc.String(), } if c.CPUNumFromQuota { f = append(f, "--cpu-num-from-quota") diff --git a/runsc/boot/network.go b/runsc/boot/network.go index bee6ee336..0af30456e 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -17,6 +17,7 @@ package boot import ( "fmt" "net" + "runtime" "strings" "syscall" @@ -24,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -75,6 +77,44 @@ type DefaultRoute struct { Name string } +// QueueingDiscipline is used to specify the kind of Queueing Discipline to +// apply for a give FDBasedLink. +type QueueingDiscipline int + +const ( + // QDiscNone disables any queueing for the underlying FD. + QDiscNone QueueingDiscipline = iota + + // QDiscFIFO applies a simple fifo based queue to the underlying + // FD. + QDiscFIFO +) + +// MakeQueueingDiscipline if possible the equivalent QueuingDiscipline for s +// else returns an error. +func MakeQueueingDiscipline(s string) (QueueingDiscipline, error) { + switch s { + case "none": + return QDiscNone, nil + case "fifo": + return QDiscFIFO, nil + default: + return 0, fmt.Errorf("unsupported qdisc specified: %q", s) + } +} + +// String implements fmt.Stringer. +func (q QueueingDiscipline) String() string { + switch q { + case QDiscNone: + return "none" + case QDiscFIFO: + return "fifo" + default: + panic(fmt.Sprintf("Invalid queueing discipline: %d", q)) + } +} + // FDBasedLink configures an fd-based link. type FDBasedLink struct { Name string @@ -84,6 +124,7 @@ type FDBasedLink struct { GSOMaxSize uint32 SoftwareGSOEnabled bool LinkAddress net.HardwareAddr + QDisc QueueingDiscipline // NumChannels controls how many underlying FD's are to be used to // create this endpoint. @@ -185,6 +226,8 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } mac := tcpip.LinkAddress(link.LinkAddress) + log.Infof("gso max size is: %d", link.GSOMaxSize) + linkEP, err := fdbased.New(&fdbased.Options{ FDs: FDs, MTU: uint32(link.MTU), @@ -199,6 +242,13 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct return err } + switch link.QDisc { + case QDiscNone: + case QDiscFIFO: + log.Infof("Enabling FIFO QDisc on %q", link.Name) + linkEP = fifo.New(linkEP, runtime.GOMAXPROCS(0), 1000) + } + log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels) if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil { return err diff --git a/runsc/main.go b/runsc/main.go index 8e594c58e..0216e9481 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -72,6 +72,7 @@ var ( network = flag.String("network", "sandbox", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.") hardwareGSO = flag.Bool("gso", true, "enable hardware segmentation offload if it is supported by a network device.") softwareGSO = flag.Bool("software-gso", true, "enable software segmentation offload when hardware ofload can't be enabled.") + qDisc = flag.String("qdisc", "none", "specifies which queueing discipline to apply by default to the non loopback nics used by the sandbox.") fileAccess = flag.String("file-access", "exclusive", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.") fsGoferHostUDS = flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.") overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.") @@ -198,6 +199,11 @@ func main() { cmd.Fatalf("%v", err) } + queueingDiscipline, err := boot.MakeQueueingDiscipline(*qDisc) + if err != nil { + cmd.Fatalf("%s", err) + } + // Sets the reference leak check mode. Also set it in config below to // propagate it to child processes. refs.SetLeakMode(refsLeakMode) @@ -232,7 +238,7 @@ func main() { OverlayfsStaleRead: *overlayfsStaleRead, CPUNumFromQuota: *cpuNumFromQuota, VFS2: *vfs2Enabled, - + QDisc: queueingDiscipline, TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot, TestOnlyTestNameEnv: *testOnlyTestNameEnv, } diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index bc093fba5..209bfdb20 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -62,7 +62,7 @@ func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Confi // Build the path to the net namespace of the sandbox process. // This is what we will copy. nsPath := filepath.Join("/proc", strconv.Itoa(pid), "ns/net") - if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.HardwareGSO, conf.SoftwareGSO, conf.NumNetworkChannels); err != nil { + if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.HardwareGSO, conf.SoftwareGSO, conf.NumNetworkChannels, conf.QDisc); err != nil { return fmt.Errorf("creating interfaces from net namespace %q: %v", nsPath, err) } case boot.NetworkHost: @@ -115,7 +115,7 @@ func isRootNS() (bool, error) { // createInterfacesAndRoutesFromNS scrapes the interface and routes from the // net namespace with the given path, creates them in the sandbox, and removes // them from the host. -func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareGSO bool, softwareGSO bool, numNetworkChannels int) error { +func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareGSO bool, softwareGSO bool, numNetworkChannels int, qDisc boot.QueueingDiscipline) error { // Join the network namespace that we will be copying. restore, err := joinNetNS(nsPath) if err != nil { @@ -201,6 +201,7 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG MTU: iface.MTU, Routes: routes, NumChannels: numNetworkChannels, + QDisc: qDisc, } // Get the link for the interface. -- cgit v1.2.3 From 5e1e61fbcbe8fa3cc8b104fadb8cdef3ad29c31f Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Fri, 1 May 2020 16:08:26 -0700 Subject: Automated rollback of changelist 308674219 PiperOrigin-RevId: 309491861 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 5 +- pkg/sentry/socket/netfilter/udp_matcher.go | 5 +- pkg/tcpip/buffer/view.go | 55 ++++++++++---- pkg/tcpip/buffer/view_test.go | 113 +++++++++++++++++++++++++++++ pkg/tcpip/link/fdbased/endpoint.go | 3 + pkg/tcpip/link/loopback/loopback.go | 10 +-- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 65 +++++++++++++---- pkg/tcpip/network/arp/arp.go | 5 +- pkg/tcpip/network/ipv4/icmp.go | 20 +++-- pkg/tcpip/network/ipv4/ipv4.go | 12 ++- pkg/tcpip/network/ipv6/icmp.go | 74 ++++++++++++------- pkg/tcpip/network/ipv6/icmp_test.go | 3 +- pkg/tcpip/network/ipv6/ipv6.go | 6 +- pkg/tcpip/stack/forwarder_test.go | 13 ++-- pkg/tcpip/stack/iptables.go | 22 +++++- pkg/tcpip/stack/iptables_targets.go | 23 ++++-- pkg/tcpip/stack/nic.go | 34 +++------ pkg/tcpip/stack/packet_buffer.go | 8 +- pkg/tcpip/stack/stack_test.go | 10 ++- pkg/tcpip/stack/transport_test.go | 5 +- pkg/tcpip/transport/icmp/endpoint.go | 8 +- pkg/tcpip/transport/tcp/segment.go | 29 +++++--- pkg/tcpip/transport/tcp/tcp_test.go | 4 +- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 9 ++- 26 files changed, 402 insertions(+), 147 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index ff1cfd8f6..55c0f04f3 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -121,12 +121,13 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa tcpHeader = header.TCP(pkt.TransportHeader) } else { // The TCP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.TCPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + if !ok { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(pkt.Data.First()) + tcpHeader = header.TCP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 3359418c1..04d03d494 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -120,12 +120,13 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa udpHeader = header.UDP(pkt.TransportHeader) } else { // The UDP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.UDPMinimumSize { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(pkt.Data.First()) + udpHeader = header.UDP(hdr) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8ec5d5d5c..f01217c91 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -77,7 +77,8 @@ func NewVectorisedView(size int, views []View) VectorisedView { return VectorisedView{views: views, size: size} } -// TrimFront removes the first "count" bytes of the vectorised view. +// TrimFront removes the first "count" bytes of the vectorised view. It panics +// if count > vv.Size(). func (vv *VectorisedView) TrimFront(count int) { for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { @@ -86,7 +87,7 @@ func (vv *VectorisedView) TrimFront(count int) { return } count -= len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } } @@ -104,7 +105,7 @@ func (vv *VectorisedView) Read(v View) (copied int, err error) { count -= len(vv.views[0]) copy(v[copied:], vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } if copied == 0 { return 0, io.EOF @@ -126,7 +127,7 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int count -= len(vv.views[0]) dstVV.AppendView(vv.views[0]) copied += len(vv.views[0]) - vv.RemoveFirst() + vv.removeFirst() } return copied } @@ -162,22 +163,37 @@ func (vv *VectorisedView) Clone(buffer []View) VectorisedView { return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} } -// First returns the first view of the vectorised view. -func (vv *VectorisedView) First() View { +// PullUp returns the first "count" bytes of the vectorised view. If those +// bytes aren't already contiguous inside the vectorised view, PullUp will +// reallocate as needed to make them contiguous. PullUp fails and returns false +// when count > vv.Size(). +func (vv *VectorisedView) PullUp(count int) (View, bool) { if len(vv.views) == 0 { - return nil + return nil, count == 0 + } + if count <= len(vv.views[0]) { + return vv.views[0][:count], true + } + if count > vv.size { + return nil, false } - return vv.views[0] -} -// RemoveFirst removes the first view of the vectorised view. -func (vv *VectorisedView) RemoveFirst() { - if len(vv.views) == 0 { - return + newFirst := NewView(count) + i := 0 + for offset := 0; offset < count; i++ { + copy(newFirst[offset:], vv.views[i]) + if count-offset < len(vv.views[i]) { + vv.views[i].TrimFront(count - offset) + break + } + offset += len(vv.views[i]) + vv.views[i] = nil } - vv.size -= len(vv.views[0]) - vv.views[0] = nil - vv.views = vv.views[1:] + // We're guaranteed that i > 0, since count is too large for the first + // view. + vv.views[i-1] = newFirst + vv.views = vv.views[i-1:] + return newFirst, true } // Size returns the size in bytes of the entire content stored in the vectorised view. @@ -225,3 +241,10 @@ func (vv *VectorisedView) Readers() []bytes.Reader { } return readers } + +// removeFirst panics when len(vv.views) < 1. +func (vv *VectorisedView) removeFirst() { + vv.size -= len(vv.views[0]) + vv.views[0] = nil + vv.views = vv.views[1:] +} diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 106e1994c..c56795c7b 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -16,6 +16,7 @@ package buffer import ( + "bytes" "reflect" "testing" ) @@ -370,3 +371,115 @@ func TestVVRead(t *testing.T) { }) } } + +var pullUpTestCases = []struct { + comment string + in VectorisedView + count int + want []byte + result VectorisedView + ok bool +}{ + { + comment: "simple case", + in: vv(2, "12"), + count: 1, + want: []byte("1"), + result: vv(2, "12"), + ok: true, + }, + { + comment: "entire View", + in: vv(2, "1", "2"), + count: 1, + want: []byte("1"), + result: vv(2, "1", "2"), + ok: true, + }, + { + comment: "spanning across two Views", + in: vv(3, "1", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, + { + comment: "spanning across all Views", + in: vv(5, "1", "23", "45"), + count: 5, + want: []byte("12345"), + result: vv(5, "12345"), + ok: true, + }, + { + comment: "count = 0", + in: vv(1, "1"), + count: 0, + want: []byte{}, + result: vv(1, "1"), + ok: true, + }, + { + comment: "count = size", + in: vv(1, "1"), + count: 1, + want: []byte("1"), + result: vv(1, "1"), + ok: true, + }, + { + comment: "count too large", + in: vv(3, "1", "23"), + count: 4, + want: nil, + result: vv(3, "1", "23"), + ok: false, + }, + { + comment: "empty vv", + in: vv(0, ""), + count: 1, + want: nil, + result: vv(0, ""), + ok: false, + }, + { + comment: "empty vv, count = 0", + in: vv(0, ""), + count: 0, + want: nil, + result: vv(0, ""), + ok: true, + }, + { + comment: "empty views", + in: vv(3, "", "1", "", "23"), + count: 2, + want: []byte("12"), + result: vv(3, "12", "3"), + ok: true, + }, +} + +func TestPullUp(t *testing.T) { + for _, c := range pullUpTestCases { + got, ok := c.in.PullUp(c.count) + + // Is the return value right? + if ok != c.ok { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", + c.comment, c.count, c.in, ok, c.ok) + } + if bytes.Compare(got, View(c.want)) != 0 { + t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", + c.comment, c.count, c.in, got, c.want) + } + + // Is the underlying structure right? + if !reflect.DeepEqual(c.in, c.result) { + t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v", + c.comment, c.count, c.in, c.result) + } + } +} diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 53a9712c6..affa1bbdf 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -436,6 +436,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne if pkt.Data.Size() == 0 { return rawfile.NonBlockingWrite(fd, pkt.Header.View()) } + if pkt.Header.UsedLength() == 0 { + return rawfile.NonBlockingWrite(fd, pkt.Data.ToView()) + } return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil) } diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 1e2255bfa..073c84ef9 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // Reject the packet if it's shorter than an ethernet header. - if vv.Size() < header.EthernetMinimumSize { + // There should be an ethernet header at the beginning of vv. + hdr, ok := vv.PullUp(header.EthernetMinimumSize) + if !ok { + // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } - - // There should be an ethernet header at the beginning of vv. - linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize]) + linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 27ea3f531..33f640b85 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) { // Wait for packet to be received, then check it. c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet") c.mu.Lock() - rcvd := []byte(c.packets[0].vv.First()) + rcvd := []byte(c.packets[0].vv.ToView()) c.packets = c.packets[:0] c.mu.Unlock() diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index be2537a82..0799c8f4d 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -171,11 +171,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - first := pkt.Header.View() - if len(first) == 0 { - first = pkt.Data.First() - } - logPacket(prefix, protocol, first, gso) + logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Header.UsedLength() + pkt.Data.Size() @@ -238,7 +234,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() } -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) { +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -247,28 +243,49 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie size := uint16(0) var fragmentOffset uint16 var moreFragments bool + + // Create a clone of pkt, including any headers if present. Avoid allocating + // backing memory for the clone. + views := [8]buffer.View{} + vv := buffer.NewVectorisedView(0, views[:0]) + vv.AppendView(pkt.Header.View()) + vv.Append(pkt.Data) + switch protocol { case header.IPv4ProtocolNumber: - ipv4 := header.IPv4(b) + hdr, ok := vv.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + ipv4 := header.IPv4(hdr) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - b = b[ipv4.HeaderLength():] + vv.TrimFront(int(ipv4.HeaderLength())) id = int(ipv4.ID()) case header.IPv6ProtocolNumber: - ipv6 := header.IPv6(b) + hdr, ok := vv.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + ipv6 := header.IPv6(hdr) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() transProto = ipv6.NextHeader() size = ipv6.PayloadLength() - b = b[header.IPv6MinimumSize:] + vv.TrimFront(header.IPv6MinimumSize) case header.ARPProtocolNumber: - arp := header.ARP(b) + hdr, ok := vv.PullUp(header.ARPSize) + if !ok { + return + } + vv.TrimFront(header.ARPSize) + arp := header.ARP(hdr) log.Infof( "%s arp %v (%v) -> %v (%v) valid:%v", prefix, @@ -284,7 +301,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie // We aren't guaranteed to have a transport header - it's possible for // writes via raw endpoints to contain only network headers. - if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize { + if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize { log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto) return } @@ -297,7 +314,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - icmp := header.ICMPv4(b) + hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + if !ok { + break + } + icmp := header.ICMPv4(hdr) icmpType := "unknown" if fragmentOffset == 0 { switch icmp.Type() { @@ -330,7 +351,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.ICMPv6ProtocolNumber: transName = "icmp" - icmp := header.ICMPv6(b) + hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + if !ok { + break + } + icmp := header.ICMPv6(hdr) icmpType := "unknown" switch icmp.Type() { case header.ICMPv6DstUnreachable: @@ -361,7 +386,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.UDPProtocolNumber: transName = "udp" - udp := header.UDP(b) + hdr, ok := vv.PullUp(header.UDPMinimumSize) + if !ok { + break + } + udp := header.UDP(hdr) if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -371,7 +400,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie case header.TCPProtocolNumber: transName = "tcp" - tcp := header.TCP(b) + hdr, ok := vv.PullUp(header.TCPMinimumSize) + if !ok { + break + } + tcp := header.TCP(hdr) if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 9f47b4ff2..9d0797af7 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -99,7 +99,10 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf } func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - v := pkt.Data.First() + v, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return + } h := header.ARP(v) if !h.IsValid() { return diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index c4bf1ba5c..4cbefe5ab 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -25,7 +25,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + hdr := header.IPv4(h) // We don't use IsValid() here because ICMP only requires that the IP // header plus 8 bytes of the transport header be included. So it's @@ -34,12 +38,12 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } - hlen := int(h.HeaderLength()) - if pkt.Data.Size() < hlen || h.FragmentOffset() != 0 { + hlen := int(hdr.HeaderLength()) + if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 { // We won't be able to handle this if it doesn't contain the // full IPv4 header, or if it's a fragment not at offset 0 // (because it won't have the transport header). @@ -48,15 +52,15 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) - p := h.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + p := hdr.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv4MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok { received.Invalid.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a9dec0c0e..1d61fddad 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -333,7 +333,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. - ip := header.IPv4(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return tcpip.ErrInvalidOptionValue + } + ip := header.IPv4(h) if !ip.IsValid(pkt.Data.Size()) { return tcpip.ErrInvalidOptionValue } @@ -383,7 +387,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv4(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b68983d10..bdf3a0d25 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -28,7 +28,11 @@ import ( // used to find out which transport endpoint must be notified about the ICMP // packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { - h := header.IPv6(pkt.Data.First()) + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return + } + hdr := header.IPv6(h) // We don't use IsValid() here because ICMP only requires that up to // 1280 bytes of the original packet be included. So it's likely that it @@ -36,17 +40,21 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // // Drop packet if it doesn't have the basic IPv6 header or if the // original source address doesn't match the endpoint's address. - if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { + if hdr.SourceAddress() != e.id.LocalAddress { return } // Skip the IP header, then handle the fragmentation header if there // is one. pkt.Data.TrimFront(header.IPv6MinimumSize) - p := h.TransportProtocol() + p := hdr.TransportProtocol() if p == header.IPv6FragmentHeader { - f := header.IPv6Fragment(pkt.Data.First()) - if !f.IsValid() || f.FragmentOffset() != 0 { + f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize) + if !ok { + return + } + fragHdr := header.IPv6Fragment(f) + if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 { // We can't handle fragments that aren't at offset 0 // because they don't have the transport headers. return @@ -55,19 +63,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. // Skip fragmentation header and find out the actual protocol // number. pkt.Data.TrimFront(header.IPv6FragmentHeaderSize) - p = f.TransportProtocol() + p = fragHdr.TransportProtocol() } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived - v := pkt.Data.First() - if len(v) < header.ICMPv6MinimumSize { + v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) + if !ok { received.Invalid.Increment() return } @@ -76,11 +84,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // Validate ICMPv6 checksum before processing the packet. // - // Only the first view in vv is accounted for by h. To account for the - // rest of vv, a shallow copy is made and the first view is removed. // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) - payload.RemoveFirst() + payload.TrimFront(len(h)) if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { received.Invalid.Increment() return @@ -101,34 +107,40 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P switch h.Type() { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() - if len(v) < header.ICMPv6PacketTooBigMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := h.MTU() + mtu := header.ICMPv6(hdr).MTU() e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() - if len(v) < header.ICMPv6DstUnreachableMinimumSize { + hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) - switch h.Code() { + switch header.ICMPv6(hdr).Code() { case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if len(v) < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { received.Invalid.Increment() return } - ns := header.NDPNeighborSolicit(h.NDPPayload()) + // The remainder of payload must be only the neighbor solicitation, so + // payload.ToView() always returns the solicitation. Per RFC 6980 section 5, + // NDP messages cannot be fragmented. Also note that in the common case NDP + // datagrams are very small and ToView() will not incur allocations. + ns := header.NDPNeighborSolicit(payload.ToView()) it, err := ns.Options().Iter(true) if err != nil { // If we have a malformed NDP NS option, drop the packet. @@ -286,12 +298,16 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if len(v) < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { received.Invalid.Increment() return } - na := header.NDPNeighborAdvert(h.NDPPayload()) + // The remainder of payload must be only the neighbor advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + na := header.NDPNeighborAdvert(payload.ToView()) it, err := na.Options().Iter(true) if err != nil { // If we have a malformed NDP NA option, drop the packet. @@ -363,14 +379,15 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + if !ok { received.Invalid.Increment() return } pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) - copy(packet, h) + copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ @@ -384,7 +401,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6EchoReply: received.EchoReply.Increment() - if len(v) < header.ICMPv6EchoMinimumSize { + if pkt.Data.Size() < header.ICMPv6EchoMinimumSize { received.Invalid.Increment() return } @@ -406,8 +423,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P case header.ICMPv6RouterAdvert: received.RouterAdvert.Increment() - p := h.NDPPayload() - if len(p) < header.NDPRAMinimumSize || !isNDPValid() { + // Is the NDP payload of sufficient size to hold a Router + // Advertisement? + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { received.Invalid.Increment() return } @@ -425,7 +443,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P return } - ra := header.NDPRouterAdvert(p) + // The remainder of payload must be only the router advertisement, so + // payload.ToView() always returns the advertisement. Per RFC 6980 section + // 5, NDP messages cannot be fragmented. Also note that in the common case + // NDP datagrams are very small and ToView() will not incur allocations. + ra := header.NDPRouterAdvert(payload.ToView()) opts := ra.Options() // Are options valid as per the wire format? diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index bd099a7f8..d412ff688 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -166,7 +166,8 @@ func TestICMPCounts(t *testing.T) { }, { typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize}, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, { typ: header.ICMPv6NeighborAdvert, size: header.ICMPv6NeighborAdvertMinimumSize, diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 82928fb66..daf1fcbc6 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,7 +171,11 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffe // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { - headerView := pkt.Data.First() + headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } h := header.IPv6(headerView) if !h.IsValid(pkt.Data.Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 9bc97b84e..8084d50bc 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -70,7 +70,10 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fwdTestNetHeaderLen) // Dispatch the packet to the transport protocol. @@ -477,7 +480,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -521,7 +524,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -568,7 +571,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] != 3 { t.Fatalf("got b[0] = %d, want = 3", b[0]) } @@ -623,7 +626,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Header.View() + b := p.Pkt.Data.ToView() if b[0] < 8 { t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6c0a4b24d..6b91159d4 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -212,6 +212,11 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. +// // NOTE: unlike the Check API the returned map contains packets that should be // dropped. func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { @@ -226,7 +231,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*Pa return drop } -// Precondition: pkt.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -271,14 +278,21 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx return chainDrop } -// Precondition: pk.NetworkHeader is set. +// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// precondition. func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data.First(). + // pkt.Data. if pkt.NetworkHeader == nil { - pkt.NetworkHeader = pkt.Data.First() + var ok bool + pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // Precondition has been violated. + panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) + } } // Check whether the packet matches the IP header filter. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 7b4543caf..8be61f4b1 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -96,9 +96,12 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { newPkt := pkt.Clone() // Set network header. - headerView := newPkt.Data.First() + headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return RuleDrop, 0 + } netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize] + newPkt.NetworkHeader = headerView hlen := int(netHeader.HeaderLength()) tlen := int(netHeader.TotalLength()) @@ -117,10 +120,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { udpHeader = header.UDP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.UDPMinimumSize { + if pkt.Data.Size() < header.UDPMinimumSize { + return RuleDrop, 0 + } + hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) + if !ok { return RuleDrop, 0 } - udpHeader = header.UDP(newPkt.Data.First()) + udpHeader = header.UDP(hdr) } udpHeader.SetDestinationPort(rt.MinPort) case header.TCPProtocolNumber: @@ -128,10 +135,14 @@ func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { if newPkt.TransportHeader != nil { tcpHeader = header.TCP(newPkt.TransportHeader) } else { - if len(pkt.Data.First()) < header.TCPMinimumSize { + if pkt.Data.Size() < header.TCPMinimumSize { return RuleDrop, 0 } - tcpHeader = header.TCP(newPkt.TransportHeader) + hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return RuleDrop, 0 + } + tcpHeader = header.TCP(hdr) } // TODO(gvisor.dev/issue/170): Need to recompute checksum // and implement nat connection tracking to support TCP. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 440970a21..7b54919bb 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1212,12 +1212,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link n.stack.stats.IP.PacketsReceived.Increment() } - if len(pkt.Data.First()) < netProto.MinimumPacketSize() { + netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - - src, dst := netProto.ParseAddresses(pkt.Data.First()) + src, dst := netProto.ParseAddresses(netHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1298,22 +1298,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - - firstData := pkt.Data.First() - pkt.Data.RemoveFirst() - - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 { - pkt.Header = buffer.NewPrependableFromView(firstData) - } else { - firstDataLen := len(firstData) - - // pkt.Header should have enough capacity to hold n.linkEP's headers. - pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen) - - // TODO(b/151227689): avoid copying the packet when forwarding - if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen)) - } + if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { + pkt.Header = buffer.NewPrependable(linkHeaderLen) } if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { @@ -1341,12 +1327,13 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - if len(pkt.Data.First()) < transProto.MinimumPacketSize() { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return @@ -1384,11 +1371,12 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp // ICMPv4 only guarantees that 8 bytes of the transport protocol will // be present in the payload. We know that the ports are within the // first 8 bytes for all known transport protocols. - if len(pkt.Data.First()) < 8 { + transHeader, ok := pkt.Data.PullUp(8) + if !ok { return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.Data.First()) + srcPort, dstPort, err := transProto.ParsePorts(transHeader) if err != nil { return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 06d312207..9ff80ab24 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -37,7 +37,13 @@ type PacketBuffer struct { Data buffer.VectorisedView // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. + // down the stack, each layer adds to Header. Note that forwarded + // packets don't populate Headers on their way out -- their headers and + // payload are never parsed out and remain in Data. + // + // TODO(gvisor.dev/issue/170): Forwarded packets don't currently + // populate Header, but should. This will be doable once early parsing + // (https://github.com/google/gvisor/pull/1995) is supported. Header buffer.Prependable // These fields are used by both inbound and outbound packets. They diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 3f4e08434..1a2cf007c 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -95,16 +95,18 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffe f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Consume the network header. - b := pkt.Data.First() + b, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return + } pkt.Data.TrimFront(fakeNetHeaderLen) // Handle control packets. if b[2] == uint8(fakeControlProtocol) { - nb := pkt.Data.First() - if len(nb) < fakeNetHeaderLen { + nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { return } - pkt.Data.TrimFront(fakeNetHeaderLen) f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) return diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 3084e6593..a611e44ab 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -642,10 +642,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.Header.View()[0]; dst != 3 { + hdrs := p.Pkt.Data.ToView() + if dst := hdrs[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.Header.View()[1]; src != 1 { + if src := hdrs[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index feef8dca0..b1d820372 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(pkt.Data.First()) - if h.Type() != header.ICMPv4EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) + if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(pkt.Data.First()) - if h.Type() != header.ICMPv6EchoReply { + h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) + if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 40461fd31..7712ce652 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -144,7 +144,11 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h := header.TCP(s.data.First()) + h, ok := s.data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + hdr := header.TCP(h) // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: @@ -156,12 +160,16 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(h.DataOffset()) - if offset < header.TCPMinimumSize || offset > len(h) { + offset := int(hdr.DataOffset()) + if offset < header.TCPMinimumSize { + return false + } + hdrWithOpts, ok := s.data.PullUp(offset) + if !ok { return false } - s.options = []byte(h[header.TCPMinimumSize:offset]) + s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -173,18 +181,19 @@ func (s *segment) parse() bool { s.data.TrimFront(offset) } if verifyChecksum { - s.csum = h.Checksum() + hdr = header.TCP(hdrWithOpts) + s.csum = hdr.Checksum() xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = h.CalculateChecksum(xsum) + xsum = hdr.CalculateChecksum(xsum) s.data.TrimFront(offset) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(h.SequenceNumber()) - s.ackNumber = seqnum.Value(h.AckNumber()) - s.flags = h.Flags() - s.window = seqnum.Size(h.WindowSize()) + s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(hdr.AckNumber()) + s.flags = hdr.Flags() + s.window = seqnum.Size(hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 7e574859b..33e2b9a09 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3556,7 +3556,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 c.SendSegment(vv) @@ -3583,7 +3583,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) { AckNum: c.IRS.Add(1), RcvWnd: 30000, }) - tcpbuf := vv.First()[header.IPv4MinimumSize:] + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] // Overwrite a byte in the payload which should cause checksum // verification to fail. tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index edb54f0be..756ab913a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk senderAddress: tcpip.FullAddress{ NIC: r.NICID(), Addr: id.RemoteAddress, - Port: hdr.SourcePort(), + Port: header.UDP(hdr).SourcePort(), }, } packet.data = pkt.Data diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 6e31a9bac..52af6de22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -68,8 +68,13 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { // Get the header then trim it from the view. - hdr := header.UDP(pkt.Data.First()) - if int(hdr.Length()) > pkt.Data.Size() { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + if int(header.UDP(h).Length()) > pkt.Data.Size() { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true -- cgit v1.2.3 From 40d6aae1220292985a85ee03248ad5781edb4c80 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 1 May 2020 16:32:15 -0700 Subject: Regenerate SLAAC address on conflicts with the NIC If the NIC already has a generated SLAAC address, regenerate a new SLAAC address until one is generated that does not conflict with the NIC's existing addresses, up to a maximum of 10 attempts. This applies to both stable and temporary SLAAC addresses. Test: stack_test.TestMixedSLAACAddrConflictRegen PiperOrigin-RevId: 309495628 --- pkg/tcpip/stack/ndp.go | 141 +++++++++++++++++++----------- pkg/tcpip/stack/ndp_test.go | 209 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 299 insertions(+), 51 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index fbc0e3f55..15343acbc 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -145,6 +145,10 @@ const ( // minRegenAdvanceDuration is the minimum duration before the deprecation // of a temporary address when a new address will be generated. minRegenAdvanceDuration = time.Duration(0) + + // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt + // SLAAC address regenerations in response to a NIC-local conflict. + maxSLAACAddrLocalRegenAttempts = 10 ) var ( @@ -565,11 +569,18 @@ type slaacPrefixState struct { // Nonzero only when the address is not preferred forever. preferredUntil time.Time - // The endpoint for the stable address generated for a SLAAC prefix. - // - // May only be nil when a SLAAC address is being (re-)generated. Otherwise, - // must not be nil as all SLAAC prefixes must have a stable address. - stableAddrRef *referencedNetworkEndpoint + // State associated with the stable address generated for the prefix. + stableAddr struct { + // The address's endpoint. + // + // May only be nil when the address is being (re-)generated. Otherwise, + // must not be nil as all SLAAC prefixes must have a stable address. + ref *referencedNetworkEndpoint + + // The number of times an address has been generated locally where the NIC + // already had the generated address. + localGenerationFailures uint8 + } // The temporary (short-lived) addresses generated for the SLAAC prefix. tempAddrs map[tcpip.Address]tempSLAACAddrState @@ -579,7 +590,7 @@ type slaacPrefixState struct { // in the generation and DAD process at any time. That is, no two addresses // will be generated at the same time for a given SLAAC prefix. - // The number of times an address has been generated. + // The number of times an address has been generated and added to the NIC. // // Addresses may be regenerated in reseponse to a DAD conflicts. generationAttempts uint8 @@ -1125,7 +1136,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) } - ndp.deprecateSLAACAddress(state.stableAddrRef) + ndp.deprecateSLAACAddress(state.stableAddr.ref) }), invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] @@ -1166,7 +1177,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { } // If the address is assigned (DAD resolved), generate a temporary address. - if state.stableAddrRef.getKind() == permanent { + if state.stableAddr.ref.getKind() == permanent { // Reset the generation attempts counter as we are starting the generation // of a new address for the SLAAC prefix. ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */) @@ -1179,11 +1190,6 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint { - // If the nic already has this address, do nothing further. - if ndp.nic.hasPermanentAddrLocked(addr.Address) { - return nil - } - // Inform the integrator that we have a new SLAAC address. ndpDisp := ndp.nic.stack.ndpDisp if ndpDisp == nil { @@ -1216,7 +1222,7 @@ func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType netwo // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool { - if r := state.stableAddrRef; r != nil { + if r := state.stableAddr.ref; r != nil { panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix())) } @@ -1226,42 +1232,62 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt return false } + var generatedAddr tcpip.AddressWithPrefix addrBytes := []byte(prefix.ID()) - if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { - addrBytes = header.AppendOpaqueInterfaceIdentifier( - addrBytes[:header.IIDOffsetInIPv6Address], - prefix, - oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), - state.generationAttempts, - oIID.SecretKey, - ) - } else if state.generationAttempts == 0 { - // Only attempt to generate an interface-specific IID if we have a valid - // link address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by - // LinkEndpoint.LinkAddress) before reaching this point. - linkAddr := ndp.nic.linkEP.LinkAddress() - if !header.IsValidUnicastEthernetAddress(linkAddr) { + + for i := 0; ; i++ { + // If we were unable to generate an address after the maximum SLAAC address + // local regeneration attempts, do nothing further. + if i == maxSLAACAddrLocalRegenAttempts { return false } - // Generate an address within prefix from the modified EUI-64 of ndp's NIC's - // Ethernet MAC address. - header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) - } else { - // We have no way to regenerate an address when addresses are not generated - // with opaque IIDs. - return false - } + dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures + if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { + addrBytes = header.AppendOpaqueInterfaceIdentifier( + addrBytes[:header.IIDOffsetInIPv6Address], + prefix, + oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), + dadCounter, + oIID.SecretKey, + ) + } else if dadCounter == 0 { + // Modified-EUI64 based IIDs have no way to resolve DAD conflicts, so if + // the DAD counter is non-zero, we cannot use this method. + // + // Only attempt to generate an interface-specific IID if we have a valid + // link address. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + linkAddr := ndp.nic.linkEP.LinkAddress() + if !header.IsValidUnicastEthernetAddress(linkAddr) { + return false + } + + // Generate an address within prefix from the modified EUI-64 of ndp's + // NIC's Ethernet MAC address. + header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:]) + } else { + // We have no way to regenerate an address in response to an address + // conflict when addresses are not generated with opaque IIDs. + return false + } + + generatedAddr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(addrBytes), + PrefixLen: validPrefixLenForAutoGen, + } - generatedAddr := tcpip.AddressWithPrefix{ - Address: tcpip.Address(addrBytes), - PrefixLen: validPrefixLenForAutoGen, + if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) { + break + } + + state.stableAddr.localGenerationFailures++ } if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil { - state.stableAddrRef = ref + state.stableAddr.ref = ref state.generationAttempts++ return true } @@ -1315,7 +1341,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla return false } - stableAddr := prefixState.stableAddrRef.ep.ID().LocalAddress + stableAddr := prefixState.stableAddr.ref.ep.ID().LocalAddress now := time.Now() // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary @@ -1354,7 +1380,20 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla return false } - generatedAddr := header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr) + // Attempt to generate a new address that is not already assigned to the NIC. + var generatedAddr tcpip.AddressWithPrefix + for i := 0; ; i++ { + // If we were unable to generate an address after the maximum SLAAC address + // local regeneration attempts, do nothing further. + if i == maxSLAACAddrLocalRegenAttempts { + return false + } + + generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr) + if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) { + break + } + } // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary // address with a zero preferred lifetime. The checks above ensure this @@ -1450,9 +1489,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // If the preferred lifetime is zero, then the prefix should be deprecated. deprecated := pl == 0 if deprecated { - ndp.deprecateSLAACAddress(prefixState.stableAddrRef) + ndp.deprecateSLAACAddress(prefixState.stableAddr.ref) } else { - prefixState.stableAddrRef.deprecated = false + prefixState.stableAddr.ref.deprecated = false } // If prefix was preferred for some finite lifetime before, stop the @@ -1514,7 +1553,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // If DAD is not yet complete on the stable address, there is no need to do // work with temporary addresses. - if prefixState.stableAddrRef.getKind() != permanent { + if prefixState.stableAddr.ref.getKind() != permanent { return } @@ -1617,7 +1656,7 @@ func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) { // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) { - if r := state.stableAddrRef; r != nil { + if r := state.stableAddr.ref; r != nil { // Since we are already invalidating the prefix, do not invalidate the // prefix when removing the address. if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil { @@ -1639,14 +1678,14 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr prefix := addr.Subnet() state, ok := ndp.slaacPrefixes[prefix] - if !ok || state.stableAddrRef == nil || addr.Address != state.stableAddrRef.ep.ID().LocalAddress { + if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.ep.ID().LocalAddress { return } if !invalidatePrefix { // If the prefix is not being invalidated, disassociate the address from the // prefix and do nothing further. - state.stableAddrRef = nil + state.stableAddr.ref = nil ndp.slaacPrefixes[prefix] = state return } @@ -1665,7 +1704,7 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState) } - state.stableAddrRef = nil + state.stableAddr.ref = nil state.deprecationTimer.StopLocked() state.invalidationTimer.StopLocked() delete(ndp.slaacPrefixes, prefix) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 421df674f..67f012840 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -2521,6 +2521,215 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout) } +// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response +// to a mix of DAD conflicts and NIC-local conflicts. +func TestMixedSLAACAddrConflictRegen(t *testing.T) { + const ( + nicID = 1 + nicName = "nic" + lifetimeSeconds = 9999 + // From stack.maxSLAACAddrLocalRegenAttempts + maxSLAACAddrLocalRegenAttempts = 10 + // We use 2 more addreses than the maximum local regeneration attempts + // because we want to also trigger regeneration in response to a DAD + // conflicts for this test. + maxAddrs = maxSLAACAddrLocalRegenAttempts + 2 + dupAddrTransmits = 1 + retransmitTimer = time.Second + ) + + var tempIIDHistoryWithModifiedEUI64 [header.IIDSize]byte + header.InitialTempIID(tempIIDHistoryWithModifiedEUI64[:], nil, nicID) + + var tempIIDHistoryWithOpaqueIID [header.IIDSize]byte + header.InitialTempIID(tempIIDHistoryWithOpaqueIID[:], nil, nicID) + + prefix, subnet, stableAddrWithModifiedEUI64 := prefixSubnetAddr(0, linkAddr1) + var stableAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix + var tempAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix + var tempAddrsWithModifiedEUI64 [maxAddrs]tcpip.AddressWithPrefix + addrBytes := []byte(subnet.ID()) + for i := 0; i < maxAddrs; i++ { + stableAddrsWithOpaqueIID[i] = tcpip.AddressWithPrefix{ + Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, uint8(i), nil)), + PrefixLen: header.IIDOffsetInIPv6Address * 8, + } + // When generating temporary addresses, the resolved stable address for the + // SLAAC prefix will be the first address stable address generated for the + // prefix as we will not simulate address conflicts for the stable addresses + // in tests involving temporary addresses. Address conflicts for stable + // addresses will be done in their own tests. + tempAddrsWithOpaqueIID[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithOpaqueIID[:], stableAddrsWithOpaqueIID[0].Address) + tempAddrsWithModifiedEUI64[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithModifiedEUI64[:], stableAddrWithModifiedEUI64.Address) + } + + tests := []struct { + name string + addrs []tcpip.AddressWithPrefix + tempAddrs bool + initialExpect tcpip.AddressWithPrefix + nicNameFromID func(tcpip.NICID, string) string + }{ + { + name: "Stable addresses with opaque IIDs", + addrs: stableAddrsWithOpaqueIID[:], + nicNameFromID: func(tcpip.NICID, string) string { + return nicName + }, + }, + { + name: "Temporary addresses with opaque IIDs", + addrs: tempAddrsWithOpaqueIID[:], + tempAddrs: true, + initialExpect: stableAddrsWithOpaqueIID[0], + nicNameFromID: func(tcpip.NICID, string) string { + return nicName + }, + }, + { + name: "Temporary addresses with modified EUI64", + addrs: tempAddrsWithModifiedEUI64[:], + tempAddrs: true, + initialExpect: stableAddrWithModifiedEUI64, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + ndpConfigs := stack.NDPConfigurations{ + HandleRAs: true, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: test.tempAddrs, + AutoGenAddressConflictRetries: 1, + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: test.nicNameFromID, + }, + }) + + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: llAddr2, + NIC: nicID, + }}) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + for j := 0; j < len(test.addrs)-1; j++ { + // The NIC will not attempt to generate an address in response to a + // NIC-local conflict after some maximum number of attempts. We skip + // creating a conflict for the address that would be generated as part + // of the last attempt so we can simulate a DAD conflict for this + // address and restart the NIC-local generation process. + if j == maxSLAACAddrLocalRegenAttempts-1 { + continue + } + + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err) + } + } + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } + + expectAutoGenAddrAsyncEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } + + expectDADEventAsync := func(addr tcpip.Address) { + t.Helper() + + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD event") + } + } + + // Enable DAD. + ndpDisp.dadC = make(chan ndpDADEvent, 2) + ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits + ndpConfigs.RetransmitTimer = retransmitTimer + if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { + t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) + } + + // Do SLAAC for prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) + if test.initialExpect != (tcpip.AddressWithPrefix{}) { + expectAutoGenAddrEvent(test.initialExpect, newAddr) + expectDADEventAsync(test.initialExpect.Address) + } + + // The last local generation attempt should succeed, but we introduce a + // DAD failure to restart the local generation process. + addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1] + expectAutoGenAddrAsyncEvent(addr, newAddr) + if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil { + t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err) + } + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + expectAutoGenAddrEvent(addr, invalidatedAddr) + + // The last address generated should resolve DAD. + addr = test.addrs[len(test.addrs)-1] + expectAutoGenAddrAsyncEvent(addr, newAddr) + expectDADEventAsync(addr.Address) + + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %+v", e) + default: + } + }) + } +} + // stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher, // channel.Endpoint and stack.Stack. // -- cgit v1.2.3 From b660f16d18827f0310594c80d9387de11430f15f Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Fri, 27 Mar 2020 12:18:45 -0700 Subject: Support for connection tracking of TCP packets. Connection tracking is used to track packets in prerouting and output hooks of iptables. The NAT rules modify the tuples in connections. The connection tracking code modifies the packets by looking at the modified tuples. --- pkg/sentry/socket/netfilter/netfilter.go | 14 +- pkg/sentry/socket/netfilter/targets.go | 3 +- pkg/sentry/socket/netfilter/tcp_matcher.go | 17 +- pkg/sentry/socket/netfilter/udp_matcher.go | 17 +- pkg/tcpip/header/tcp.go | 17 + pkg/tcpip/network/ipv4/ipv4.go | 46 ++- pkg/tcpip/stack/BUILD | 2 + pkg/tcpip/stack/conntrack.go | 480 ++++++++++++++++++++++ pkg/tcpip/stack/iptables.go | 51 ++- pkg/tcpip/stack/iptables_targets.go | 115 +++--- pkg/tcpip/stack/iptables_types.go | 4 +- pkg/tcpip/stack/nic.go | 4 +- pkg/tcpip/stack/packet_buffer.go | 4 + pkg/tcpip/stack/route.go | 13 + pkg/tcpip/stack/stack.go | 19 + pkg/tcpip/transport/tcp/BUILD | 2 +- pkg/tcpip/transport/tcp/rcv.go | 19 +- pkg/tcpip/transport/tcp/rcv_test.go | 6 +- pkg/tcpip/transport/tcpconntrack/BUILD | 1 - pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go | 16 +- test/iptables/filter_output.go | 2 +- test/iptables/iptables_test.go | 66 ++- test/iptables/iptables_util.go | 2 +- test/iptables/nat.go | 103 ++++- 24 files changed, 858 insertions(+), 165 deletions(-) create mode 100644 pkg/tcpip/stack/conntrack.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 72d093aa8..40736fb38 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -251,7 +251,7 @@ func marshalTarget(target stack.Target) []byte { case stack.ReturnTarget: return marshalStandardTarget(stack.RuleReturn) case stack.RedirectTarget: - return marshalRedirectTarget() + return marshalRedirectTarget(tg) case JumpTarget: return marshalJumpTarget(tg) default: @@ -288,7 +288,7 @@ func marshalErrorTarget(errorName string) []byte { return binary.Marshal(ret, usermem.ByteOrder, target) } -func marshalRedirectTarget() []byte { +func marshalRedirectTarget(rt stack.RedirectTarget) []byte { // This is a redirect target named redirect target := linux.XTRedirectTarget{ Target: linux.XTEntryTarget{ @@ -298,6 +298,16 @@ func marshalRedirectTarget() []byte { copy(target.Target.Name[:], redirectTargetName) ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) + target.NfRange.RangeSize = 1 + if rt.RangeProtoSpecified { + target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED + } + // Convert port from little endian to big endian. + port := make([]byte, 2) + binary.LittleEndian.PutUint16(port, rt.MinPort) + target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port) + binary.LittleEndian.PutUint16(port, rt.MaxPort) + target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port) return binary.Marshal(ret, usermem.ByteOrder, target) } diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index c948de876..84abe8d29 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -15,6 +15,7 @@ package netfilter import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -29,6 +30,6 @@ type JumpTarget struct { } // Action implements stack.Target.Action. -func (jt JumpTarget) Action(stack.PacketBuffer) (stack.RuleVerdict, int) { +func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrackTable, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 55c0f04f3..57a1e1c12 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -120,14 +120,27 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa if pkt.TransportHeader != nil { tcpHeader = header.TCP(pkt.TransportHeader) } else { + var length int + if hook == stack.Prerouting { + // The network header hasn't been parsed yet. We have to do it here. + hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // There's no valid TCP header here, so we hotdrop the + // packet. + return false, true + } + h := header.IPv4(hdr) + pkt.NetworkHeader = hdr + length = int(h.HeaderLength()) + } // The TCP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + hdr, ok := pkt.Data.PullUp(length + header.TCPMinimumSize) if !ok { // There's no valid TCP header here, so we hotdrop the // packet. return false, true } - tcpHeader = header.TCP(hdr) + tcpHeader = header.TCP(hdr[length:]) } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 04d03d494..cfa9e621d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -119,14 +119,27 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceNa if pkt.TransportHeader != nil { udpHeader = header.UDP(pkt.TransportHeader) } else { + var length int + if hook == stack.Prerouting { + // The network header hasn't been parsed yet. We have to do it here. + hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + // There's no valid UDP header here, so we hotdrop the + // packet. + return false, true + } + h := header.IPv4(hdr) + pkt.NetworkHeader = hdr + length = int(h.HeaderLength()) + } // The UDP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) + hdr, ok := pkt.Data.PullUp(length + header.UDPMinimumSize) if !ok { // There's no valid UDP header here, so we hotdrop the // packet. return false, true } - udpHeader = header.UDP(hdr) + udpHeader = header.UDP(hdr[length:]) } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 13480687d..21581257b 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -594,3 +594,20 @@ func AddTCPOptionPadding(options []byte, offset int) int { } return paddingToAdd } + +// Acceptable checks if a segment that starts at segSeq and has length segLen is +// "acceptable" for arriving in a receive window that starts at rcvNxt and ends +// before rcvAcc, according to the table on page 26 and 69 of RFC 793. +func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool { + if rcvNxt == rcvAcc { + return segLen == 0 && segSeq == rcvNxt + } + if segLen == 0 { + // rcvWnd is incremented by 1 because that is Linux's behavior despite the + // RFC. + return segSeq.InRange(rcvNxt, rcvAcc.Add(1)) + } + // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming + // the payload, so we'll accept any payload that overlaps the receieve window. + return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThan(rcvAcc) +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1d61fddad..9db42b2a4 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -252,11 +252,31 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Output, pkt); !ok { + if ok := ipt.Check(stack.Output, &pkt, gso, r, ""); !ok { // iptables is telling us to drop the packet. return nil } + if pkt.NatDone { + // If the packet is manipulated as per NAT Ouput rules, handle packet + // based on destination address and do not send the packet to link layer. + netHeader := header.IPv4(pkt.NetworkHeader) + ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + if err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + packet := stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} + ep.HandlePacket(&route, packet) + return nil + } + } + if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in // the PacketBuffer's Data field. @@ -302,8 +322,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - dropped := ipt.CheckPackets(stack.Output, pkts) - if len(dropped) == 0 { + dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r) + if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) @@ -318,6 +338,24 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if _, ok := dropped[pkt]; ok { continue } + if _, ok := natPkts[pkt]; ok { + netHeader := header.IPv4(pkt.NetworkHeader) + ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) + if err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + packet := stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} + ep.HandlePacket(&route, packet) + n++ + continue + } + } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) return n, err @@ -407,7 +445,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Input, pkt); !ok { + if ok := ipt.Check(stack.Input, &pkt, nil, nil, ""); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 5e963a4af..f71073207 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -30,6 +30,7 @@ go_template_instance( go_library( name = "stack", srcs = [ + "conntrack.go", "dhcpv6configurationfromndpra_string.go", "forwarder.go", "icmp_rate_limit.go", @@ -62,6 +63,7 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", + "//pkg/tcpip/transport/tcpconntrack", "//pkg/waiter", "@org_golang_x_time//rate:go_default_library", ], diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go new file mode 100644 index 000000000..7d1ede1f2 --- /dev/null +++ b/pkg/tcpip/stack/conntrack.go @@ -0,0 +1,480 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "encoding/binary" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" +) + +// Connection tracking is used to track and manipulate packets for NAT rules. +// The connection is created for a packet if it does not exist. Every connection +// contains two tuples (original and reply). The tuples are manipulated if there +// is a matching NAT rule. The packet is modified by looking at the tuples in the +// Prerouting and Output hooks. + +// Direction of the tuple. +type ctDirection int + +const ( + dirOriginal ctDirection = iota + dirReply +) + +// Status of connection. +// TODO(gvisor.dev/issue/170): Add other states of connection. +type connStatus int + +const ( + connNew connStatus = iota + connEstablished +) + +// Manipulation type for the connection. +type manipType int + +const ( + manipDstPrerouting manipType = iota + manipDstOutput +) + +// connTrackMutable is the manipulatable part of the tuple. +type connTrackMutable struct { + // addr is source address of the tuple. + addr tcpip.Address + + // port is source port of the tuple. + port uint16 + + // protocol is network layer protocol. + protocol tcpip.NetworkProtocolNumber +} + +// connTrackImmutable is the non-manipulatable part of the tuple. +type connTrackImmutable struct { + // addr is destination address of the tuple. + addr tcpip.Address + + // direction is direction (original or reply) of the tuple. + direction ctDirection + + // port is destination port of the tuple. + port uint16 + + // protocol is transport layer protocol. + protocol tcpip.TransportProtocolNumber +} + +// connTrackTuple represents the tuple which is created from the +// packet. +type connTrackTuple struct { + // dst is non-manipulatable part of the tuple. + dst connTrackImmutable + + // src is manipulatable part of the tuple. + src connTrackMutable +} + +// connTrackTupleHolder is the container of tuple and connection. +type ConnTrackTupleHolder struct { + // conn is pointer to the connection tracking entry. + conn *connTrack + + // tuple is original or reply tuple. + tuple connTrackTuple +} + +// connTrack is the connection. +type connTrack struct { + // originalTupleHolder contains tuple in original direction. + originalTupleHolder ConnTrackTupleHolder + + // replyTupleHolder contains tuple in reply direction. + replyTupleHolder ConnTrackTupleHolder + + // status indicates connection is new or established. + status connStatus + + // timeout indicates the time connection should be active. + timeout time.Duration + + // manip indicates if the packet should be manipulated. + manip manipType + + // tcb is TCB control block. It is used to keep track of states + // of tcp connection. + tcb tcpconntrack.TCB + + // tcbHook indicates if the packet is inbound or outbound to + // update the state of tcb. + tcbHook Hook +} + +// ConnTrackTable contains a map of all existing connections created for +// NAT rules. +type ConnTrackTable struct { + // connMu protects connTrackTable. + connMu sync.RWMutex + + // connTrackTable maintains a map of tuples needed for connection tracking + // for iptables NAT rules. The key for the map is an integer calculated + // using seed, source address, destination address, source port and + // destination port. + CtMap map[uint32]ConnTrackTupleHolder + + // seed is a one-time random value initialized at stack startup + // and is used in calculation of hash key for connection tracking + // table. + Seed uint32 +} + +// parseHeaders sets headers in the packet. +func parseHeaders(pkt *PacketBuffer) { + newPkt := pkt.Clone() + + // Set network header. + hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return + } + netHeader := header.IPv4(hdr) + newPkt.NetworkHeader = hdr + length := int(netHeader.HeaderLength()) + + // TODO(gvisor.dev/issue/170): Need to support for other + // protocols as well. + // Set transport header. + switch protocol := netHeader.TransportProtocol(); protocol { + case header.UDPProtocolNumber: + if newPkt.TransportHeader == nil { + h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize) + if !ok { + return + } + newPkt.TransportHeader = buffer.View(header.UDP(h[length:])) + } + case header.TCPProtocolNumber: + if newPkt.TransportHeader == nil { + h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize) + if !ok { + return + } + newPkt.TransportHeader = buffer.View(header.TCP(h[length:])) + } + } + pkt.NetworkHeader = newPkt.NetworkHeader + pkt.TransportHeader = newPkt.TransportHeader +} + +// packetToTuple converts packet to a tuple in original direction. +func packetToTuple(pkt PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { + var tuple connTrackTuple + + netHeader := header.IPv4(pkt.NetworkHeader) + // TODO(gvisor.dev/issue/170): Need to support for other + // protocols as well. + if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { + return tuple, tcpip.ErrUnknownProtocol + } + tcpHeader := header.TCP(pkt.TransportHeader) + if tcpHeader == nil { + return tuple, tcpip.ErrUnknownProtocol + } + + tuple.src.addr = netHeader.SourceAddress() + tuple.src.port = tcpHeader.SourcePort() + tuple.src.protocol = header.IPv4ProtocolNumber + + tuple.dst.addr = netHeader.DestinationAddress() + tuple.dst.port = tcpHeader.DestinationPort() + tuple.dst.protocol = netHeader.TransportProtocol() + + return tuple, nil +} + +// getReplyTuple creates reply tuple for the given tuple. +func getReplyTuple(tuple connTrackTuple) connTrackTuple { + var replyTuple connTrackTuple + replyTuple.src.addr = tuple.dst.addr + replyTuple.src.port = tuple.dst.port + replyTuple.src.protocol = tuple.src.protocol + replyTuple.dst.addr = tuple.src.addr + replyTuple.dst.port = tuple.src.port + replyTuple.dst.protocol = tuple.dst.protocol + replyTuple.dst.direction = dirReply + + return replyTuple +} + +// makeNewConn creates new connection. +func makeNewConn(tuple, replyTuple connTrackTuple) connTrack { + var conn connTrack + conn.status = connNew + conn.originalTupleHolder.tuple = tuple + conn.originalTupleHolder.conn = &conn + conn.replyTupleHolder.tuple = replyTuple + conn.replyTupleHolder.conn = &conn + + return conn +} + +// getTupleHash returns hash of the tuple. The fields used for +// generating hash are seed (generated once for stack), source address, +// destination address, source port and destination ports. +func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 { + h := jenkins.Sum32(ct.Seed) + h.Write([]byte(tuple.src.addr)) + h.Write([]byte(tuple.dst.addr)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, tuple.src.port) + h.Write([]byte(portBuf)) + binary.LittleEndian.PutUint16(portBuf, tuple.dst.port) + h.Write([]byte(portBuf)) + + return h.Sum32() +} + +// connTrackForPacket returns connTrack for packet. +// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other +// transport protocols. +func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) { + if hook == Prerouting { + // Headers will not be set in Prerouting. + // TODO(gvisor.dev/issue/170): Change this after parsing headers + // code is added. + parseHeaders(pkt) + } + + var dir ctDirection + tuple, err := packetToTuple(*pkt, hook) + if err != nil { + return nil, dir + } + + ct.connMu.Lock() + defer ct.connMu.Unlock() + + connTrackTable := ct.CtMap + hash := ct.getTupleHash(tuple) + + var conn *connTrack + switch createConn { + case true: + // If connection does not exist for the hash, create a new + // connection. + replyTuple := getReplyTuple(tuple) + replyHash := ct.getTupleHash(replyTuple) + newConn := makeNewConn(tuple, replyTuple) + conn = &newConn + + // Add tupleHolders to the map. + // TODO(gvisor.dev/issue/170): Need to support collisions using linked list. + ct.CtMap[hash] = conn.originalTupleHolder + ct.CtMap[replyHash] = conn.replyTupleHolder + default: + tupleHolder, ok := connTrackTable[hash] + if !ok { + return nil, dir + } + + // If this is the reply of new connection, set the connection + // status as ESTABLISHED. + conn = tupleHolder.conn + if conn.status == connNew && tupleHolder.tuple.dst.direction == dirReply { + conn.status = connEstablished + } + if tupleHolder.conn == nil { + panic("tupleHolder has null connection tracking entry") + } + + dir = tupleHolder.tuple.dst.direction + } + return conn, dir +} + +// SetNatInfo will manipulate the tuples according to iptables NAT rules. +func (ct *ConnTrackTable) SetNatInfo(pkt *PacketBuffer, rt RedirectTarget, hook Hook) { + // Get the connection. Connection is always created before this + // function is called. + conn, _ := ct.connTrackForPacket(pkt, hook, false) + if conn == nil { + panic("connection should be created to manipulate tuples.") + } + replyTuple := conn.replyTupleHolder.tuple + replyHash := ct.getTupleHash(replyTuple) + + // TODO(gvisor.dev/issue/170): Support only redirect of ports. Need to + // support changing of address for Prerouting. + + // Change the port as per the iptables rule. This tuple will be used + // to manipulate the packet in HandlePacket. + conn.replyTupleHolder.tuple.src.addr = rt.MinIP + conn.replyTupleHolder.tuple.src.port = rt.MinPort + newHash := ct.getTupleHash(conn.replyTupleHolder.tuple) + + // Add the changed tuple to the map. + ct.connMu.Lock() + defer ct.connMu.Unlock() + ct.CtMap[newHash] = conn.replyTupleHolder + if hook == Output { + conn.replyTupleHolder.conn.manip = manipDstOutput + } + + // Delete the old tuple. + delete(ct.CtMap, replyHash) +} + +// handlePacketPrerouting manipulates ports for packets in Prerouting hook. +// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.. +func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) { + netHeader := header.IPv4(pkt.NetworkHeader) + tcpHeader := header.TCP(pkt.TransportHeader) + + // For prerouting redirection, packets going in the original direction + // have their destinations modified and replies have their sources + // modified. + switch dir { + case dirOriginal: + port := conn.replyTupleHolder.tuple.src.port + tcpHeader.SetDestinationPort(port) + netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + case dirReply: + port := conn.originalTupleHolder.tuple.dst.port + tcpHeader.SetSourcePort(port) + netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + } + + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) +} + +// handlePacketOutput manipulates ports for packets in Output hook. +func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, dir ctDirection) { + netHeader := header.IPv4(pkt.NetworkHeader) + tcpHeader := header.TCP(pkt.TransportHeader) + + // For output redirection, packets going in the original direction + // have their destinations modified and replies have their sources + // modified. For prerouting redirection, we only reach this point + // when replying, so packet sources are modified. + if conn.manip == manipDstOutput && dir == dirOriginal { + port := conn.replyTupleHolder.tuple.src.port + tcpHeader.SetDestinationPort(port) + netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + } else { + port := conn.originalTupleHolder.tuple.dst.port + tcpHeader.SetSourcePort(port) + netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + } + + // Calculate the TCP checksum and set it. + tcpHeader.SetChecksum(0) + hdr := &pkt.Header + length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) + if gso != nil && gso.NeedsCsum { + tcpHeader.SetChecksum(xsum) + } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size()) + tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + } + + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) +} + +// HandlePacket will manipulate the port and address of the packet if the +// connection exists. +func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { + if pkt.NatDone { + return + } + + if hook != Prerouting && hook != Output { + return + } + + conn, dir := ct.connTrackForPacket(pkt, hook, false) + // Connection or Rule not found for the packet. + if conn == nil { + return + } + + netHeader := header.IPv4(pkt.NetworkHeader) + // TODO(gvisor.dev/issue/170): Need to support for other transport + // protocols as well. + if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { + return + } + + tcpHeader := header.TCP(pkt.TransportHeader) + if tcpHeader == nil { + return + } + + switch hook { + case Prerouting: + handlePacketPrerouting(pkt, conn, dir) + case Output: + handlePacketOutput(pkt, conn, gso, r, dir) + } + pkt.NatDone = true + + // Update the state of tcb. + // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle + // other tcp states. + var st tcpconntrack.Result + if conn.tcb.IsEmpty() { + conn.tcb.Init(tcpHeader) + conn.tcbHook = hook + } else { + switch hook { + case conn.tcbHook: + st = conn.tcb.UpdateStateOutbound(tcpHeader) + default: + st = conn.tcb.UpdateStateInbound(tcpHeader) + } + } + + // Delete conntrack if tcp connection is closed. + if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset { + ct.deleteConnTrack(conn) + } +} + +// deleteConnTrack deletes the connection. +func (ct *ConnTrackTable) deleteConnTrack(conn *connTrack) { + if conn == nil { + return + } + + tuple := conn.originalTupleHolder.tuple + hash := ct.getTupleHash(tuple) + replyTuple := conn.replyTupleHolder.tuple + replyHash := ct.getTupleHash(replyTuple) + + ct.connMu.Lock() + defer ct.connMu.Unlock() + + delete(ct.CtMap, hash) + delete(ct.CtMap, replyHash) +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 6b91159d4..7c3c47d50 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -17,6 +17,7 @@ package stack import ( "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -110,6 +111,10 @@ func DefaultTables() IPTables { Prerouting: []string{TablenameMangle, TablenameNat}, Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, }, + connections: ConnTrackTable{ + CtMap: make(map[uint32]ConnTrackTupleHolder), + Seed: generateRandUint32(), + }, } } @@ -173,12 +178,16 @@ const ( // dropped. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address) bool { + // Packets are manipulated only if connection and matching + // NAT rule exists. + it.connections.HandlePacket(pkt, hook, gso, r) + // Go through each table containing the hook. for _, tablename := range it.Priorities[hook] { table := it.Tables[tablename] ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -189,7 +198,7 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v { case RuleAccept: continue case RuleDrop: @@ -219,26 +228,34 @@ func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool { // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList) (drop map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - if ok := it.Check(hook, *pkt); !ok { - if drop == nil { - drop = make(map[*PacketBuffer]struct{}) + if !pkt.NatDone { + if ok := it.Check(hook, pkt, gso, r, ""); !ok { + if drop == nil { + drop = make(map[*PacketBuffer]struct{}) + } + drop[pkt] = struct{}{} + } + if pkt.NatDone { + if natPkts == nil { + natPkts = make(map[*PacketBuffer]struct{}) + } + natPkts[pkt] = struct{}{} } - drop[pkt] = struct{}{} } } - return drop + return drop, natPkts } // Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a // precondition. -func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address); verdict { case RuleAccept: return chainAccept @@ -255,7 +272,7 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address); verdict { case chainAccept: return chainAccept case chainDrop: @@ -279,9 +296,9 @@ func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx } // Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a +// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a // precondition. -func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in @@ -304,7 +321,7 @@ func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, pkt, "") + matches, hotdrop := matcher.Match(hook, *pkt, "") if hotdrop { return RuleDrop, 0 } @@ -315,7 +332,7 @@ func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt) + return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) } func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 8be61f4b1..36cc6275d 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -24,7 +24,7 @@ import ( type AcceptTarget struct{} // Action implements Target.Action. -func (AcceptTarget) Action(packet PacketBuffer) (RuleVerdict, int) { +func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } @@ -32,7 +32,7 @@ func (AcceptTarget) Action(packet PacketBuffer) (RuleVerdict, int) { type DropTarget struct{} // Action implements Target.Action. -func (DropTarget) Action(packet PacketBuffer) (RuleVerdict, int) { +func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } @@ -41,7 +41,7 @@ func (DropTarget) Action(packet PacketBuffer) (RuleVerdict, int) { type ErrorTarget struct{} // Action implements Target.Action. -func (ErrorTarget) Action(packet PacketBuffer) (RuleVerdict, int) { +func (ErrorTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -52,7 +52,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (UserChainTarget) Action(PacketBuffer) (RuleVerdict, int) { +func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -61,7 +61,7 @@ func (UserChainTarget) Action(PacketBuffer) (RuleVerdict, int) { type ReturnTarget struct{} // Action implements Target.Action. -func (ReturnTarget) Action(PacketBuffer) (RuleVerdict, int) { +func (ReturnTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } @@ -75,16 +75,16 @@ type RedirectTarget struct { // redirect. RangeProtoSpecified bool - // Min address used to redirect. + // MinIP indicates address used to redirect. MinIP tcpip.Address - // Max address used to redirect. + // MaxIP indicates address used to redirect. MaxIP tcpip.Address - // Min port used to redirect. + // MinPort indicates port used to redirect. MinPort uint16 - // Max port used to redirect. + // MaxPort indicates port used to redirect. MaxPort uint16 } @@ -92,61 +92,76 @@ type RedirectTarget struct { // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for PREROUTING and calls pkt.Clone(), neither // of which should be the case. -func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) { - newPkt := pkt.Clone() +func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { + // Packet is already manipulated. + if pkt.NatDone { + return RuleAccept, 0 + } // Set network header. - headerView, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return RuleDrop, 0 + if hook == Prerouting { + parseHeaders(pkt) } - netHeader := header.IPv4(headerView) - newPkt.NetworkHeader = headerView - hlen := int(netHeader.HeaderLength()) - tlen := int(netHeader.TotalLength()) - newPkt.Data.TrimFront(hlen) - newPkt.Data.CapLength(tlen - hlen) + // Drop the packet if network and transport header are not set. + if pkt.NetworkHeader == nil || pkt.TransportHeader == nil { + return RuleDrop, 0 + } - // TODO(gvisor.dev/issue/170): Change destination address to - // loopback or interface address on which the packet was - // received. + // Change the address to localhost (127.0.0.1) in Output and + // to primary address of the incoming interface in Prerouting. + switch hook { + case Output: + rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1}) + rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1}) + case Prerouting: + rt.MinIP = address + rt.MaxIP = address + default: + panic("redirect target is supported only on output and prerouting hooks") + } // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if // we need to change dest address (for OUTPUT chain) or ports. + netHeader := header.IPv4(pkt.NetworkHeader) switch protocol := netHeader.TransportProtocol(); protocol { case header.UDPProtocolNumber: - var udpHeader header.UDP - if newPkt.TransportHeader != nil { - udpHeader = header.UDP(newPkt.TransportHeader) - } else { - if pkt.Data.Size() < header.UDPMinimumSize { - return RuleDrop, 0 - } - hdr, ok := newPkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - return RuleDrop, 0 + udpHeader := header.UDP(pkt.TransportHeader) + udpHeader.SetDestinationPort(rt.MinPort) + + // Calculate UDP checksum and set it. + if hook == Output { + udpHeader.SetChecksum(0) + hdr := &pkt.Header + length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + + // Only calculate the checksum if offloading isn't supported. + if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + xsum := r.PseudoHeaderChecksum(protocol, length) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + udpHeader.SetChecksum(0) + udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } - udpHeader = header.UDP(hdr) } - udpHeader.SetDestinationPort(rt.MinPort) + // Change destination address. + netHeader.SetDestinationAddress(rt.MinIP) + netHeader.SetChecksum(0) + netHeader.SetChecksum(^netHeader.CalculateChecksum()) + pkt.NatDone = true case header.TCPProtocolNumber: - var tcpHeader header.TCP - if newPkt.TransportHeader != nil { - tcpHeader = header.TCP(newPkt.TransportHeader) - } else { - if pkt.Data.Size() < header.TCPMinimumSize { - return RuleDrop, 0 - } - hdr, ok := newPkt.Data.PullUp(header.TCPMinimumSize) - if !ok { - return RuleDrop, 0 - } - tcpHeader = header.TCP(hdr) + if ct == nil { + return RuleAccept, 0 + } + + // Set up conection for matching NAT rule. + // Only the first packet of the connection comes here. + // Other packets will be manipulated in connection tracking. + if conn, _ := ct.connTrackForPacket(pkt, hook, true); conn != nil { + ct.SetNatInfo(pkt, rt, hook) + ct.HandlePacket(pkt, hook, gso, r) } - // TODO(gvisor.dev/issue/170): Need to recompute checksum - // and implement nat connection tracking to support TCP. - tcpHeader.SetDestinationPort(rt.MinPort) default: return RuleDrop, 0 } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 2ffb55f2a..1bb0ba1bd 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -82,6 +82,8 @@ type IPTables struct { // list is the order in which each table should be visited for that // hook. Priorities map[Hook][]string + + connections ConnTrackTable } // A Table defines a set of chains and hooks into the network stack. It is @@ -176,5 +178,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(packet PacketBuffer) (RuleVerdict, int) + Action(packet *PacketBuffer, connections *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 7b54919bb..8f4c1fe42 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1230,8 +1230,10 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. if protocol == header.IPv4ProtocolNumber { + // iptables filtering. ipt := n.stack.IPTables() - if ok := ipt.Check(Prerouting, pkt); !ok { + address := n.primaryAddress(protocol) + if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9ff80ab24..926df4d7b 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -72,6 +72,10 @@ type PacketBuffer struct { EgressRoute *Route GSOOptions *GSO NetworkProtocolNumber tcpip.NetworkProtocolNumber + + // NatDone indicates if the packet has been manipulated as per NAT + // iptables rule. + NatDone bool } // Clone makes a copy of pk. It clones the Data field, which creates a new diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 53148dc03..150297ab9 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -261,3 +261,16 @@ func (r *Route) MakeLoopedRoute() Route { func (r *Route) Stack() *Stack { return r.ref.stack() } + +// ReverseRoute returns new route with given source and destination address. +func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { + return Route{ + NetProto: r.NetProto, + LocalAddress: dst, + LocalLinkAddress: r.RemoteLinkAddress, + RemoteAddress: src, + RemoteLinkAddress: r.LocalLinkAddress, + ref: r.ref, + Loop: r.Loop, + } +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 4a2dc3dc6..e33fae4eb 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1885,3 +1885,22 @@ func generateRandInt64() int64 { } return v } + +// FindNetworkEndpoint returns the network endpoint for the given address. +func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) { + s.mu.Lock() + defer s.mu.Unlock() + + for _, nic := range s.nics { + id := NetworkEndpointID{address} + + if ref, ok := nic.mu.endpoints[id]; ok { + nic.mu.RLock() + defer nic.mu.RUnlock() + + // An endpoint with this id exists, check if it can be used and return it. + return ref.ep, nil + } + } + return nil, tcpip.ErrBadAddress +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index f2aa69069..f38eb6833 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -115,7 +115,7 @@ go_test( size = "small", srcs = ["rcv_test.go"], deps = [ - ":tcp", + "//pkg/tcpip/header", "//pkg/tcpip/seqnum", ], ) diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index a4b73b588..6fe97fefd 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -70,24 +70,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale // acceptable checks if the segment sequence number range is acceptable // according to the table on page 26 of RFC 793. func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { - return Acceptable(segSeq, segLen, r.rcvNxt, r.rcvAcc) -} - -// Acceptable checks if a segment that starts at segSeq and has length segLen is -// "acceptable" for arriving in a receive window that starts at rcvNxt and ends -// before rcvAcc, according to the table on page 26 and 69 of RFC 793. -func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool { - if rcvNxt == rcvAcc { - return segLen == 0 && segSeq == rcvNxt - } - if segLen == 0 { - // rcvWnd is incremented by 1 because that is Linux's behavior despite the - // RFC. - return segSeq.InRange(rcvNxt, rcvAcc.Add(1)) - } - // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming - // the payload, so we'll accept any payload that overlaps the receieve window. - return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThan(rcvAcc) + return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvAcc) } // getSendParams returns the parameters needed by the sender when building diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go index dc02729ce..c9eeff935 100644 --- a/pkg/tcpip/transport/tcp/rcv_test.go +++ b/pkg/tcpip/transport/tcp/rcv_test.go @@ -17,8 +17,8 @@ package rcv_test import ( "testing" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" ) func TestAcceptable(t *testing.T) { @@ -67,8 +67,8 @@ func TestAcceptable(t *testing.T) { {105, 2, 108, 108, false}, {105, 2, 109, 109, false}, } { - if got := tcp.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want { - t.Errorf("tcp.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want) + if got := header.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want { + t.Errorf("header.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want) } } } diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD index 2025ff757..3ad6994a7 100644 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -9,7 +9,6 @@ go_library( deps = [ "//pkg/tcpip/header", "//pkg/tcpip/seqnum", - "//pkg/tcpip/transport/tcp", ], ) diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 30d05200f..12bc1b5b5 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -20,7 +20,6 @@ package tcpconntrack import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" ) // Result is returned when the state of a TCB is updated in response to an @@ -312,7 +311,7 @@ type stream struct { // the window is zero, if it's a packet with no payload and sequence number // equal to una. func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { - return tcp.Acceptable(segSeq, segLen, s.una, s.end) + return header.Acceptable(segSeq, segLen, s.una, s.end) } // closed determines if the stream has already been closed. This happens when @@ -338,3 +337,16 @@ func logicalLen(tcp header.TCP) seqnum.Size { } return l } + +// IsEmpty returns true if tcb is not initialized. +func (t *TCB) IsEmpty() bool { + if t.inbound != (stream{}) || t.outbound != (stream{}) { + return false + } + + if t.firstFin != nil || t.state != ResultDrop { + return false + } + + return true +} diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index f6d974b85..b1382d353 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -42,7 +42,7 @@ func (FilterOutputDropTCPDestPort) Name() string { // ContainerAction implements TestCase.ContainerAction. func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { + if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", "1024:65535", "-j", "DROP"); err != nil { return err } diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 334d8e676..63a862d35 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -115,30 +115,6 @@ func TestFilterInputDropOnlyUDP(t *testing.T) { singleTest(t, FilterInputDropOnlyUDP{}) } -func TestNATRedirectUDPPort(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") - singleTest(t, NATRedirectUDPPort{}) -} - -func TestNATRedirectTCPPort(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") - singleTest(t, NATRedirectTCPPort{}) -} - -func TestNATDropUDP(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") - singleTest(t, NATDropUDP{}) -} - -func TestNATAcceptAll(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") - singleTest(t, NATAcceptAll{}) -} - func TestFilterInputDropTCPDestPort(t *testing.T) { singleTest(t, FilterInputDropTCPDestPort{}) } @@ -164,14 +140,10 @@ func TestFilterInputReturnUnderflow(t *testing.T) { } func TestFilterOutputDropTCPDestPort(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("filter OUTPUT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, FilterOutputDropTCPDestPort{}) } func TestFilterOutputDropTCPSrcPort(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("filter OUTPUT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, FilterOutputDropTCPSrcPort{}) } @@ -235,44 +207,54 @@ func TestOutputInvertDestination(t *testing.T) { singleTest(t, FilterOutputInvertDestination{}) } +func TestNATPreRedirectUDPPort(t *testing.T) { + singleTest(t, NATPreRedirectUDPPort{}) +} + +func TestNATPreRedirectTCPPort(t *testing.T) { + singleTest(t, NATPreRedirectTCPPort{}) +} + +func TestNATOutRedirectUDPPort(t *testing.T) { + singleTest(t, NATOutRedirectUDPPort{}) +} + +func TestNATOutRedirectTCPPort(t *testing.T) { + singleTest(t, NATOutRedirectTCPPort{}) +} + +func TestNATDropUDP(t *testing.T) { + singleTest(t, NATDropUDP{}) +} + +func TestNATAcceptAll(t *testing.T) { + singleTest(t, NATAcceptAll{}) +} + func TestNATOutRedirectIP(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, NATOutRedirectIP{}) } func TestNATOutDontRedirectIP(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, NATOutDontRedirectIP{}) } func TestNATOutRedirectInvert(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, NATOutRedirectInvert{}) } func TestNATPreRedirectIP(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, NATPreRedirectIP{}) } func TestNATPreDontRedirectIP(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, NATPreDontRedirectIP{}) } func TestNATPreRedirectInvert(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, NATPreRedirectInvert{}) } func TestNATRedirectRequiresProtocol(t *testing.T) { - // TODO(gvisor.dev/issue/170): Enable when supported. - t.Skip("NAT isn't supported yet (gvisor.dev/issue/170).") singleTest(t, NATRedirectRequiresProtocol{}) } diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index 2a00677be..2f988cd18 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -151,7 +151,7 @@ func connectTCP(ip net.IP, port int, timeout time.Duration) error { return err } if err := testutil.Poll(callback, timeout); err != nil { - return fmt.Errorf("timed out waiting to connect IP, most recent error: %v", err) + return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %v", port, err) } return nil diff --git a/test/iptables/nat.go b/test/iptables/nat.go index 40096901c..0a10ce7fe 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -26,8 +26,10 @@ const ( ) func init() { - RegisterTestCase(NATRedirectUDPPort{}) - RegisterTestCase(NATRedirectTCPPort{}) + RegisterTestCase(NATPreRedirectUDPPort{}) + RegisterTestCase(NATPreRedirectTCPPort{}) + RegisterTestCase(NATOutRedirectUDPPort{}) + RegisterTestCase(NATOutRedirectTCPPort{}) RegisterTestCase(NATDropUDP{}) RegisterTestCase(NATAcceptAll{}) RegisterTestCase(NATPreRedirectIP{}) @@ -39,16 +41,16 @@ func init() { RegisterTestCase(NATRedirectRequiresProtocol{}) } -// NATRedirectUDPPort tests that packets are redirected to different port. -type NATRedirectUDPPort struct{} +// NATPreRedirectUDPPort tests that packets are redirected to different port. +type NATPreRedirectUDPPort struct{} // Name implements TestCase.Name. -func (NATRedirectUDPPort) Name() string { - return "NATRedirectUDPPort" +func (NATPreRedirectUDPPort) Name() string { + return "NATPreRedirectUDPPort" } // ContainerAction implements TestCase.ContainerAction. -func (NATRedirectUDPPort) ContainerAction(ip net.IP) error { +func (NATPreRedirectUDPPort) ContainerAction(ip net.IP) error { if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } @@ -61,33 +63,53 @@ func (NATRedirectUDPPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (NATRedirectUDPPort) LocalAction(ip net.IP) error { +func (NATPreRedirectUDPPort) LocalAction(ip net.IP) error { return sendUDPLoop(ip, acceptPort, sendloopDuration) } -// NATRedirectTCPPort tests that connections are redirected on specified ports. -type NATRedirectTCPPort struct{} +// NATPreRedirectTCPPort tests that connections are redirected on specified ports. +type NATPreRedirectTCPPort struct{} // Name implements TestCase.Name. -func (NATRedirectTCPPort) Name() string { - return "NATRedirectTCPPort" +func (NATPreRedirectTCPPort) Name() string { + return "NATPreRedirectTCPPort" } // ContainerAction implements TestCase.ContainerAction. -func (NATRedirectTCPPort) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { +func (NATPreRedirectTCPPort) ContainerAction(ip net.IP) error { + if err := natTable("-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } // Listen for TCP packets on redirect port. - return listenTCP(redirectPort, sendloopDuration) + return listenTCP(acceptPort, sendloopDuration) } // LocalAction implements TestCase.LocalAction. -func (NATRedirectTCPPort) LocalAction(ip net.IP) error { +func (NATPreRedirectTCPPort) LocalAction(ip net.IP) error { return connectTCP(ip, dropPort, sendloopDuration) } +// NATOutRedirectUDPPort tests that packets are redirected to different port. +type NATOutRedirectUDPPort struct{} + +// Name implements TestCase.Name. +func (NATOutRedirectUDPPort) Name() string { + return "NATOutRedirectUDPPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectUDPPort) ContainerAction(ip net.IP) error { + dest := []byte{200, 0, 0, 1} + return loopbackTest(dest, "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)) +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectUDPPort) LocalAction(ip net.IP) error { + // No-op. + return nil +} + // NATDropUDP tests that packets are not received in ports other than redirect // port. type NATDropUDP struct{} @@ -329,3 +351,52 @@ func loopbackTest(dest net.IP, args ...string) error { // sendCh will always take the full sendloop time. return <-sendCh } + +// NATOutRedirectTCPPort tests that connections are redirected on specified ports. +type NATOutRedirectTCPPort struct{} + +// Name implements TestCase.Name. +func (NATOutRedirectTCPPort) Name() string { + return "NATOutRedirectTCPPort" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectTCPPort) ContainerAction(ip net.IP) error { + if err := natTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + + timeout := 20 * time.Second + dest := []byte{127, 0, 0, 1} + localAddr := net.TCPAddr{ + IP: dest, + Port: acceptPort, + } + + // Starts listening on port. + lConn, err := net.ListenTCP("tcp", &localAddr) + if err != nil { + return err + } + defer lConn.Close() + + // Accept connections on port. + lConn.SetDeadline(time.Now().Add(timeout)) + err = connectTCP(ip, dropPort, timeout) + if err != nil { + return err + } + + conn, err := lConn.AcceptTCP() + if err != nil { + return err + } + conn.Close() + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectTCPPort) LocalAction(ip net.IP) error { + return nil +} -- cgit v1.2.3 From 485ca36adf18fbb587df9f34a2deea730882fb36 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 6 May 2020 15:57:42 -0700 Subject: Do not assume no DHCPv6 configurations Do not assume that networks need any DHCPv6 configurations. Instead, notify the NDP dispatcher in response to the first NDP RA's DHCPv6 flags, even if the flags indicate no DHCPv6 configurations are available. PiperOrigin-RevId: 310245068 --- .../stack/dhcpv6configurationfromndpra_string.go | 11 +++++----- pkg/tcpip/stack/ndp.go | 9 ++++---- pkg/tcpip/stack/ndp_test.go | 24 +++++++++++++++++++--- 3 files changed, 32 insertions(+), 12 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go index 8b4213eec..d199ded6a 100644 --- a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go +++ b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Code generated by "stringer -type=DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. +// Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. package stack @@ -22,9 +22,9 @@ func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} - _ = x[DHCPv6NoConfiguration-0] - _ = x[DHCPv6ManagedAddress-1] - _ = x[DHCPv6OtherConfigurations-2] + _ = x[DHCPv6NoConfiguration-1] + _ = x[DHCPv6ManagedAddress-2] + _ = x[DHCPv6OtherConfigurations-3] } const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations" @@ -32,8 +32,9 @@ const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAd var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66} func (i DHCPv6ConfigurationFromNDPRA) String() string { + i -= 1 if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) { - return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i), 10) + ")" + return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i+1), 10) + ")" } return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]] } diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 15343acbc..526c7d6ff 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -199,9 +199,11 @@ var ( type DHCPv6ConfigurationFromNDPRA int const ( + _ DHCPv6ConfigurationFromNDPRA = iota + // DHCPv6NoConfiguration indicates that no configurations are available via // DHCPv6. - DHCPv6NoConfiguration DHCPv6ConfigurationFromNDPRA = iota + DHCPv6NoConfiguration // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6. // @@ -315,9 +317,6 @@ type NDPDispatcher interface { // OnDHCPv6Configuration will be called with an updated configuration that is // available via DHCPv6 for a specified NIC. // - // NDPDispatcher assumes that the initial configuration available by DHCPv6 is - // DHCPv6NoConfiguration. - // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) @@ -1808,6 +1807,8 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { if got := len(ndp.defaultRouters); got != 0 { panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got)) } + + ndp.dhcpv6Configuration = 0 } // startSolicitingRouters starts soliciting routers, as per RFC 4861 section diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 67f012840..b3d174cdd 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -4888,7 +4888,12 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { } } - // The initial DHCPv6 configuration should be stack.DHCPv6NoConfiguration. + // Even if the first RA reports no DHCPv6 configurations are available, the + // dispatcher should get an event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectDHCPv6Event(stack.DHCPv6NoConfiguration) + // Receiving the same update again should not result in an event to the + // dispatcher. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) expectNoDHCPv6Event() @@ -4896,8 +4901,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectDHCPv6Event(stack.DHCPv6OtherConfigurations) - // Receiving the same update again should not result in an event to the - // NDPDispatcher. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() @@ -4933,6 +4936,21 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { expectDHCPv6Event(stack.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() + + // Cycling the NIC should cause the last DHCPv6 configuration to be cleared. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + // Receive an RA that updates the DHCPv6 configuration to Other + // Configurations. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectNoDHCPv6Event() } // TestRouterSolicitation tests the initial Router Solicitations that are sent -- cgit v1.2.3 From cfd30665c1d857f20dd05e67c6da6833770e2141 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 8 May 2020 15:39:04 -0700 Subject: iptables - filter packets using outgoing interface. Enables commands with -o (--out-interface) for iptables rules. $ iptables -A OUTPUT -o eth0 -j ACCEPT PiperOrigin-RevId: 310642286 --- pkg/abi/linux/netfilter.go | 2 +- pkg/sentry/socket/netfilter/netfilter.go | 40 ++++++-- pkg/tcpip/network/ipv4/ipv4.go | 8 +- pkg/tcpip/stack/iptables.go | 42 ++++++-- pkg/tcpip/stack/iptables_types.go | 13 +++ pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack.go | 16 ++- test/iptables/filter_output.go | 170 +++++++++++++++++++++++++++++++ test/iptables/iptables_test.go | 24 +++++ test/iptables/iptables_util.go | 15 +++ 10 files changed, 309 insertions(+), 23 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index a8d4f9d69..46d8b0b42 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -146,7 +146,7 @@ type IPTIP struct { // OutputInterface is the output network interface. OutputInterface [IFNAMSIZ]byte - // InputInterfaceMask is the intput interface mask. + // InputInterfaceMask is the input interface mask. InputInterfaceMask [IFNAMSIZ]byte // OuputInterfaceMask is the output interface mask. diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 40736fb38..41a1ce031 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -17,6 +17,7 @@ package netfilter import ( + "bytes" "errors" "fmt" @@ -204,6 +205,16 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI TargetOffset: linux.SizeOfIPTEntry, }, } + copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst) + copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + if rule.Filter.DstInvert { + entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP + } + if rule.Filter.OutputInterfaceInvert { + entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT + } for _, matcher := range rule.Matchers { // Serialize the matcher and add it to the @@ -719,11 +730,27 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) } + + n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) + if n == -1 { + n = len(iptip.OutputInterface) + } + ifname := string(iptip.OutputInterface[:n]) + + n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) + if n == -1 { + n = len(iptip.OutputInterfaceMask) + } + ifnameMask := string(iptip.OutputInterfaceMask[:n]) + return stack.IPHeaderFilter{ - Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), - Dst: tcpip.Address(iptip.Dst[:]), - DstMask: tcpip.Address(iptip.DstMask[:]), - DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, + Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), + Dst: tcpip.Address(iptip.Dst[:]), + DstMask: tcpip.Address(iptip.DstMask[:]), + DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, + OutputInterface: ifname, + OutputInterfaceMask: ifnameMask, + OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, }, nil } @@ -732,16 +759,15 @@ func containsUnsupportedFields(iptip linux.IPTIP) bool { // - Protocol // - Dst and DstMask // - The inverse destination IP check flag + // - OutputInterface, OutputInterfaceMask and its inverse. var emptyInetAddr = linux.InetAddr{} var emptyInterface = [linux.IFNAMSIZ]byte{} // Disable any supported inverse flags. - inverseMask := uint8(linux.IPT_INV_DSTIP) + inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_VIA_OUT) return iptip.Src != emptyInetAddr || iptip.SrcMask != emptyInetAddr || iptip.InputInterface != emptyInterface || - iptip.OutputInterface != emptyInterface || iptip.InputInterfaceMask != emptyInterface || - iptip.OutputInterfaceMask != emptyInterface || iptip.Flags != 0 || iptip.InverseFlags&^inverseMask != 0 } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 9db42b2a4..64046cbbf 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -249,10 +249,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + nicName := e.stack.FindNICNameFromID(e.NICID()) // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Output, &pkt, gso, r, ""); !ok { + if ok := ipt.Check(stack.Output, &pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. return nil } @@ -319,10 +320,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe pkt = pkt.Next() } + nicName := e.stack.FindNICNameFromID(e.NICID()) // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r) + dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -445,7 +447,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Input, &pkt, nil, nil, ""); !ok { + if ok := ipt.Check(stack.Input, &pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 7c3c47d50..443423b3c 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "strings" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -178,7 +179,7 @@ const ( // dropped. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool { // Packets are manipulated only if connection and matching // NAT rule exists. it.connections.HandlePacket(pkt, hook, gso, r) @@ -187,7 +188,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr for _, tablename := range it.Priorities[hook] { table := it.Tables[tablename] ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -228,10 +229,10 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, gso, r, ""); !ok { + if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -251,11 +252,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a // precondition. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { case RuleAccept: return chainAccept @@ -272,7 +273,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -298,7 +299,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a // precondition. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in @@ -313,7 +314,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // Check whether the packet matches the IP header filter. - if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader)) { + if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader), hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -335,7 +336,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) } -func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { +func filterMatch(filter IPHeaderFilter, hdr header.IPv4, hook Hook, nicName string) bool { // TODO(gvisor.dev/issue/170): Support other fields of the filter. // Check the transport protocol. if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { @@ -355,5 +356,26 @@ func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { return false } + // Check the output interface. + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + if hook == Output { + n := len(filter.OutputInterface) + if n == 0 { + return true + } + + // If the interface name ends with '+', any interface which begins + // with the name should be matched. + ifName := filter.OutputInterface + matches = true + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return filter.OutputInterfaceInvert != matches + } + return true } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 1bb0ba1bd..fe06007ae 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -158,6 +158,19 @@ type IPHeaderFilter struct { // true the filter will match packets that fail the destination // comparison. DstInvert bool + + // OutputInterface matches the name of the outgoing interface for the + // packet. + OutputInterface string + + // OutputInterfaceMask masks the characters of the interface name when + // comparing with OutputInterface. + OutputInterfaceMask string + + // OutputInterfaceInvert inverts the meaning of outgoing interface check, + // i.e. when true the filter will match packets that fail the outgoing + // interface comparison. + OutputInterfaceInvert bool } // A Matcher is the interface for matching packets. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8f4c1fe42..54103fdb3 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1233,7 +1233,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address); !ok { + if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address, ""); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index e33fae4eb..b39ffa9fb 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1898,9 +1898,23 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres nic.mu.RLock() defer nic.mu.RUnlock() - // An endpoint with this id exists, check if it can be used and return it. + // An endpoint with this id exists, check if it can be + // used and return it. return ref.ep, nil } } return nil, tcpip.ErrBadAddress } + +// FindNICNameFromID returns the name of the nic for the given NICID. +func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return "" + } + + return nic.Name() +} diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index b1382d353..c145bd1e9 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -29,6 +29,12 @@ func init() { RegisterTestCase(FilterOutputAcceptUDPOwner{}) RegisterTestCase(FilterOutputDropUDPOwner{}) RegisterTestCase(FilterOutputOwnerFail{}) + RegisterTestCase(FilterOutputInterfaceAccept{}) + RegisterTestCase(FilterOutputInterfaceDrop{}) + RegisterTestCase(FilterOutputInterface{}) + RegisterTestCase(FilterOutputInterfaceBeginsWith{}) + RegisterTestCase(FilterOutputInterfaceInvertDrop{}) + RegisterTestCase(FilterOutputInterfaceInvertAccept{}) } // FilterOutputDropTCPDestPort tests that connections are not accepted on @@ -286,3 +292,167 @@ func (FilterOutputInvertDestination) ContainerAction(ip net.IP) error { func (FilterOutputInvertDestination) LocalAction(ip net.IP) error { return listenUDP(acceptPort, sendloopDuration) } + +// FilterOutputInterfaceAccept tests that packets are sent via interface +// matching the iptables rule. +type FilterOutputInterfaceAccept struct{} + +// Name implements TestCase.Name. +func (FilterOutputInterfaceAccept) Name() string { + return "FilterOutputInterfaceAccept" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceAccept) ContainerAction(ip net.IP) error { + ifname, ok := getInterfaceName() + if !ok { + return fmt.Errorf("no interface is present, except loopback") + } + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "ACCEPT"); err != nil { + return err + } + + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceAccept) LocalAction(ip net.IP) error { + return listenUDP(acceptPort, sendloopDuration) +} + +// FilterOutputInterfaceDrop tests that packets are not sent via interface +// matching the iptables rule. +type FilterOutputInterfaceDrop struct{} + +// Name implements TestCase.Name. +func (FilterOutputInterfaceDrop) Name() string { + return "FilterOutputInterfaceDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceDrop) ContainerAction(ip net.IP) error { + ifname, ok := getInterfaceName() + if !ok { + return fmt.Errorf("no interface is present, except loopback") + } + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "DROP"); err != nil { + return err + } + + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceDrop) LocalAction(ip net.IP) error { + if err := listenUDP(acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort) + } + + return nil +} + +// FilterOutputInterface tests that packets are sent via interface which is +// not matching the interface name in the iptables rule. +type FilterOutputInterface struct{} + +// Name implements TestCase.Name. +func (FilterOutputInterface) Name() string { + return "FilterOutputInterface" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterface) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", "lo", "-j", "DROP"); err != nil { + return err + } + + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterface) LocalAction(ip net.IP) error { + return listenUDP(acceptPort, sendloopDuration) +} + +// FilterOutputInterfaceBeginsWith tests that packets are not sent via an +// interface which begins with the given interface name. +type FilterOutputInterfaceBeginsWith struct{} + +// Name implements TestCase.Name. +func (FilterOutputInterfaceBeginsWith) Name() string { + return "FilterOutputInterfaceBeginsWith" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceBeginsWith) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", "e+", "-j", "DROP"); err != nil { + return err + } + + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceBeginsWith) LocalAction(ip net.IP) error { + if err := listenUDP(acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort) + } + + return nil +} + +// FilterOutputInterfaceInvertDrop tests that we selectively do not send +// packets via interface not matching the interface name. +type FilterOutputInterfaceInvertDrop struct{} + +// Name implements TestCase.Name. +func (FilterOutputInterfaceInvertDrop) Name() string { + return "FilterOutputInterfaceInvertDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceInvertDrop) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "DROP"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + if err := listenTCP(acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } + + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceInvertDrop) LocalAction(ip net.IP) error { + if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { + return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) + } + + return nil +} + +// FilterOutputInterfaceInvertAccept tests that we can selectively send packets +// not matching the specific outgoing interface. +type FilterOutputInterfaceInvertAccept struct{} + +// Name implements TestCase.Name. +func (FilterOutputInterfaceInvertAccept) Name() string { + return "FilterOutputInterfaceInvertAccept" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterOutputInterfaceInvertAccept) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "ACCEPT"); err != nil { + return err + } + + // Listen for TCP packets on accept port. + return listenTCP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterOutputInterfaceInvertAccept) LocalAction(ip net.IP) error { + return connectTCP(ip, acceptPort, sendloopDuration) +} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 63a862d35..84eb75a40 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -167,6 +167,30 @@ func TestFilterOutputOwnerFail(t *testing.T) { singleTest(t, FilterOutputOwnerFail{}) } +func TestFilterOutputInterfaceAccept(t *testing.T) { + singleTest(t, FilterOutputInterfaceAccept{}) +} + +func TestFilterOutputInterfaceDrop(t *testing.T) { + singleTest(t, FilterOutputInterfaceDrop{}) +} + +func TestFilterOutputInterface(t *testing.T) { + singleTest(t, FilterOutputInterface{}) +} + +func TestFilterOutputInterfaceBeginsWith(t *testing.T) { + singleTest(t, FilterOutputInterfaceBeginsWith{}) +} + +func TestFilterOutputInterfaceInvertDrop(t *testing.T) { + singleTest(t, FilterOutputInterfaceInvertDrop{}) +} + +func TestFilterOutputInterfaceInvertAccept(t *testing.T) { + singleTest(t, FilterOutputInterfaceInvertAccept{}) +} + func TestJumpSerialize(t *testing.T) { singleTest(t, FilterInputSerializeJump{}) } diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index 2f988cd18..7146edbb9 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -169,3 +169,18 @@ func localAddrs() ([]string, error) { } return addrStrs, nil } + +// getInterfaceName returns the name of the interface other than loopback. +func getInterfaceName() (string, bool) { + var ifname string + if interfaces, err := net.Interfaces(); err == nil { + for _, intf := range interfaces { + if intf.Name != "lo" { + ifname = intf.Name + break + } + } + } + + return ifname, ifname != "" +} -- cgit v1.2.3 From 420b791a3d6e0e6e2fc30c6f8be013bce7ca6549 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Fri, 15 May 2020 20:03:54 -0700 Subject: Minor formatting updates for gvisor.dev. * Aggregate architecture Overview in "What is gVisor?" as it makes more sense in one place. * Drop "user-space kernel" and use "application kernel". The term "user-space kernel" is confusing when some platform implementation do not run in user-space (instead running in guest ring zero). * Clear up the relationship between the Platform page in the user guide and the Platform page in the architecture guide, and ensure they are cross-linked. * Restore the call-to-action quick start link in the main page, and drop the GitHub link (which also appears in the top-right). * Improve image formatting by centering all doc and blog images, and move the image captions to the alt text. PiperOrigin-RevId: 311845158 --- README.md | 79 ++--- g3doc/BUILD | 4 + g3doc/Layers.png | Bin 0 -> 11044 bytes g3doc/Layers.svg | 1 + g3doc/Machine-Virtualization.png | Bin 0 -> 13205 bytes g3doc/Machine-Virtualization.svg | 1 + g3doc/README.md | 161 +++++++++- g3doc/Rule-Based-Execution.png | Bin 0 -> 6780 bytes g3doc/Rule-Based-Execution.svg | 1 + g3doc/Sentry-Gofer.png | Bin 0 -> 9064 bytes g3doc/Sentry-Gofer.svg | 1 + g3doc/architecture_guide/BUILD | 30 +- g3doc/architecture_guide/Layers.png | Bin 11044 -> 0 bytes g3doc/architecture_guide/Layers.svg | 1 - .../architecture_guide/Machine-Virtualization.png | Bin 13205 -> 0 bytes .../architecture_guide/Machine-Virtualization.svg | 1 - g3doc/architecture_guide/README.md | 83 ----- g3doc/architecture_guide/Rule-Based-Execution.png | Bin 6780 -> 0 bytes g3doc/architecture_guide/Rule-Based-Execution.svg | 1 - g3doc/architecture_guide/Sentry-Gofer.png | Bin 9064 -> 0 bytes g3doc/architecture_guide/Sentry-Gofer.svg | 1 - g3doc/architecture_guide/performance.md | 35 ++- g3doc/architecture_guide/platforms.md | 109 +++---- g3doc/architecture_guide/platforms.png | Bin 0 -> 21384 bytes g3doc/architecture_guide/platforms.svg | 334 +++++++++++++++++++++ g3doc/architecture_guide/resources.md | 27 +- g3doc/architecture_guide/resources.png | Bin 0 -> 16621 bytes g3doc/architecture_guide/resources.svg | 208 +++++++++++++ g3doc/architecture_guide/security.md | 28 +- g3doc/architecture_guide/security.png | Bin 0 -> 16932 bytes g3doc/architecture_guide/security.svg | 153 ++++++++++ g3doc/user_guide/filesystem.md | 4 +- g3doc/user_guide/platforms.md | 100 +++--- pkg/sentry/arch/syscalls_arm64.go | 2 +- pkg/sentry/kernel/pipe/pipe_util.go | 2 +- pkg/sentry/kernel/task_syscall.go | 4 +- pkg/sentry/socket/netstack/netstack.go | 6 +- pkg/tcpip/stack/stack.go | 4 +- pkg/tcpip/transport/tcp/endpoint.go | 2 +- pkg/tcpip/transport/tcp/tcp_test.go | 2 +- runsc/cmd/help.go | 12 +- website/BUILD | 1 - website/_layouts/docs.html | 2 + website/_sass/front.scss | 4 +- website/_sass/style.scss | 10 + website/blog/2019-11-18-security-basics.md | 28 +- website/blog/2020-04-02-networking-security.md | 8 +- website/index.md | 10 +- 48 files changed, 1054 insertions(+), 406 deletions(-) create mode 100644 g3doc/Layers.png create mode 100644 g3doc/Layers.svg create mode 100644 g3doc/Machine-Virtualization.png create mode 100644 g3doc/Machine-Virtualization.svg create mode 100644 g3doc/Rule-Based-Execution.png create mode 100644 g3doc/Rule-Based-Execution.svg create mode 100644 g3doc/Sentry-Gofer.png create mode 100644 g3doc/Sentry-Gofer.svg delete mode 100644 g3doc/architecture_guide/Layers.png delete mode 100644 g3doc/architecture_guide/Layers.svg delete mode 100644 g3doc/architecture_guide/Machine-Virtualization.png delete mode 100644 g3doc/architecture_guide/Machine-Virtualization.svg delete mode 100644 g3doc/architecture_guide/README.md delete mode 100644 g3doc/architecture_guide/Rule-Based-Execution.png delete mode 100644 g3doc/architecture_guide/Rule-Based-Execution.svg delete mode 100644 g3doc/architecture_guide/Sentry-Gofer.png delete mode 100644 g3doc/architecture_guide/Sentry-Gofer.svg create mode 100644 g3doc/architecture_guide/platforms.png create mode 100644 g3doc/architecture_guide/platforms.svg create mode 100644 g3doc/architecture_guide/resources.png create mode 100644 g3doc/architecture_guide/resources.svg create mode 100644 g3doc/architecture_guide/security.png create mode 100644 g3doc/architecture_guide/security.svg (limited to 'pkg/tcpip/stack') diff --git a/README.md b/README.md index de3e06f4e..442f5672a 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ ## What is gVisor? -**gVisor** is a user-space kernel, written in Go, that implements a substantial +**gVisor** is a application kernel, written in Go, that implements a substantial portion of the Linux system surface. It includes an [Open Container Initiative (OCI)][oci] runtime called `runsc` that provides an isolation boundary between the application and the host kernel. The `runsc` @@ -15,16 +15,17 @@ containers. ## Why does gVisor exist? Containers are not a [**sandbox**][sandbox]. While containers have -revolutionized how we develop, package, and deploy applications, running -untrusted or potentially malicious code without additional isolation is not a -good idea. The efficiency and performance gains from using a single, shared -kernel also mean that container escape is possible with a single vulnerability. - -gVisor is a user-space kernel for containers. It limits the host kernel surface -accessible to the application while still giving the application access to all -the features it expects. Unlike most kernels, gVisor does not assume or require -a fixed set of physical resources; instead, it leverages existing host kernel -functionality and runs as a normal user-space process. In other words, gVisor +revolutionized how we develop, package, and deploy applications, using them to +run untrusted or potentially malicious code without additional isolation is not +a good idea. While using a single, shared kernel allows for efficiency and +performance gains, it also means that container escape is possible with a single +vulnerability. + +gVisor is an application kernel for containers. It limits the host kernel +surface accessible to the application while still giving the application access +to all the features it expects. Unlike most kernels, gVisor does not assume or +require a fixed set of physical resources; instead, it leverages existing host +kernel functionality and runs as a normal process. In other words, gVisor implements Linux by way of Linux. gVisor should not be confused with technologies and tools to harden containers @@ -39,33 +40,24 @@ be found at [gvisor.dev][gvisor-dev]. ## Installing from source -gVisor currently requires x86\_64 Linux to build, though support for other -architectures may become available in the future. +gVisor builds on x86_64 and ARM64. Other architectures may become available in +the future. + +For the purposes of these instructions, [bazel][bazel] and other build +dependencies are wrapped in a build container. It is possible to use +[bazel][bazel] directly, or type `make help` for standard targets. ### Requirements Make sure the following dependencies are installed: * Linux 4.14.77+ ([older linux][old-linux]) -* [git][git] -* [Bazel][bazel] 1.2+ -* [Python][python] * [Docker version 17.09.0 or greater][docker] -* C++ toolchain supporting C++17 (GCC 7+, Clang 5+) -* Gold linker (e.g. `binutils-gold` package on Ubuntu) ### Building Build and install the `runsc` binary: -``` -bazel build runsc -sudo cp ./bazel-bin/runsc/linux_amd64_pure_stripped/runsc /usr/local/bin -``` - -If you don't want to install bazel on your system, you can build runsc in a -Docker container: - ``` make runsc sudo cp ./bazel-bin/runsc/linux_amd64_pure_stripped/runsc /usr/local/bin @@ -73,41 +65,19 @@ sudo cp ./bazel-bin/runsc/linux_amd64_pure_stripped/runsc /usr/local/bin ### Testing -The test suite can be run with Bazel: - -``` -bazel test //... -``` - -or in a Docker container: +To run standard test suites, you can use: ``` make unit-tests make tests ``` -### Using remote execution - -If you have a [Remote Build Execution][rbe] environment, you can use it to speed -up build and test cycles. - -You must authenticate with the project first: +To run specific tests, you can specify the target: ``` -gcloud auth application-default login --no-launch-browser +make test TARGET="//runsc:version_test" ``` -Then invoke bazel with the following flags: - -``` ---config=remote ---project_id=$PROJECT ---remote_instance_name=projects/$PROJECT/instances/default_instance -``` - -You can also add those flags to your local ~/.bazelrc to avoid needing to -specify them each time on the command line. - ### Using `go get` This project uses [bazel][bazel] to build and manage dependencies. A synthetic @@ -128,7 +98,7 @@ development on this branch is not supported. Development should occur on the ## Community & Governance -The governance model is documented in our [community][community] repository. +See [GOVERNANCE.md](GOVERANCE.md) for project governance information. The [gvisor-users mailing list][gvisor-users-list] and [gvisor-dev mailing list][gvisor-dev-list] are good starting points for @@ -145,12 +115,9 @@ See [Contributing.md](CONTRIBUTING.md). [bazel]: https://bazel.build [community]: https://gvisor.googlesource.com/community [docker]: https://www.docker.com -[git]: https://git-scm.com [gvisor-users-list]: https://groups.google.com/forum/#!forum/gvisor-users +[gvisor-dev]: https://gvisor.dev [gvisor-dev-list]: https://groups.google.com/forum/#!forum/gvisor-dev [oci]: https://www.opencontainers.org [old-linux]: https://gvisor.dev/docs/user_guide/networking/#gso -[python]: https://python.org -[rbe]: https://blog.bazel.build/2018/10/05/remote-build-execution.html [sandbox]: https://en.wikipedia.org/wiki/Sandbox_(computer_security) -[gvisor-dev]: https://gvisor.dev diff --git a/g3doc/BUILD b/g3doc/BUILD index 24177ad06..dbbf96204 100644 --- a/g3doc/BUILD +++ b/g3doc/BUILD @@ -9,6 +9,10 @@ doc( name = "index", src = "README.md", category = "Project", + data = glob([ + "*.png", + "*.svg", + ]), permalink = "/docs/", weight = "0", ) diff --git a/g3doc/Layers.png b/g3doc/Layers.png new file mode 100644 index 000000000..308c6c451 Binary files /dev/null and b/g3doc/Layers.png differ diff --git a/g3doc/Layers.svg b/g3doc/Layers.svg new file mode 100644 index 000000000..0a366f841 --- /dev/null +++ b/g3doc/Layers.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/g3doc/Machine-Virtualization.png b/g3doc/Machine-Virtualization.png new file mode 100644 index 000000000..1ba2ed6b2 Binary files /dev/null and b/g3doc/Machine-Virtualization.png differ diff --git a/g3doc/Machine-Virtualization.svg b/g3doc/Machine-Virtualization.svg new file mode 100644 index 000000000..5352da07b --- /dev/null +++ b/g3doc/Machine-Virtualization.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/g3doc/README.md b/g3doc/README.md index 7999f5d47..304a91493 100644 --- a/g3doc/README.md +++ b/g3doc/README.md @@ -1,6 +1,6 @@ # What is gVisor? -gVisor is a user-space kernel, written in Go, that implements a substantial +gVisor is an application kernel, written in Go, that implements a substantial portion of the [Linux system call interface][linux]. It provides an additional layer of isolation between running applications and the host operating system. @@ -9,19 +9,160 @@ that makes it easy to work with existing container tooling. The `runsc` runtime integrates with Docker and Kubernetes, making it simple to run sandboxed containers. -gVisor takes a distinct approach to container sandboxing and makes a different -set of technical trade-offs compared to existing sandbox technologies, thus -providing new tools and ideas for the container security landscape. - gVisor can be used with Docker, Kubernetes, or directly using `runsc`. Use the links below to see detailed instructions for each of them: -* [Docker](./user_guide/quick_start/docker/): The quickest and easiest way to - get started. -* [Kubernetes](./user_guide/quick_start/kubernetes/): Isolate Pods in your K8s - cluster with gVisor. -* [OCI Quick Start](./user_guide/quick_start/oci/): Expert mode. Customize +* [Docker](./user_guide/quick_start/docker.md): The quickest and easiest way + to get started. +* [Kubernetes](./user_guide/quick_start/kubernetes.md): Isolate Pods in your + K8s cluster with gVisor. +* [OCI Quick Start](./user_guide/quick_start/oci.md): Expert mode. Customize gVisor for your environment. +## What does gVisor do? + +gVisor provides a virtualized environment in order to sandbox containers. The +system interfaces normally implemented by the host kernel are moved into a +distinct, per-sandbox application kernel in order to minimize the risk of an +container escape exploit. gVisor does not introduce large fixed overheads +however, and still retains a process-like model with respect to resource +utilization. + +## How is this different? + +Two other approaches are commonly taken to provide stronger isolation than +native containers. + +**Machine-level virtualization**, such as [KVM][kvm] and [Xen][xen], exposes +virtualized hardware to a guest kernel via a Virtual Machine Monitor (VMM). This +virtualized hardware is generally enlightened (paravirtualized) and additional +mechanisms can be used to improve the visibility between the guest and host +(e.g. balloon drivers, paravirtualized spinlocks). Running containers in +distinct virtual machines can provide great isolation, compatibility and +performance (though nested virtualization may bring challenges in this area), +but for containers it often requires additional proxies and agents, and may +require a larger resource footprint and slower start-up times. + +![Machine-level virtualization](Machine-Virtualization.png "Machine-level virtualization") + +**Rule-based execution**, such as [seccomp][seccomp], [SELinux][selinux] and +[AppArmor][apparmor], allows the specification of a fine-grained security policy +for an application or container. These schemes typically rely on hooks +implemented inside the host kernel to enforce the rules. If the surface can be +made small enough, then this is an excellent way to sandbox applications and +maintain native performance. However, in practice it can be extremely difficult +(if not impossible) to reliably define a policy for arbitrary, previously +unknown applications, making this approach challenging to apply universally. + +![Rule-based execution](Rule-Based-Execution.png "Rule-based execution") + +Rule-based execution is often combined with additional layers for +defense-in-depth. + +**gVisor** provides a third isolation mechanism, distinct from those above. + +gVisor intercepts application system calls and acts as the guest kernel, without +the need for translation through virtualized hardware. gVisor may be thought of +as either a merged guest kernel and VMM, or as seccomp on steroids. This +architecture allows it to provide a flexible resource footprint (i.e. one based +on threads and memory mappings, not fixed guest physical resources) while also +lowering the fixed costs of virtualization. However, this comes at the price of +reduced application compatibility and higher per-system call overhead. + +![gVisor](Layers.png "gVisor") + +On top of this, gVisor employs rule-based execution to provide defense-in-depth +(details below). + +gVisor's approach is similar to [User Mode Linux (UML)][uml], although UML +virtualizes hardware internally and thus provides a fixed resource footprint. + +Each of the above approaches may excel in distinct scenarios. For example, +machine-level virtualization will face challenges achieving high density, while +gVisor may provide poor performance for system call heavy workloads. + +## Why Go? + +gVisor is written in [Go][golang] in order to avoid security pitfalls that can +plague kernels. With Go, there are strong types, built-in bounds checks, no +uninitialized variables, no use-after-free, no stack overflow, and a built-in +race detector. However, the use of Go has its challenges, and the runtime often +introduces performance overhead. + +## What are the different components? + +A gVisor sandbox consists of multiple processes. These processes collectively +comprise an environment in which one or more containers can be run. + +Each sandbox has its own isolated instance of: + +* The **Sentry**, which is a kernel that runs the containers and intercepts + and responds to system calls made by the application. + +Each container running in the sandbox has its own isolated instance of: + +* A **Gofer** which provides file system access to the containers. + +![gVisor architecture diagram](Sentry-Gofer.png "gVisor architecture diagram") + +## What is runsc? + +The entrypoint to running a sandboxed container is the `runsc` executable. +`runsc` implements the [Open Container Initiative (OCI)][oci] runtime +specification, which is used by Docker and Kubernetes. This means that OCI +compatible _filesystem bundles_ can be run by `runsc`. Filesystem bundles are +comprised of a `config.json` file containing container configuration, and a root +filesystem for the container. Please see the [OCI runtime spec][runtime-spec] +for more information on filesystem bundles. `runsc` implements multiple commands +that perform various functions such as starting, stopping, listing, and querying +the status of containers. + +### Sentry + + + +The Sentry is the largest component of gVisor. It can be thought of as a +application kernel. The Sentry implements all the kernel functionality needed by +the application, including: system calls, signal delivery, memory management and +page faulting logic, the threading model, and more. + +When the application makes a system call, the +[Platform](./architecture_guide/platforms.md) redirects the call to the Sentry, +which will do the necessary work to service it. It is important to note that the +Sentry does not pass system calls through to the host kernel. As a userspace +application, the Sentry will make some host system calls to support its +operation, but it does not allow the application to directly control the system +calls it makes. For example, the Sentry is not able to open files directly; file +system operations that extend beyond the sandbox (not internal `/proc` files, +pipes, etc) are sent to the Gofer, described below. + +### Gofer + + + +The Gofer is a standard host process which is started with each container and +communicates with the Sentry via the [9P protocol][9p] over a socket or shared +memory channel. The Sentry process is started in a restricted seccomp container +without access to file system resources. The Gofer mediates all access to the +these resources, providing an additional level of isolation. + +### Application + +The application is a normal Linux binary provided to gVisor in an OCI runtime +bundle. gVisor aims to provide an environment equivalent to Linux v4.4, so +applications should be able to run unmodified. However, gVisor does not +presently implement every system call, `/proc` file, or `/sys` file so some +incompatibilities may occur. See [Commpatibility](./user_guide/compatibility.md) +for more information. + +[9p]: https://en.wikipedia.org/wiki/9P_(protocol) +[apparmor]: https://wiki.ubuntu.com/AppArmor +[golang]: https://golang.org +[kvm]: https://www.linux-kvm.org [linux]: https://en.wikipedia.org/wiki/Linux_kernel_interfaces [oci]: https://www.opencontainers.org +[runtime-spec]: https://github.com/opencontainers/runtime-spec +[seccomp]: https://www.kernel.org/doc/Documentation/prctl/seccomp_filter.txt +[selinux]: https://selinuxproject.org +[uml]: http://user-mode-linux.sourceforge.net/ +[xen]: https://www.xenproject.org diff --git a/g3doc/Rule-Based-Execution.png b/g3doc/Rule-Based-Execution.png new file mode 100644 index 000000000..b42654a90 Binary files /dev/null and b/g3doc/Rule-Based-Execution.png differ diff --git a/g3doc/Rule-Based-Execution.svg b/g3doc/Rule-Based-Execution.svg new file mode 100644 index 000000000..bd6717043 --- /dev/null +++ b/g3doc/Rule-Based-Execution.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/g3doc/Sentry-Gofer.png b/g3doc/Sentry-Gofer.png new file mode 100644 index 000000000..ca2c27ef7 Binary files /dev/null and b/g3doc/Sentry-Gofer.png differ diff --git a/g3doc/Sentry-Gofer.svg b/g3doc/Sentry-Gofer.svg new file mode 100644 index 000000000..5c10750d2 --- /dev/null +++ b/g3doc/Sentry-Gofer.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/g3doc/architecture_guide/BUILD b/g3doc/architecture_guide/BUILD index 72038305b..404f627a4 100644 --- a/g3doc/architecture_guide/BUILD +++ b/g3doc/architecture_guide/BUILD @@ -5,31 +5,13 @@ package( licenses = ["notice"], ) -doc( - name = "index", - src = "README.md", - category = "Architecture Guide", - data = [ - "Layers.png", - "Layers.svg", - "Machine-Virtualization.png", - "Machine-Virtualization.svg", - "Rule-Based-Execution.png", - "Rule-Based-Execution.svg", - "Sentry-Gofer.png", - "Sentry-Gofer.svg", - ], - permalink = "/docs/architecture_guide/", - weight = "0", -) - doc( name = "platforms", src = "platforms.md", category = "Architecture Guide", data = [ - "Sentry-Gofer.png", - "Sentry-Gofer.svg", + "platforms.png", + "platforms.svg", ], permalink = "/docs/architecture_guide/platforms/", weight = "40", @@ -39,6 +21,10 @@ doc( name = "resources", src = "resources.md", category = "Architecture Guide", + data = [ + "resources.png", + "resources.svg", + ], permalink = "/docs/architecture_guide/resources/", weight = "30", ) @@ -48,8 +34,8 @@ doc( src = "security.md", category = "Architecture Guide", data = [ - "Layers.png", - "Layers.svg", + "security.png", + "security.svg", ], permalink = "/docs/architecture_guide/security/", weight = "10", diff --git a/g3doc/architecture_guide/Layers.png b/g3doc/architecture_guide/Layers.png deleted file mode 100644 index 308c6c451..000000000 Binary files a/g3doc/architecture_guide/Layers.png and /dev/null differ diff --git a/g3doc/architecture_guide/Layers.svg b/g3doc/architecture_guide/Layers.svg deleted file mode 100644 index 0a366f841..000000000 --- a/g3doc/architecture_guide/Layers.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/g3doc/architecture_guide/Machine-Virtualization.png b/g3doc/architecture_guide/Machine-Virtualization.png deleted file mode 100644 index 1ba2ed6b2..000000000 Binary files a/g3doc/architecture_guide/Machine-Virtualization.png and /dev/null differ diff --git a/g3doc/architecture_guide/Machine-Virtualization.svg b/g3doc/architecture_guide/Machine-Virtualization.svg deleted file mode 100644 index 5352da07b..000000000 --- a/g3doc/architecture_guide/Machine-Virtualization.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/g3doc/architecture_guide/README.md b/g3doc/architecture_guide/README.md deleted file mode 100644 index ab9ef7174..000000000 --- a/g3doc/architecture_guide/README.md +++ /dev/null @@ -1,83 +0,0 @@ -# Overview - -gVisor provides a virtualized environment in order to sandbox untrusted -containers. The system interfaces normally implemented by the host kernel are -moved into a distinct, per-sandbox user space kernel in order to minimize the -risk of an exploit. gVisor does not introduce large fixed overheads however, and -still retains a process-like model with respect to resource utilization. - -## How is this different? - -Two other approaches are commonly taken to provide stronger isolation than -native containers. - -**Machine-level virtualization**, such as [KVM][kvm] and [Xen][xen], exposes -virtualized hardware to a guest kernel via a Virtual Machine Monitor (VMM). This -virtualized hardware is generally enlightened (paravirtualized) and additional -mechanisms can be used to improve the visibility between the guest and host -(e.g. balloon drivers, paravirtualized spinlocks). Running containers in -distinct virtual machines can provide great isolation, compatibility and -performance (though nested virtualization may bring challenges in this area), -but for containers it often requires additional proxies and agents, and may -require a larger resource footprint and slower start-up times. - -![Machine-level virtualization](Machine-Virtualization.png "Machine-level virtualization") - -**Rule-based execution**, such as [seccomp][seccomp], [SELinux][selinux] and -[AppArmor][apparmor], allows the specification of a fine-grained security policy -for an application or container. These schemes typically rely on hooks -implemented inside the host kernel to enforce the rules. If the surface can be -made small enough (i.e. a sufficiently complete policy defined), then this is an -excellent way to sandbox applications and maintain native performance. However, -in practice it can be extremely difficult (if not impossible) to reliably define -a policy for arbitrary, previously unknown applications, making this approach -challenging to apply universally. - -![Rule-based execution](Rule-Based-Execution.png "Rule-based execution") - -Rule-based execution is often combined with additional layers for -defense-in-depth. - -**gVisor** provides a third isolation mechanism, distinct from those above. - -gVisor intercepts application system calls and acts as the guest kernel, without -the need for translation through virtualized hardware. gVisor may be thought of -as either a merged guest kernel and VMM, or as seccomp on steroids. This -architecture allows it to provide a flexible resource footprint (i.e. one based -on threads and memory mappings, not fixed guest physical resources) while also -lowering the fixed costs of virtualization. However, this comes at the price of -reduced application compatibility and higher per-system call overhead. - -![gVisor](Layers.png "gVisor") - -On top of this, gVisor employs rule-based execution to provide defense-in-depth -(details below). - -gVisor's approach is similar to [User Mode Linux (UML)][uml], although UML -virtualizes hardware internally and thus provides a fixed resource footprint. - -Each of the above approaches may excel in distinct scenarios. For example, -machine-level virtualization will face challenges achieving high density, while -gVisor may provide poor performance for system call heavy workloads. - -### Why Go? - -gVisor is written in [Go][golang] in order to avoid security pitfalls that can -plague kernels. With Go, there are strong types, built-in bounds checks, no -uninitialized variables, no use-after-free, no stack overflow, and a built-in -race detector. (The use of Go has its challenges too, and isn't free.) - -### What about Gofers? - - - -Gofers mediate file system interactions, and are used to provide additional -isolation. For more details, see the [Platform Guide](./platforms.md). - -[apparmor]: https://wiki.ubuntu.com/AppArmor -[golang]: https://golang.org -[kvm]: https://www.linux-kvm.org -[seccomp]: https://www.kernel.org/doc/Documentation/prctl/seccomp_filter.txt -[selinux]: https://selinuxproject.org -[uml]: http://user-mode-linux.sourceforge.net/ -[xen]: https://www.xenproject.org diff --git a/g3doc/architecture_guide/Rule-Based-Execution.png b/g3doc/architecture_guide/Rule-Based-Execution.png deleted file mode 100644 index b42654a90..000000000 Binary files a/g3doc/architecture_guide/Rule-Based-Execution.png and /dev/null differ diff --git a/g3doc/architecture_guide/Rule-Based-Execution.svg b/g3doc/architecture_guide/Rule-Based-Execution.svg deleted file mode 100644 index bd6717043..000000000 --- a/g3doc/architecture_guide/Rule-Based-Execution.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/g3doc/architecture_guide/Sentry-Gofer.png b/g3doc/architecture_guide/Sentry-Gofer.png deleted file mode 100644 index ca2c27ef7..000000000 Binary files a/g3doc/architecture_guide/Sentry-Gofer.png and /dev/null differ diff --git a/g3doc/architecture_guide/Sentry-Gofer.svg b/g3doc/architecture_guide/Sentry-Gofer.svg deleted file mode 100644 index 5c10750d2..000000000 --- a/g3doc/architecture_guide/Sentry-Gofer.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/g3doc/architecture_guide/performance.md b/g3doc/architecture_guide/performance.md index 3862d78ee..39dbb0045 100644 --- a/g3doc/architecture_guide/performance.md +++ b/g3doc/architecture_guide/performance.md @@ -13,12 +13,13 @@ forms: additional cycles and memory usage, which may manifest as increased latency, reduced throughput or density, or not at all. In general, these costs come from two different sources. -First, the existence of the [Sentry](../) means that additional memory will be -required, and application system calls must traverse additional layers of -software. The design emphasizes [security](../security/) and therefore we chose -to use a language for the Sentry that provides benefits in this domain but may -not yet offer the raw performance of other choices. Costs imposed by these -design choices are **structural costs**. +First, the existence of the [Sentry](../README.md#sentry) means that additional +memory will be required, and application system calls must traverse additional +layers of software. The design emphasizes +[security](/docs/architecture_guide/security/) and therefore we chose to use a +language for the Sentry that provides benefits in this domain but may not yet +offer the raw performance of other choices. Costs imposed by these design +choices are **structural costs**. Second, as gVisor is an independent implementation of the system call surface, many of the subsystems or specific calls are not as optimized as more mature @@ -50,7 +51,7 @@ Virtual Machines (VMs) with the following specifications: Through this document, `runsc` is used to indicate the runtime provided by gVisor. When relevant, we use the name `runsc-platform` to describe a specific -[platform choice](../platforms/). +[platform choice](/docs/architecture_guide/platforms/). **Except where specified, all tests below are conducted with the `ptrace` platform. The `ptrace` platform works everywhere and does not require hardware @@ -131,11 +132,11 @@ full start-up and run time for the workload, which trains a model. ## System calls Some **structural costs** of gVisor are heavily influenced by the -[platform choice](../platforms/), which implements system call interception. -Today, gVisor supports a variety of platforms. These platforms present distinct -performance, compatibility and security trade-offs. For example, the KVM -platform has low overhead system call interception but runs poorly with nested -virtualization. +[platform choice](/docs/architecture_guide/platforms/), which implements system +call interception. Today, gVisor supports a variety of platforms. These +platforms present distinct performance, compatibility and security trade-offs. +For example, the KVM platform has low overhead system call interception but runs +poorly with nested virtualization. {% include graph.html id="syscall" url="/performance/syscall.csv" title="perf.py syscall --runtime=runc --runtime=runsc-ptrace --runtime=runsc-kvm" y_min="100" @@ -163,7 +164,8 @@ overhead. Some of these costs above are **structural costs**, and `redis` is likely to remain a challenging performance scenario. However, optimizing the -[platform](../platforms/) will also have a dramatic impact. +[platform](/docs/architecture_guide/platforms/) will also have a dramatic +impact. ## Start-up time @@ -184,7 +186,7 @@ similarly loads a number of modules and binds an HTTP server. > Note: most of the time overhead above is associated Docker itself. This is > evident with the empty `runc` benchmark. To avoid these costs with `runsc`, > you may also consider using `runsc do` mode or invoking the -> [OCI runtime](../../user_guide/quick_start/oci/) directly. +> [OCI runtime](../user_guide/quick_start/oci.md) directly. ## Network @@ -222,8 +224,9 @@ In terms of raw disk I/O, gVisor does not introduce significant fundamental overhead. For general file operations, gVisor introduces a small fixed overhead for data that transitions across the sandbox boundary. This manifests as **structural costs** in some cases, since these operations must be routed -through the [Gofer](../) as a result of our [security model](../security/), but -in most cases are dominated by **implementation costs**, due to an internal +through the [Gofer](../README.md#gofer) as a result of our +[Security Model](/docs/architecture_guide/security/), but in most cases are +dominated by **implementation costs**, due to an internal [Virtual File System][vfs] (VFS) implementation that needs improvement. {% include graph.html id="fio-bw" url="/performance/fio.csv" title="perf.py fio diff --git a/g3doc/architecture_guide/platforms.md b/g3doc/architecture_guide/platforms.md index 6e63da8ce..d112c9a28 100644 --- a/g3doc/architecture_guide/platforms.md +++ b/g3doc/architecture_guide/platforms.md @@ -1,86 +1,61 @@ # Platform Guide -A gVisor sandbox consists of multiple processes when running. These processes -collectively comprise a shared environment in which one or more containers can -be run. +[TOC] -Each sandbox has its own isolated instance of: - -* The **Sentry**, A user-space kernel that runs the container and intercepts - and responds to system calls made by the application. - -Each container running in the sandbox has its own isolated instance of: - -* A **Gofer** which provides file system access to the container. - -![gVisor architecture diagram](Sentry-Gofer.png "gVisor architecture diagram") - -## runsc - -The entrypoint to running a sandboxed container is the `runsc` executable. -`runsc` implements the [Open Container Initiative (OCI)][oci] runtime -specification. This means that OCI compatible _filesystem bundles_ can be run by -`runsc`. Filesystem bundles are comprised of a `config.json` file containing -container configuration, and a root filesystem for the container. Please see the -[OCI runtime spec][runtime-spec] for more information on filesystem bundles. -`runsc` implements multiple commands that perform various functions such as -starting, stopping, listing, and querying the status of containers. +gVisor requires a platform to implement interception of syscalls, basic context +switching, and memory mapping functionality. Internally, gVisor uses an +abstraction sensibly called [Platform][platform]. A simplified version of this +interface looks like: -## Sentry +```golang +type Platform interface { + NewAddressSpace() (AddressSpace, error) + NewContext() Context +} -The Sentry is the largest component of gVisor. It can be thought of as a -userspace OS kernel. The Sentry implements all the kernel functionality needed -by the untrusted application. It implements all of the supported system calls, -signal delivery, memory management and page faulting logic, the threading model, -and more. +type Context interface { + Switch(as AddressSpace, ac arch.Context) (..., error) +} -When the untrusted application makes a system call, the currently used platform -redirects the call to the Sentry, which will do the necessary work to service -it. It is important to note that the Sentry will not simply pass through system -calls to the host kernel. As a userspace application, the Sentry will make some -host system calls to support its operation, but it will not allow the -application to directly control the system calls it makes. +type AddressSpace interface { + MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, ...) error + Unmap(addr usermem.Addr, length uint64) +} +``` -The Sentry aims to present an equivalent environment to (upstream) Linux v4.4. +There are a number of different ways to implement this interface that come with +various trade-offs, generally around performance and hardware requirements. -File system operations that extend beyond the sandbox (not internal /proc files, -pipes, etc) are sent to the Gofer, described below. +## Implementations -## Platforms +The choice of platform depends on the context in which `runsc` is executing. In +general, virtualized platforms may be limited to platforms that do not require +hardware virtualized support (since the hardware is already in use): -gVisor requires a platform to implement interception of syscalls, basic context -switching, and memory mapping functionality. +![Platforms](platforms.png "Platform examples.") ### ptrace -The ptrace platform uses `PTRACE_SYSEMU` to execute user code without allowing -it to execute host system calls. This platform can run anywhere that ptrace -works (even VMs without nested virtualization). - -### KVM (experimental) +The ptrace platform uses [PTRACE_SYSEMU][ptrace] to execute user code without +allowing it to execute host system calls. This platform can run anywhere that +`ptrace` works (even VMs without nested virtualization), which is ubiquitous. -The KVM platform allows the Sentry to act as both guest OS and VMM, switching -back and forth between the two worlds seamlessly. The KVM platform can run on -bare-metal or in a VM with nested virtualization enabled. While there is no -virtualized hardware layer -- the sandbox retains a process model -- gVisor -leverages virtualization extensions available on modern processors in order to -improve isolation and performance of address space switches. +Unfortunately, the ptrace platform has high context switch overhead, so system +call-heavy applications may pay a [performance penalty](./performance.md). -## Gofer +### KVM -The Gofer is a normal host Linux process. The Gofer is started with each sandbox -and connected to the Sentry. The Sentry process is started in a restricted -seccomp container without access to file system resources. The Gofer provides -the Sentry access to file system resources via the 9P protocol and provides an -additional level of isolation. +The KVM platform uses the kernel's [KVM][kvm] functionality to allow the Sentry +to act as both guest OS and VMM. The KVM platform can run on bare-metal or in a +VM with nested virtualization enabled. While there is no virtualized hardware +layer -- the sandbox retains a process model -- gVisor leverages virtualization +extensions available on modern processors in order to improve isolation and +performance of address space switches. -## Application +## Changing Platforms -The application (aka the untrusted application) is a normal Linux binary -provided to gVisor in an OCI runtime bundle. gVisor aims to provide an -environment equivalent to Linux v4.4, so applications should be able to run -unmodified. However, gVisor does not presently implement every system call, -/proc file, or /sys file so some incompatibilities may occur. +See [Changing Platforms](../user_guide/platforms.md). -[oci]: https://www.opencontainers.org -[runtime-spec]: https://github.com/opencontainers/runtime-spec +[kvm]: https://www.kernel.org/doc/Documentation/virtual/kvm/api.txt +[platform]: https://cs.opensource.google/gvisor/gvisor/+/release-20190304.1:pkg/sentry/platform/platform.go;l=33 +[ptrace]: http://man7.org/linux/man-pages/man2/ptrace.2.html diff --git a/g3doc/architecture_guide/platforms.png b/g3doc/architecture_guide/platforms.png new file mode 100644 index 000000000..005d56feb Binary files /dev/null and b/g3doc/architecture_guide/platforms.png differ diff --git a/g3doc/architecture_guide/platforms.svg b/g3doc/architecture_guide/platforms.svg new file mode 100644 index 000000000..b0bac9ba7 --- /dev/null +++ b/g3doc/architecture_guide/platforms.svg @@ -0,0 +1,334 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + gVisor + workload + host + + + + gVisor + workload + + KVM + + + gVisor + workload + + ptrace + + ptrace + VM + + guest + + + gVisor + workload + + ptrace + + diff --git a/g3doc/architecture_guide/resources.md b/g3doc/architecture_guide/resources.md index 894f995ae..1dec37bd1 100644 --- a/g3doc/architecture_guide/resources.md +++ b/g3doc/architecture_guide/resources.md @@ -10,9 +10,10 @@ sandbox to be highly dynamic in terms of resource usage: spanning a large number of cores and large amount of memory when busy, and yielding those resources back to the host when not. -Some of the details here may depend on the [platform](../platforms/), but in -general this page describes the resource model used by gVisor. If you're not -familiar with the terms here, uou may want to start with the [Overview](../). +In order words, the shape of the sandbox should closely track the shape of the +sandboxed process: + +![Resource model](resources.png "Workloads of different shapes.") ## Processes @@ -23,9 +24,9 @@ the sandbox (e.g. via a [Docker exec][exec]). ## Networking -Similarly to processes, the sandbox attaches a network endpoint to the system, -but runs it's own network stack. All network resources, other than packets in -flight, exist only inside the sandbox, bound by relevant resource limits. +The sandbox attaches a network endpoint to the system, but runs it's own network +stack. All network resources, other than packets in flight on the host, exist +only inside the sandbox, bound by relevant resource limits. You can interact with network endpoints exposed by the sandbox, just as you would any other container, but network introspection similarly requires entering @@ -33,15 +34,14 @@ the sandbox. ## Files -Files may be backed by different implementations. For host-native files (where a -file descriptor is available), the Gofer may return a file descriptor to the -Sentry via [SCM_RIGHTS][scmrights][^1]. +Files in the sandbox may be backed by different implementations. For host-native +files (where a file descriptor is available), the Gofer may return a file +descriptor to the Sentry via [SCM_RIGHTS][scmrights][^1]. These files may be read from and written to through standard system calls, and also mapped into the associated application's address space. This allows the same host memory to be shared across multiple sandboxes, although this mechanism -does not preclude the use of side-channels (see the -[security model](../security/)). +does not preclude the use of side-channels (see [Security Model](./security.md). Note that some file systems exist only within the context of the sandbox. For example, in many cases a `tmpfs` mount will be available at `/tmp` or @@ -64,8 +64,9 @@ scheduling decisions about all application threads. ## Time Time in the sandbox is provided by the Sentry, through its own [vDSO][vdso] and -timekeeping implementation. This is divorced from the host time, and no state is -shared with the host, although the time will be initialized with the host clock. +time-keeping implementation. This is distinct from the host time, and no state +is shared with the host, although the time will be initialized with the host +clock. The Sentry runs timers to note the passage of time, much like a kernel running on hardware (though the timers are software timers, in this case). These timers diff --git a/g3doc/architecture_guide/resources.png b/g3doc/architecture_guide/resources.png new file mode 100644 index 000000000..f715008ec Binary files /dev/null and b/g3doc/architecture_guide/resources.png differ diff --git a/g3doc/architecture_guide/resources.svg b/g3doc/architecture_guide/resources.svg new file mode 100644 index 000000000..fd7805d90 --- /dev/null +++ b/g3doc/architecture_guide/resources.svg @@ -0,0 +1,208 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + gVisor + gVisor + gVisor + workload + workload + workload + host + + diff --git a/g3doc/architecture_guide/security.md b/g3doc/architecture_guide/security.md index f78586291..b99b86332 100644 --- a/g3doc/architecture_guide/security.md +++ b/g3doc/architecture_guide/security.md @@ -86,15 +86,17 @@ a substitute for a secure architecture*. ## Goals: Limiting Exposure -gVisor’s primary design goal is to minimize the System API attack vector while -still providing a process model. There are two primary security principles that -inform this design. First, the application’s direct interactions with the host -System API are intercepted by the Sentry, which implements the System API -instead. Second, the System API accessible to the Sentry itself is minimized to -a safer, restricted set. The first principle minimizes the possibility of direct -exploitation of the host System API by applications, and the second principle -minimizes indirect exploitability, which is the exploitation by an exploited or -buggy Sentry (e.g. chaining an exploit). +![Threat model](security.png "Threat model.") + +gVisor’s primary design goal is to minimize the System API attack vector through +multiple layers of defense, while still providing a process model. There are two +primary security principles that inform this design. First, the application’s +direct interactions with the host System API are intercepted by the Sentry, +which implements the System API instead. Second, the System API accessible to +the Sentry itself is minimized to a safer, restricted set. The first principle +minimizes the possibility of direct exploitation of the host System API by +applications, and the second principle minimizes indirect exploitability, which +is the exploitation by an exploited or buggy Sentry (e.g. chaining an exploit). The first principle is similar to the security basis for a Virtual Machine (VM). With a VM, an application’s interactions with the host are replaced by @@ -210,9 +212,9 @@ crashes are recorded and triaged to similarly identify material issues. ### Is this more or less secure than a Virtual Machine? The security of a VM depends to a large extent on what is exposed from the host -kernel and user space support code. For example, device emulation code in the +kernel and userspace support code. For example, device emulation code in the host kernel (e.g. APIC) or optimizations (e.g. vhost) can be more complex than a -simple system call, and exploits carry the same risks. Similarly, the user space +simple system call, and exploits carry the same risks. Similarly, the userspace support code is frequently unsandboxed, and exploits, while rare, may allow unfettered access to the system. @@ -245,8 +247,8 @@ In gVisor, the platforms that use ptrace operate differently. The stubs that are traced are never allowed to continue execution into the host kernel and complete a call directly. Instead, all system calls are interpreted and handled by the Sentry itself, who reflects resulting register state back into the tracee before -continuing execution in user space. This is very similar to the mechanism used -by User-Mode Linux (UML). +continuing execution in userspace. This is very similar to the mechanism used by +User-Mode Linux (UML). [dirtycow]: https://en.wikipedia.org/wiki/Dirty_COW [clang]: https://en.wikipedia.org/wiki/C_(programming_language) diff --git a/g3doc/architecture_guide/security.png b/g3doc/architecture_guide/security.png new file mode 100644 index 000000000..c29befbf6 Binary files /dev/null and b/g3doc/architecture_guide/security.png differ diff --git a/g3doc/architecture_guide/security.svg b/g3doc/architecture_guide/security.svg new file mode 100644 index 000000000..0575e2dec --- /dev/null +++ b/g3doc/architecture_guide/security.svg @@ -0,0 +1,153 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/g3doc/user_guide/filesystem.md b/g3doc/user_guide/filesystem.md index 6c69f42a1..cd00762dd 100644 --- a/g3doc/user_guide/filesystem.md +++ b/g3doc/user_guide/filesystem.md @@ -4,8 +4,8 @@ gVisor accesses the filesystem through a file proxy, called the Gofer. The gofer runs as a separate process, that is isolated from the sandbox. Gofer instances -communicate with their respective sentry using the 9P protocol. For a more -detailed explanation see [Overview > Gofer](../../architecture_guide/#gofer). +communicate with their respective sentry using the 9P protocol. For another +explanation see [What is gVisor?](../README.md). ## Sandbox overlay diff --git a/g3doc/user_guide/platforms.md b/g3doc/user_guide/platforms.md index eefb6b222..752025881 100644 --- a/g3doc/user_guide/platforms.md +++ b/g3doc/user_guide/platforms.md @@ -1,56 +1,27 @@ -# Platforms (KVM) +# Changing Platforms [TOC] -This document will help you set up your system to use a different gVisor -platform. +This guide described how to change the +[platform](../architecture_guide/platforms.md) used by `runsc`. -## What is a Platform? +## Prerequisites -gVisor requires a *platform* to implement interception of syscalls, basic -context switching, and memory mapping functionality. These are described in more -depth in the [Platform Design](../../architecture_guide/platforms/). +If you intend to run the KVM platform, you will also to have KVM installed on +your system. If you are running a Debian based system like Debian or Ubuntu you +can usually do this by ensuring the module is loaded, and permissions are +appropriately set on the `/dev/kvm` device. -## Selecting a Platform - -The platform is selected by the `--platform` command line flag passed to -`runsc`. By default, the ptrace platform is selected. To select a different -platform, modify your Docker configuration (`/etc/docker/daemon.json`) to pass -this argument: - -```json -{ - "runtimes": { - "runsc": { - "path": "/usr/local/bin/runsc", - "runtimeArgs": [ - "--platform=kvm" - ] - } - } -} -``` - -You must restart the Docker daemon after making changes to this file, typically -this is done via `systemd`: +If you have an Intel CPU: ```bash -sudo systemctl restart docker +sudo modprobe kvm-intel && sudo chmod a+rw /dev/kvm ``` -## Example: Using the KVM Platform - -The KVM platform is currently experimental; however, it provides several -benefits over the default ptrace platform. - -### Prerequisites - -You will also to have KVM installed on your system. If you are running a Debian -based system like Debian or Ubuntu you can usually do this by installing the -`qemu-kvm` package. +If you have an AMD CPU: ```bash -sudo apt-get install qemu-kvm +sudo modprobe kvm-amd && sudo chmod a+rw /dev/kvm ``` If you are using a virtual machine you will need to make sure that nested @@ -68,31 +39,22 @@ cause of security issues (e.g. [CVE-2018-12904](https://nvd.nist.gov/vuln/detail/CVE-2018-12904)). It is not recommended for production.*** -### Configuring Docker - -Per above, you will need to configure Docker to use `runsc` with the KVM -platform. You will remember from the Docker Quick Start that you configured -Docker to use `runsc` as the runtime. Docker allows you to add multiple runtimes -to the Docker configuration. +## Configuring Docker -Add a new entry for the KVM platform entry to your Docker configuration -(`/etc/docker/daemon.json`) in order to provide the `--platform=kvm` runtime -argument. - -In the end, the file should look something like: +The platform is selected by the `--platform` command line flag passed to +`runsc`. By default, the ptrace platform is selected. For example, to select the +KVM platform, modify your Docker configuration (`/etc/docker/daemon.json`) to +pass the `--platform` argument: ```json { "runtimes": { "runsc": { - "path": "/usr/local/bin/runsc" - }, - "runsc-kvm": { "path": "/usr/local/bin/runsc", "runtimeArgs": [ "--platform=kvm" ] - } + } } } ``` @@ -104,13 +66,27 @@ this is done via `systemd`: sudo systemctl restart docker ``` -## Running a container +Note that you may configure multiple runtimes using different platforms. For +example, the following configuration has one configuration for ptrace and one +for the KVM platform: -Now run your container using the `runsc-kvm` runtime. This will run the -container using the KVM platform: - -```bash -docker run --runtime=runsc-kvm --rm hello-world +```json +{ + "runtimes": { + "runsc-ptrace": { + "path": "/usr/local/bin/runsc", + "runtimeArgs": [ + "--platform=ptrace" + ] + }, + "runsc-kvm": { + "path": "/usr/local/bin/runsc", + "runtimeArgs": [ + "--platform=kvm" + ] + } + } +} ``` [nested-azure]: https://docs.microsoft.com/en-us/azure/virtual-machines/windows/nested-virtualization diff --git a/pkg/sentry/arch/syscalls_arm64.go b/pkg/sentry/arch/syscalls_arm64.go index 92d062513..95dfd1e90 100644 --- a/pkg/sentry/arch/syscalls_arm64.go +++ b/pkg/sentry/arch/syscalls_arm64.go @@ -23,7 +23,7 @@ const restartSyscallNr = uintptr(128) // // In linux, at the entry of the syscall handler(el0_svc_common()), value of R0 // is saved to the pt_regs.orig_x0 in kernel code. But currently, the orig_x0 -// was not accessible to the user space application, so we have to do the same +// was not accessible to the userspace application, so we have to do the same // operation in the sentry code to save the R0 value into the App context. func (c *context64) SyscallSaveOrig() { c.OrigR0 = c.Regs.Regs[0] diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index 5a1d4fd57..aacf28da2 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -144,7 +144,7 @@ func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArgume if v > math.MaxInt32 { v = math.MaxInt32 // Silently truncate. } - // Copy result to user-space. + // Copy result to userspace. _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ AddressSpaceActive: true, }) diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index c9db78e06..a5903b0b5 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -199,10 +199,10 @@ func (t *Task) doSyscall() taskRunState { // // On x86, register rax was shared by syscall number and return // value, and at the entry of the syscall handler, the rax was - // saved to regs.orig_rax which was exposed to user space. + // saved to regs.orig_rax which was exposed to userspace. // But on arm64, syscall number was passed through X8, and the X0 // was shared by the first syscall argument and return value. The - // X0 was saved to regs.orig_x0 which was not exposed to user space. + // X0 was saved to regs.orig_x0 which was not exposed to userspace. // So we have to do the same operation here to save the X0 value // into the task context. t.Arch().SyscallSaveOrig() diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9dea2b5ff..60df51dae 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -2718,7 +2718,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy v = math.MaxInt32 } - // Copy result to user-space. + // Copy result to userspace. _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ AddressSpaceActive: true, }) @@ -2787,7 +2787,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc if v > math.MaxInt32 { v = math.MaxInt32 } - // Copy result to user-space. + // Copy result to userspace. _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ AddressSpaceActive: true, }) @@ -2803,7 +2803,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc v = math.MaxInt32 } - // Copy result to user-space. + // Copy result to userspace. _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ AddressSpaceActive: true, }) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index b39ffa9fb..0ab4c3e19 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -235,11 +235,11 @@ type RcvBufAutoTuneParams struct { // was started. MeasureTime time.Time - // CopiedBytes is the number of bytes copied to user space since + // CopiedBytes is the number of bytes copied to userspace since // this measure began. CopiedBytes int - // PrevCopiedBytes is the number of bytes copied to user space in + // PrevCopiedBytes is the number of bytes copied to userspace in // the previous RTT period. PrevCopiedBytes int diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 71735029e..b5ba972f1 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1097,7 +1097,7 @@ func (e *endpoint) initialReceiveWindow() int { } // ModerateRecvBuf adjusts the receive buffer and the advertised window -// based on the number of bytes copied to user space. +// based on the number of bytes copied to userspace. func (e *endpoint) ModerateRecvBuf(copied int) { e.LockUser() defer e.UnlockUser() diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 0b4512c65..6ef32a1b3 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -5869,7 +5869,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // Invoke the moderation API. This is required for auto-tuning // to happen. This method is normally expected to be invoked // from a higher layer than tcpip.Endpoint. So we simulate - // copying to user-space by invoking it explicitly here. + // copying to userspace by invoking it explicitly here. c.EP.ModerateRecvBuf(totalCopied) // Now send a keep-alive packet to trigger an ACK so that we can diff --git a/runsc/cmd/help.go b/runsc/cmd/help.go index c7d210140..cd85dabbb 100644 --- a/runsc/cmd/help.go +++ b/runsc/cmd/help.go @@ -65,16 +65,10 @@ func (h *Help) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{} switch f.NArg() { case 0: fmt.Fprintf(h.cdr.Output, "Usage: %s \n\n", h.cdr.Name()) - fmt.Fprintf(h.cdr.Output, `runsc is a command line client for running applications packaged in the Open -Container Initiative (OCI) format. Applications run by runsc are run in an -isolated gVisor sandbox that emulates a Linux environment. + fmt.Fprintf(h.cdr.Output, `runsc is the gVisor container runtime. -gVisor is a user-space kernel, written in Go, that implements a substantial -portion of the Linux system call interface. It provides an additional layer -of isolation between running applications and the host operating system. - -Functionality is provided by subcommands. For additonal help on individual -subcommands use "%s %s ". +Functionality is provided by subcommands. For help with a specific subcommand, +use "%s %s ". `, h.cdr.Name(), h.Name()) h.cdr.VisitGroups(func(g *subcommands.CommandGroup) { diff --git a/website/BUILD b/website/BUILD index d6afd5f44..c97b2560b 100644 --- a/website/BUILD +++ b/website/BUILD @@ -138,7 +138,6 @@ docs( "//g3doc:community", "//g3doc:index", "//g3doc:roadmap", - "//g3doc/architecture_guide:index", "//g3doc/architecture_guide:performance", "//g3doc/architecture_guide:platforms", "//g3doc/architecture_guide:resources", diff --git a/website/_layouts/docs.html b/website/_layouts/docs.html index 33ea8e1de..549305089 100644 --- a/website/_layouts/docs.html +++ b/website/_layouts/docs.html @@ -51,7 +51,9 @@ categories: Create issue

{% endif %} +
{{ content }} +
diff --git a/website/_sass/front.scss b/website/_sass/front.scss index 44a7e3473..0e4208f3c 100644 --- a/website/_sass/front.scss +++ b/website/_sass/front.scss @@ -4,12 +4,14 @@ background-repeat: no-repeat; background-size: cover; background-blend-mode: darken; - background-color: rgba(0, 0, 0, 0.1); + background-color: rgba(0, 0, 0, 0.3); p { color: #fff; margin-top: 0; margin-bottom: 0; font-weight: 300; + font-size: 24px; + line-height: 30px; } } diff --git a/website/_sass/style.scss b/website/_sass/style.scss index 520ea469a..4deb945d4 100644 --- a/website/_sass/style.scss +++ b/website/_sass/style.scss @@ -142,3 +142,13 @@ table th { margin-top: 10px; margin-bottom: 20px; } + +.docs-content * img { + display: block; + margin: 20px auto; +} + +.blog-content * img { + display: block; + margin: 20px auto; +} diff --git a/website/blog/2019-11-18-security-basics.md b/website/blog/2019-11-18-security-basics.md index ed6d97ffe..fbdd511dd 100644 --- a/website/blog/2019-11-18-security-basics.md +++ b/website/blog/2019-11-18-security-basics.md @@ -56,15 +56,9 @@ in combination: redundant walls, scattered draw bridges, small bottle-neck entrances, moats, etc. A simplified version of the design is below -([more detailed version](/docs/architecture_guide/))[^2]: +([more detailed version](/docs/))[^2]: --------------------------------------------------------------------------------- - -![Figure 1](/assets/images/2019-11-18-security-basics-figure1.png) - -Figure 1: Simplified design of gVisor. - --------------------------------------------------------------------------------- +![Figure 1](/assets/images/2019-11-18-security-basics-figure1.png "Simplified design of gVisor.") In order to discuss design principles, the following components are important to know: @@ -134,13 +128,7 @@ minimum level of permission is required for it to perform its function. Specifically, the closer you are to the untrusted application, the less privilege you have. --------------------------------------------------------------------------------- - -![Figure 2](/assets/images/2019-11-18-security-basics-figure2.png) - -Figure 2: runsc components and their privileges. - --------------------------------------------------------------------------------- +![Figure 2](/assets/images/2019-11-18-security-basics-figure2.png "runsc components and their privileges.") This is evident in how runsc (the drop in gVisor binary for Docker/Kubernetes) constructs the sandbox. The Sentry has the least privilege possible (it can't @@ -222,15 +210,7 @@ the host Linux syscalls. In other words, with gVisor, applications get the vast majority (and growing) functionality of Linux containers for only 68 possible syscalls to the Host OS. 350 syscalls to 68 is attack surface reduction. --------------------------------------------------------------------------------- - -![Figure 3](/assets/images/2019-11-18-security-basics-figure3.png) - -Figure 3: Reduction of Attack Surface of the Syscall Table. Note that the -Senty's Syscall Emulation Layer keeps the Containerized Process from ever -calling the Host OS. - --------------------------------------------------------------------------------- +![Figure 3](/assets/images/2019-11-18-security-basics-figure3.png "Reduction of Attack Surface of the Syscall Table. Note that the Senty's Syscall Emulation Layer keeps the Containerized Process from ever calling the Host OS.") ## Secure-by-default diff --git a/website/blog/2020-04-02-networking-security.md b/website/blog/2020-04-02-networking-security.md index 78f0a6714..5a5e38fd7 100644 --- a/website/blog/2020-04-02-networking-security.md +++ b/website/blog/2020-04-02-networking-security.md @@ -69,13 +69,7 @@ a similar syscall). Moreover, because packets typically come from off-host (e.g. the internet), the Host OS's packet processing code has received a lot of scrutiny, hopefully resulting in a high degree of hardening. --------------------------------------------------------------------------------- - -![Figure 1](/assets/images/2020-04-02-networking-security-figure1.png) - -Figure 1: Netstack and gVisor - --------------------------------------------------------------------------------- +![Figure 1](/assets/images/2020-04-02-networking-security-figure1.png "Network and gVisor.") ## Writing a network stack diff --git a/website/index.md b/website/index.md index 95d5d16f0..84f877d49 100644 --- a/website/index.md +++ b/website/index.md @@ -3,10 +3,10 @@
-

gVisor is an application kernel and container runtime providing defense-in-depth for containers anywhere.

+

gVisor is an application kernel for containers that provides efficient defense-in-depth anywhere.

+ Quick start  Learn More  - GitHub 

@@ -19,8 +19,8 @@

Container-native Security

-

By providing each container with its own userspace kernel, gVisor limits - the attack surface of the host. This protection does not limit +

By providing each container with its own application kernel, gVisor + limits the attack surface of the host. This protection does not limit functionality: gVisor runs unmodified binaries and integrates with container orchestration systems, such as Docker and Kubernetes, and supports features such as volumes and sidecars.

@@ -43,7 +43,7 @@ The pluggable platform architecture of gVisor allows it to run anywhere, enabling consistent security policies across multiple environments without having to rearchitect your infrastructure.

- Get Started » + Read More »
-- cgit v1.2.3 From 835e8c89a5e8ed0008821b1036ce3f507394dce5 Mon Sep 17 00:00:00 2001 From: Sam Balana Date: Wed, 27 May 2020 17:34:04 -0700 Subject: Remove linkEP from DeliverNetworkPacket The specified LinkEndpoint is not being used in a significant way. No behavior change, existing tests pass. This change is a breaking change. PiperOrigin-RevId: 313496602 --- pkg/tcpip/link/channel/channel.go | 2 +- pkg/tcpip/link/fdbased/endpoint.go | 2 +- pkg/tcpip/link/fdbased/endpoint_test.go | 2 +- pkg/tcpip/link/fdbased/mmap.go | 2 +- pkg/tcpip/link/fdbased/packet_dispatchers.go | 4 ++-- pkg/tcpip/link/loopback/loopback.go | 4 ++-- pkg/tcpip/link/muxed/injectable.go | 2 +- pkg/tcpip/link/qdisc/fifo/endpoint.go | 4 ++-- pkg/tcpip/link/sharedmem/sharedmem.go | 2 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 2 +- pkg/tcpip/link/sniffer/sniffer.go | 4 ++-- pkg/tcpip/link/waitable/waitable.go | 4 ++-- pkg/tcpip/link/waitable/waitable_test.go | 8 ++++---- pkg/tcpip/stack/forwarder_test.go | 2 +- pkg/tcpip/stack/nic.go | 4 ++-- pkg/tcpip/stack/nic_test.go | 2 +- pkg/tcpip/stack/registration.go | 2 +- 17 files changed, 26 insertions(+), 26 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 9bf67686d..5eb78b398 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -187,7 +187,7 @@ func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack // InjectLinkAddr injects an inbound packet with a remote link address. func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt stack.PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) + e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) } // Attach saves the stack network-layer dispatcher for use later when packets diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index affa1bbdf..5ee508d48 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -642,7 +642,7 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { // InjectInbound injects an inbound packet. func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, pkt) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt) } // NewInjectable creates a new fd-based InjectableEndpoint. diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 3bfb15a8e..6f41a71a8 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -103,7 +103,7 @@ func (c *context) cleanup() { } } -func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { c.ch <- packetInfo{remote, protocol, pkt} } diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index fe2bf3b0b..ca4229ed6 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -191,7 +191,7 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, stack.PacketBuffer{ + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, stack.PacketBuffer{ Data: buffer.View(pkt).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 3aaabd1aa..26c96a655 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -145,7 +145,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { @@ -301,7 +301,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { LinkHeader: buffer.View(eth), } pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, pkt) + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. for i := 0; i < used; i++ { diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 073c84ef9..20d9e95f6 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw // Because we're immediately turning around and writing the packet back // to the rx path, we intentionally don't preserve the remote and local // link addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, stack.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -106,7 +106,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { } linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ Data: vv, LinkHeader: buffer.View(linkHeader), }) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index a5478ce17..f0769830a 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -81,7 +81,7 @@ func (m *InjectableEndpoint) IsAttached() bool { // InjectInbound implements stack.InjectableLinkEndpoint. func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, pkt) + m.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt) } // WritePackets writes outbound packets to the appropriate diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 54432194d..ec5c5048a 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -102,8 +102,8 @@ func (q *queueDispatcher) dispatchLoop() { } // DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) +func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { + e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } // Attach implements stack.LinkEndpoint.Attach. diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 0796d717e..f5dec0a7f 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -275,7 +275,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { // Send packet up the stack. eth := header.Ethernet(b[:header.EthernetMinimumSize]) - d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), stack.PacketBuffer{ + d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), stack.PacketBuffer{ Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 33f640b85..f3fc62607 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -131,7 +131,7 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ addr: remoteLinkAddr, diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index da1c520ae..b060d4627 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -120,9 +120,9 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { e.dumpPacket("recv", nil, protocol, &pkt) - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) + e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } // Attach implements the stack.LinkEndpoint interface. It saves the dispatcher diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 2b3741276..f5a05929f 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -50,12 +50,12 @@ func New(lower stack.LinkEndpoint) *Endpoint { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { if !e.dispatchGate.Enter() { return } - e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, pkt) + e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) e.dispatchGate.Leave() } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 54eb5322b..0a9b99f18 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { e.dispatchCount++ } @@ -120,21 +120,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 8084d50bc..344d60baa 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -209,7 +209,7 @@ func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber // InjectLinkAddr injects an inbound packet with a remote link address. func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) { - e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt) + e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) } // Attach saves the stack network-layer dispatcher for use later when packets diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 54103fdb3..9213cbb9a 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1167,7 +1167,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { n.mu.RLock() enabled := n.mu.enabled // If the NIC is not yet enabled, don't receive any packets. @@ -1240,7 +1240,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } if ref := n.getRef(protocol, dst); ref != nil { - handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, pkt) + handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt) return } diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index d672fc157..b01b3f476 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -44,7 +44,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket(nil, "", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket("", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index b331427c6..a25657cd7 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -326,7 +326,7 @@ type NetworkDispatcher interface { // packets sent via loopback), and won't have the field set. // // DeliverNetworkPacket takes ownership of pkt. - DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) + DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities -- cgit v1.2.3 From c55b84e16aeb4481106661e3877c50edbf281762 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 28 May 2020 16:44:15 -0700 Subject: Enable iptables source filtering (-s/--source) --- pkg/sentry/socket/netfilter/netfilter.go | 21 ++++++++--- pkg/tcpip/stack/iptables.go | 47 +----------------------- pkg/tcpip/stack/iptables_types.go | 62 ++++++++++++++++++++++++++++++++ test/iptables/filter_input.go | 60 +++++++++++++++++++++++++++++++ test/iptables/iptables_test.go | 8 +++++ 5 files changed, 147 insertions(+), 51 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 789bb94c8..47ff48c00 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -64,6 +64,8 @@ const enableLogging = false var emptyFilter = stack.IPHeaderFilter{ Dst: "\x00\x00\x00\x00", DstMask: "\x00\x00\x00\x00", + Src: "\x00\x00\x00\x00", + SrcMask: "\x00\x00\x00\x00", } // nflog logs messages related to the writing and reading of iptables. @@ -214,11 +216,16 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI } copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst) copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src) + copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask) copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface) copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) if rule.Filter.DstInvert { entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP } + if rule.Filter.SrcInvert { + entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP + } if rule.Filter.OutputInterfaceInvert { entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT } @@ -737,6 +744,9 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) } + if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) + } n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) if n == -1 { @@ -755,6 +765,9 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { Dst: tcpip.Address(iptip.Dst[:]), DstMask: tcpip.Address(iptip.DstMask[:]), DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, + Src: tcpip.Address(iptip.Src[:]), + SrcMask: tcpip.Address(iptip.SrcMask[:]), + SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0, OutputInterface: ifname, OutputInterfaceMask: ifnameMask, OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, @@ -765,15 +778,13 @@ func containsUnsupportedFields(iptip linux.IPTIP) bool { // The following features are supported: // - Protocol // - Dst and DstMask + // - Src and SrcMask // - The inverse destination IP check flag // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInetAddr = linux.InetAddr{} var emptyInterface = [linux.IFNAMSIZ]byte{} // Disable any supported inverse flags. - inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_VIA_OUT) - return iptip.Src != emptyInetAddr || - iptip.SrcMask != emptyInetAddr || - iptip.InputInterface != emptyInterface || + inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT) + return iptip.InputInterface != emptyInterface || iptip.InputInterfaceMask != emptyInterface || iptip.Flags != 0 || iptip.InverseFlags&^inverseMask != 0 diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 443423b3c..709ede3fa 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "strings" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -314,7 +313,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // Check whether the packet matches the IP header filter. - if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader), hook, nicName) { + if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -335,47 +334,3 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // All the matchers matched, so run the target. return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) } - -func filterMatch(filter IPHeaderFilter, hdr header.IPv4, hook Hook, nicName string) bool { - // TODO(gvisor.dev/issue/170): Support other fields of the filter. - // Check the transport protocol. - if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { - return false - } - - // Check the destination IP. - dest := hdr.DestinationAddress() - matches := true - for i := range filter.Dst { - if dest[i]&filter.DstMask[i] != filter.Dst[i] { - matches = false - break - } - } - if matches == filter.DstInvert { - return false - } - - // Check the output interface. - // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING - // hooks after supported. - if hook == Output { - n := len(filter.OutputInterface) - if n == 0 { - return true - } - - // If the interface name ends with '+', any interface which begins - // with the name should be matched. - ifName := filter.OutputInterface - matches = true - if strings.HasSuffix(ifName, "+") { - matches = strings.HasPrefix(nicName, ifName[:n-1]) - } else { - matches = nicName == ifName - } - return filter.OutputInterfaceInvert != matches - } - - return true -} diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index fe06007ae..a3bd3e700 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -15,7 +15,10 @@ package stack import ( + "strings" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // A Hook specifies one of the hooks built into the network stack. @@ -159,6 +162,16 @@ type IPHeaderFilter struct { // comparison. DstInvert bool + // Src matches the source IP address. + Src tcpip.Address + + // SrcMask masks bits of the source IP address when comparing with Src. + SrcMask tcpip.Address + + // SrcInvert inverts the meaning of the source IP check, i.e. when true the + // filter will match packets that fail the source comparison. + SrcInvert bool + // OutputInterface matches the name of the outgoing interface for the // packet. OutputInterface string @@ -173,6 +186,55 @@ type IPHeaderFilter struct { OutputInterfaceInvert bool } +// match returns whether hdr matches the filter. +func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool { + // TODO(gvisor.dev/issue/170): Support other fields of the filter. + // Check the transport protocol. + if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() { + return false + } + + // Check the source and destination IPs. + if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) { + return false + } + + // Check the output interface. + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + if hook == Output { + n := len(fl.OutputInterface) + if n == 0 { + return true + } + + // If the interface name ends with '+', any interface which begins + // with the name should be matched. + ifName := fl.OutputInterface + matches := true + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return fl.OutputInterfaceInvert != matches + } + + return true +} + +// filterAddress returns whether addr matches the filter. +func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool { + matches := true + for i := range filterAddr { + if addr[i]&mask[i] != filterAddr[i] { + matches = false + break + } + } + return matches != invert +} + // A Matcher is the interface for matching packets. type Matcher interface { // Name returns the name of the Matcher. diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index 41e0cfa8d..14e385f5a 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -49,6 +49,8 @@ func init() { RegisterTestCase(FilterInputJumpTwice{}) RegisterTestCase(FilterInputDestination{}) RegisterTestCase(FilterInputInvertDestination{}) + RegisterTestCase(FilterInputSource{}) + RegisterTestCase(FilterInputInvertSource{}) } // FilterInputDropUDP tests that we can drop UDP traffic. @@ -667,3 +669,61 @@ func (FilterInputInvertDestination) ContainerAction(ip net.IP) error { func (FilterInputInvertDestination) LocalAction(ip net.IP) error { return sendUDPLoop(ip, acceptPort, sendloopDuration) } + +// FilterInputSource verifies that we can filter packets via `-d +// `. +type FilterInputSource struct{} + +// Name implements TestCase.Name. +func (FilterInputSource) Name() string { + return "FilterInputSource" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputSource) ContainerAction(ip net.IP) error { + // Make INPUT's default action DROP, then ACCEPT all packets from this + // machine. + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-A", "INPUT", "-s", fmt.Sprintf("%v", ip), "-j", "ACCEPT"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputSource) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// FilterInputInvertSource verifies that we can filter packets via `! -d +// `. +type FilterInputInvertSource struct{} + +// Name implements TestCase.Name. +func (FilterInputInvertSource) Name() string { + return "FilterInputInvertSource" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputInvertSource) ContainerAction(ip net.IP) error { + // Make INPUT's default action DROP, then ACCEPT all packets not bound + // for 127.0.0.1. + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-A", "INPUT", "!", "-s", localIP, "-j", "ACCEPT"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputInvertSource) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 4fd2cb46a..172ad9e16 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -302,3 +302,11 @@ func TestNATPreRedirectInvert(t *testing.T) { func TestNATRedirectRequiresProtocol(t *testing.T) { singleTest(t, NATRedirectRequiresProtocol{}) } + +func TestInputSource(t *testing.T) { + singleTest(t, FilterInputSource{}) +} + +func TestInputInvertSource(t *testing.T) { + singleTest(t, FilterInputInvertSource{}) +} -- cgit v1.2.3 From 341be65421edff16fd9eeb593301ce1d66148772 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Fri, 29 May 2020 12:58:50 -0700 Subject: Update WritePacket* API to take ownership of packets to be written. Updates #2404. PiperOrigin-RevId: 313834784 --- pkg/tcpip/stack/nic.go | 5 ++++- pkg/tcpip/stack/registration.go | 19 ++++++++++--------- pkg/tcpip/stack/route.go | 15 ++++++++++++--- 3 files changed, 26 insertions(+), 13 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 9213cbb9a..05646e5e2 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1304,13 +1304,16 @@ func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt pkt.Header = buffer.NewPrependable(linkHeaderLen) } + // WritePacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return } n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) + n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) } // DeliverTransportPacket delivers the packets to the appropriate transport diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index a25657cd7..db89234e8 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -240,16 +240,17 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have - // already been set. + // protocol. It takes ownership of pkt. pkt.TransportHeader must have already + // been set. WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and - // protocol. pkts must not be zero length. + // protocol. pkts must not be zero length. It takes ownership of pkts and + // underlying packets. WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) // WriteHeaderIncludedPacket writes a packet that includes a network - // header to the given destination address. + // header to the given destination address. It takes ownership of pkt. WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. @@ -382,9 +383,8 @@ type LinkEndpoint interface { LinkAddress() tcpip.LinkAddress // WritePacket writes a packet with the given protocol through the - // given route. It sets pkt.LinkHeader if a link layer header exists. - // pkt.NetworkHeader and pkt.TransportHeader must have already been - // set. + // given route. It takes ownership of pkt. pkt.NetworkHeader and + // pkt.TransportHeader must have already been set. // // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to @@ -392,7 +392,8 @@ type LinkEndpoint interface { WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error // WritePackets writes packets with the given protocol through the - // given route. pkts must not be zero length. + // given route. pkts must not be zero length. It takes ownership of pkts and + // underlying packets. // // Right now, WritePackets is used only when the software segmentation // offload is enabled. If it will be used for something else, it may @@ -400,7 +401,7 @@ type LinkEndpoint interface { WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) // WriteRawPacket writes a packet directly to the link. The packet - // should already have an ethernet header. + // should already have an ethernet header. It takes ownership of vv. WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error // Attach attaches the data link layer endpoint to the network-layer diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 150297ab9..3d0e5cc6e 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -158,12 +158,15 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuff return tcpip.ErrInvalidEndpointState } + // WritePacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + err := r.ref.ep.WritePacket(r, gso, params, pkt) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Header.UsedLength() + pkt.Data.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) } return err } @@ -175,9 +178,12 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead return 0, tcpip.ErrInvalidEndpointState } + // WritePackets takes ownership of pkt, calculate length first. + numPkts := pkts.Len() + n, err := r.ref.ep.WritePackets(r, gso, pkts, params) if err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len() - n)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n)) } r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n)) @@ -198,12 +204,15 @@ func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error { return tcpip.ErrInvalidEndpointState } + // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. + numBytes := pkt.Data.Size() + if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err } r.ref.nic.stats.Tx.Packets.Increment() - r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) return nil } -- cgit v1.2.3 From d3a8bffe04595910714ec67231585bc33dab2b5b Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Wed, 3 Jun 2020 14:57:57 -0700 Subject: Pass PacketBuffer as pointer. Historically we've been passing PacketBuffer by shallow copying through out the stack. Right now, this is only correct as the caller would not use PacketBuffer after passing into the next layer in netstack. With new buffer management effort in gVisor/netstack, PacketBuffer will own a Buffer (to be added). Internally, both PacketBuffer and Buffer may have pointers and shallow copying shouldn't be used. Updates #2404. PiperOrigin-RevId: 314610879 --- pkg/sentry/socket/netfilter/owner_matcher.go | 2 +- pkg/sentry/socket/netfilter/tcp_matcher.go | 2 +- pkg/sentry/socket/netfilter/udp_matcher.go | 2 +- pkg/tcpip/link/channel/channel.go | 8 ++--- pkg/tcpip/link/fdbased/endpoint.go | 4 +-- pkg/tcpip/link/fdbased/endpoint_test.go | 10 +++---- pkg/tcpip/link/fdbased/mmap.go | 2 +- pkg/tcpip/link/fdbased/packet_dispatchers.go | 4 +-- pkg/tcpip/link/loopback/loopback.go | 6 ++-- pkg/tcpip/link/muxed/injectable.go | 4 +-- pkg/tcpip/link/muxed/injectable_test.go | 4 +-- pkg/tcpip/link/qdisc/fifo/endpoint.go | 6 ++-- pkg/tcpip/link/sharedmem/sharedmem.go | 4 +-- pkg/tcpip/link/sharedmem/sharedmem_test.go | 26 ++++++++--------- pkg/tcpip/link/sniffer/sniffer.go | 8 ++--- pkg/tcpip/link/tun/device.go | 2 +- pkg/tcpip/link/waitable/waitable.go | 4 +-- pkg/tcpip/link/waitable/waitable_test.go | 16 +++++----- pkg/tcpip/network/arp/arp.go | 10 +++---- pkg/tcpip/network/arp/arp_test.go | 2 +- pkg/tcpip/network/ip_test.go | 30 ++++++++++++------- pkg/tcpip/network/ipv4/icmp.go | 12 +++++--- pkg/tcpip/network/ipv4/ipv4.go | 34 +++++++++++----------- pkg/tcpip/network/ipv4/ipv4_test.go | 28 +++++++++++------- pkg/tcpip/network/ipv6/icmp.go | 10 +++---- pkg/tcpip/network/ipv6/icmp_test.go | 14 ++++----- pkg/tcpip/network/ipv6/ipv6.go | 8 ++--- pkg/tcpip/network/ipv6/ipv6_test.go | 8 ++--- pkg/tcpip/network/ipv6/ndp_test.go | 10 +++---- pkg/tcpip/stack/conntrack.go | 4 +-- pkg/tcpip/stack/forwarder.go | 4 +-- pkg/tcpip/stack/forwarder_test.go | 34 +++++++++++----------- pkg/tcpip/stack/iptables.go | 2 +- pkg/tcpip/stack/iptables_types.go | 2 +- pkg/tcpip/stack/ndp.go | 4 +-- pkg/tcpip/stack/ndp_test.go | 14 ++++----- pkg/tcpip/stack/nic.go | 12 ++++---- pkg/tcpip/stack/nic_test.go | 2 +- pkg/tcpip/stack/packet_buffer.go | 33 +++++++++++++++++++-- pkg/tcpip/stack/registration.go | 26 ++++++++--------- pkg/tcpip/stack/route.go | 4 +-- pkg/tcpip/stack/stack.go | 4 +-- pkg/tcpip/stack/stack_test.go | 26 ++++++++--------- pkg/tcpip/stack/transport_demuxer.go | 14 ++++----- pkg/tcpip/stack/transport_demuxer_test.go | 4 +-- pkg/tcpip/stack/transport_test.go | 22 +++++++------- pkg/tcpip/transport/icmp/endpoint.go | 8 ++--- pkg/tcpip/transport/icmp/protocol.go | 2 +- pkg/tcpip/transport/packet/endpoint.go | 2 +- pkg/tcpip/transport/raw/endpoint.go | 6 ++-- pkg/tcpip/transport/tcp/connect.go | 4 +-- pkg/tcpip/transport/tcp/dispatcher.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 4 +-- pkg/tcpip/transport/tcp/forwarder.go | 2 +- pkg/tcpip/transport/tcp/protocol.go | 4 +-- pkg/tcpip/transport/tcp/segment.go | 2 +- pkg/tcpip/transport/tcp/testing/context/context.go | 10 +++---- pkg/tcpip/transport/udp/endpoint.go | 10 +++++-- pkg/tcpip/transport/udp/forwarder.go | 4 +-- pkg/tcpip/transport/udp/protocol.go | 6 ++-- pkg/tcpip/transport/udp/udp_test.go | 4 +-- 61 files changed, 306 insertions(+), 255 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 3863293c7..1b4e0ad79 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -111,7 +111,7 @@ func (*OwnerMatcher) Name() string { } // Match implements Matcher.Match. -func (om *OwnerMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { +func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { // Support only for OUTPUT chain. // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. if hook != stack.Output { diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 57a1e1c12..ebabdf334 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -96,7 +96,7 @@ func (*TCPMatcher) Name() string { } // Match implements Matcher.Match. -func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { +func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { netHeader := header.IPv4(pkt.NetworkHeader) if netHeader.TransportProtocol() != header.TCPProtocolNumber { diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index cfa9e621d..98b9943f8 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -93,7 +93,7 @@ func (*UDPMatcher) Name() string { } // Match implements Matcher.Match. -func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) { +func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { netHeader := header.IPv4(pkt.NetworkHeader) // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 5eb78b398..20b183da0 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -181,12 +181,12 @@ func (e *Endpoint) NumQueued() int { } // InjectInbound injects an inbound packet. -func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt stack.PacketBuffer) { +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *stack.PacketBuffer) { e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) } @@ -229,13 +229,13 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // Clone r then release its resource so we only get the relevant fields from // stack.Route without holding a reference to a NIC's endpoint. route := r.Clone() route.Release() p := PacketInfo{ - Pkt: &pkt, + Pkt: pkt, Proto: protocol, GSO: gso, Route: route, diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 5ee508d48..f34082e1a 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -387,7 +387,7 @@ const ( // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) @@ -641,7 +641,7 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) { } // InjectInbound injects an inbound packet. -func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt) } diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 6f41a71a8..eaee7e5d7 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -45,7 +45,7 @@ const ( type packetInfo struct { raddr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - contents stack.PacketBuffer + contents *stack.PacketBuffer } type context struct { @@ -103,7 +103,7 @@ func (c *context) cleanup() { } } -func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { c.ch <- packetInfo{remote, protocol, pkt} } @@ -179,7 +179,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), Hash: hash, @@ -295,7 +295,7 @@ func TestPreserveSrcAddress(t *testing.T) { // WritePacket panics given a prependable with anything less than // the minimum size of the ethernet header. hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, Data: buffer.VectorisedView{}, }); err != nil { @@ -358,7 +358,7 @@ func TestDeliverPacket(t *testing.T) { want := packetInfo{ raddr: raddr, proto: proto, - contents: stack.PacketBuffer{ + contents: &stack.PacketBuffer{ Data: buffer.View(b).ToVectorisedView(), LinkHeader: buffer.View(hdr), }, diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index ca4229ed6..2dfd29aa9 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -191,7 +191,7 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(remote, local, p, stack.PacketBuffer{ + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{ Data: buffer.View(pkt).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 26c96a655..f04738cfb 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -139,7 +139,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), LinkHeader: buffer.View(eth), } @@ -296,7 +296,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), LinkHeader: buffer.View(eth), } diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 20d9e95f6..568c6874f 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -76,7 +76,7 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. -func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) @@ -84,7 +84,7 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Netw // Because we're immediately turning around and writing the packet back // to the rx path, we intentionally don't preserve the remote and local // link addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, stack.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -106,7 +106,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { } linkHeader := header.Ethernet(hdr) vv.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{ + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), &stack.PacketBuffer{ Data: vv, LinkHeader: buffer.View(linkHeader), }) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index f0769830a..c69d6b7e9 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -80,7 +80,7 @@ func (m *InjectableEndpoint) IsAttached() bool { } // InjectInbound implements stack.InjectableLinkEndpoint. -func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { m.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt) } @@ -98,7 +98,7 @@ func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts s // WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint // based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a // route registered in this endpoint. -func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if endpoint, ok := m.routes[r.RemoteAddress]; ok { return endpoint.WritePacket(r, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 87c734c1f..0744f66d6 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -50,7 +50,7 @@ func TestInjectableEndpointDispatch(t *testing.T) { hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{ + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), }) @@ -70,7 +70,7 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { hdr := buffer.NewPrependable(1) hdr.Prepend(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, stack.PacketBuffer{ + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buffer.NewView(0).ToVectorisedView(), }) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index ec5c5048a..b5dfb7850 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -102,7 +102,7 @@ func (q *queueDispatcher) dispatchLoop() { } // DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket. -func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } @@ -146,7 +146,7 @@ func (e *endpoint) GSOMaxSize() uint32 { } // WritePacket implements stack.LinkEndpoint.WritePacket. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. newRoute := r.Clone() @@ -154,7 +154,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] - if !d.q.enqueue(&pkt) { + if !d.q.enqueue(pkt) { return tcpip.ErrNoBufferSpace } d.newPacketWaker.Assert() diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index f5dec0a7f..0374a2441 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -185,7 +185,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // Add the ethernet header here. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) @@ -275,7 +275,7 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { // Send packet up the stack. eth := header.Ethernet(b[:header.EthernetMinimumSize]) - d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), stack.PacketBuffer{ + d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), &stack.PacketBuffer{ Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), LinkHeader: buffer.View(eth), }) diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index f3fc62607..28a2e88ba 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -131,7 +131,7 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress return c } -func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { c.mu.Lock() c.packets = append(c.packets, packetInfo{ addr: remoteLinkAddr, @@ -273,7 +273,7 @@ func TestSimpleSend(t *testing.T) { randomFill(buf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -345,7 +345,7 @@ func TestPreserveSrcAddressInSend(t *testing.T) { hdr := buffer.NewPrependable(header.EthernetMinimumSize) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ Header: hdr, }); err != nil { t.Fatalf("WritePacket failed: %v", err) @@ -401,7 +401,7 @@ func TestFillTxQueue(t *testing.T) { for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -419,7 +419,7 @@ func TestFillTxQueue(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != want { @@ -447,7 +447,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -470,7 +470,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -488,7 +488,7 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != want { @@ -514,7 +514,7 @@ func TestFillTxMemory(t *testing.T) { ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -533,7 +533,7 @@ func TestFillTxMemory(t *testing.T) { // Next attempt to write must fail. hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }) @@ -561,7 +561,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { @@ -577,7 +577,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: uu, }); err != want { @@ -588,7 +588,7 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write the one-buffer packet again. It must succeed. { hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, stack.PacketBuffer{ + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ Header: hdr, Data: buf.ToVectorisedView(), }); err != nil { diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index b060d4627..ae3186314 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -120,8 +120,8 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. -func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { - e.dumpPacket("recv", nil, protocol, &pkt) +func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dumpPacket("recv", nil, protocol, pkt) e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } @@ -208,8 +208,8 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw // WritePacket implements the stack.LinkEndpoint interface. It is called by // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { - e.dumpPacket("send", gso, protocol, &pkt) +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.dumpPacket("send", gso, protocol, pkt) return e.lower.WritePacket(r, gso, protocol, pkt) } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 617446ea2..6bc9033d0 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -213,7 +213,7 @@ func (d *Device) Write(data []byte) (int64, error) { remote = tcpip.LinkAddress(zeroMAC[:]) } - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Data: buffer.View(data).ToVectorisedView(), } if ethHdr != nil { diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index f5a05929f..949b3f2b2 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -50,7 +50,7 @@ func New(lower stack.LinkEndpoint) *Endpoint { // It is called by the link-layer endpoint being wrapped when a packet arrives, // and only forwards to the actual dispatcher if Wait or WaitDispatch haven't // been called. -func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if !e.dispatchGate.Enter() { return } @@ -99,7 +99,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket implements stack.LinkEndpoint.WritePacket. It is called by // higher-level protocols to write packets. It only forwards packets to the // lower endpoint if Wait or WaitWrite haven't been called. -func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if !e.writeGate.Enter() { return nil } diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 0a9b99f18..63bf40562 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -35,7 +35,7 @@ type countedEndpoint struct { dispatcher stack.NetworkDispatcher } -func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dispatchCount++ } @@ -65,7 +65,7 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { e.writeCount++ return nil } @@ -89,21 +89,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, 0, stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -120,21 +120,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 9d0797af7..ea1acba83 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -80,7 +80,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -94,11 +94,11 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { v, ok := pkt.Data.PullUp(header.ARPSize) if !ok { return @@ -122,7 +122,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) fallthrough // also fill the cache from requests @@ -177,7 +177,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 1646d9cde..66e67429c 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -103,7 +103,7 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ Data: v.ToVectorisedView(), }) } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 4c20301c6..d9b62f2db 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -96,7 +96,7 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt stack.PacketBuffer) { +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) { t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) t.dataCalls++ } @@ -104,7 +104,7 @@ func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.Trans // DeliverTransportControlPacket is called by network endpoints after parsing // incoming control (ICMP) packets. This is used by the test object to verify // that the results of the parsing are expected. -func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { t.checkValues(trans, pkt.Data, remote, local) if typ != t.typ { t.t.Errorf("typ = %v, want %v", typ, t.typ) @@ -150,7 +150,7 @@ func (*testObject) Wait() {} // WritePacket is called by network endpoints after producing a packet and // writing it to the link endpoint. This is used by the test object to verify // that the produced packet is as expected. -func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { var prot tcpip.TransportProtocolNumber var srcAddr tcpip.Address var dstAddr tcpip.Address @@ -246,7 +246,11 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + Protocol: 123, + TTL: 123, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -289,7 +293,7 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: view.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -379,7 +383,7 @@ func TestIPv4ReceiveControl(t *testing.T) { o.extra = c.expectedExtra vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: vv, }) if want := c.expectedCount; o.controlCalls != want { @@ -444,7 +448,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: frag1.ToVectorisedView(), }) if o.dataCalls != 0 { @@ -452,7 +456,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send second segment. - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: frag2.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -487,7 +491,11 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + Protocol: 123, + TTL: 123, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }); err != nil { @@ -530,7 +538,7 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: view.ToVectorisedView(), }) if o.dataCalls != 1 { @@ -644,7 +652,7 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: view[:len(view)-c.trunc].ToVectorisedView(), }) if want := c.expectedCount; o.controlCalls != want { diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 4cbefe5ab..d1c3ae835 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -24,7 +24,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { return @@ -56,7 +56,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) @@ -88,7 +88,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, stack.PacketBuffer{ + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, &stack.PacketBuffer{ Data: pkt.Data.Clone(nil), NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), }) @@ -102,7 +102,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt stack.PacketBuffer) { pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv4ProtocolNumber, + TTL: r.DefaultTTL(), + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: vv, TransportHeader: buffer.View(pkt), diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 64046cbbf..9cd7592f4 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -129,7 +129,7 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { // packet's stated length matches the length of the header+payload. mtu // includes the IP header and options. This does not support the DontFragment // IP flag. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. ip := header.IPv4(pkt.Header.View()) flags := ip.Flags() @@ -169,7 +169,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, if i > 0 { newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: pkt.Header, Data: newPayload, NetworkHeader: buffer.View(h), @@ -188,7 +188,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, newPayload := pkt.Data.Clone(nil) newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: pkt.Header, Data: newPayload, NetworkHeader: buffer.View(h), @@ -202,7 +202,7 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, startOfHdr := pkt.Header startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, stack.PacketBuffer{ + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ Header: startOfHdr, Data: emptyVV, NetworkHeader: buffer.View(h), @@ -245,7 +245,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) @@ -253,7 +253,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Output, &pkt, gso, r, "", nicName); !ok { + if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. return nil } @@ -271,9 +271,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) - packet := stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} - ep.HandlePacket(&route, packet) + ep.HandlePacket(&route, &stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) return nil } } @@ -286,7 +286,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, stack.PacketBuffer{ + e.HandlePacket(&loopedR, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -351,14 +351,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) - packet := stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views)} - ep.HandlePacket(&route, packet) + ep.HandlePacket(&route, &stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) n++ continue } } - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, *pkt); err != nil { + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) return n, err } @@ -370,7 +370,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) @@ -426,7 +426,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuf // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { r.Stats().IP.MalformedPacketsReceived.Increment() @@ -447,7 +447,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Input, &pkt, nil, nil, "", ""); !ok { + if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 36035c820..c208ebd99 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -114,7 +114,7 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketInfo stack.PacketBuffer, mtu uint32) { +func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) @@ -174,7 +174,7 @@ func compareFragments(t *testing.T, packets []stack.PacketBuffer, sourcePacketIn type errorChannel struct { *channel.Endpoint - Ch chan stack.PacketBuffer + Ch chan *stack.PacketBuffer packetCollectorErrors []*tcpip.Error } @@ -184,7 +184,7 @@ type errorChannel struct { func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { return &errorChannel{ Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan stack.PacketBuffer, size), + Ch: make(chan *stack.PacketBuffer, size), packetCollectorErrors: packetCollectorErrors, } } @@ -203,7 +203,7 @@ func (e *errorChannel) Drain() int { } // WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) *tcpip.Error { +func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { select { case e.Ch <- pkt: default: @@ -282,13 +282,17 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := stack.PacketBuffer{ + source := &stack.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. Data: payload.Clone(nil), } c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: 42, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -296,7 +300,7 @@ func TestFragmentation(t *testing.T) { t.Errorf("err got %v, want %v", err, nil) } - var results []stack.PacketBuffer + var results []*stack.PacketBuffer L: for { select { @@ -338,7 +342,11 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: 42, + TOS: stack.DefaultTOS, + }, &stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -460,7 +468,7 @@ func TestInvalidFragments(t *testing.T) { s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.PacketBuffer{ + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), }) } @@ -698,7 +706,7 @@ func TestReceiveFragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) - e.InjectInbound(header.IPv4ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ Data: vv, }) } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index bdf3a0d25..b62fb1de6 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -27,7 +27,7 @@ import ( // the original packet that caused the ICMP one to be sent. This information is // used to find out which transport endpoint must be notified about the ICMP // packet. -func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { return @@ -70,7 +70,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt stack. e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.PacketBuffer, hasFragmentHeader bool) { +func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt *stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived @@ -288,7 +288,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, }); err != nil { sent.Dropped.Increment() @@ -390,7 +390,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt stack.P copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: pkt.Data, }); err != nil { @@ -532,7 +532,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, stack.PacketBuffer{ + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ Header: hdr, }) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d412ff688..a720f626f 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -57,7 +57,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress { return "" } -func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, stack.PacketBuffer) *tcpip.Error { +func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { return nil } @@ -67,7 +67,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, stack.PacketBuffer) { +func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) { } type stubLinkAddressCache struct { @@ -189,7 +189,7 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, stack.PacketBuffer{ + ep.HandlePacket(&r, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -328,7 +328,7 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), stack.PacketBuffer{ + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), &stack.PacketBuffer{ Data: vv, }) } @@ -563,7 +563,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -740,7 +740,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -918,7 +918,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), }) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index daf1fcbc6..0d94ad122 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -116,7 +116,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) @@ -128,7 +128,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, stack.PacketBuffer{ + e.HandlePacket(&loopedR, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) @@ -163,14 +163,14 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet // supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // TODO(b/146666412): Support IPv6 header-included packets. return tcpip.ErrNotSupported } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) if !ok { r.Stats().IP.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 841a0cb7a..213ff64f2 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -65,7 +65,7 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -123,7 +123,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -637,7 +637,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { DstAddr: addr2, }) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -1238,7 +1238,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.Append(f.data) - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: vv, }) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 12b70f7e9..3c141b91b 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -136,7 +136,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -380,7 +380,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.PacketBuffer{ + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -497,7 +497,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -568,7 +568,7 @@ func TestNDPValidation(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(r, stack.PacketBuffer{ + ep.HandlePacket(r, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) } @@ -884,7 +884,7 @@ func TestRouterAdvertValidation(t *testing.T) { t.Fatalf("got rxRA = %d, want = 0", got) } - e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 7d1ede1f2..d4053be08 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -186,7 +186,7 @@ func parseHeaders(pkt *PacketBuffer) { } // packetToTuple converts packet to a tuple in original direction. -func packetToTuple(pkt PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { +func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { var tuple connTrackTuple netHeader := header.IPv4(pkt.NetworkHeader) @@ -265,7 +265,7 @@ func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, creat } var dir ctDirection - tuple, err := packetToTuple(*pkt, hook) + tuple, err := packetToTuple(pkt, hook) if err != nil { return nil, dir } diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go index 6b64cd37f..3eff141e6 100644 --- a/pkg/tcpip/stack/forwarder.go +++ b/pkg/tcpip/stack/forwarder.go @@ -32,7 +32,7 @@ type pendingPacket struct { nic *NIC route *Route proto tcpip.NetworkProtocolNumber - pkt PacketBuffer + pkt *PacketBuffer } type forwardQueue struct { @@ -50,7 +50,7 @@ func newForwardQueue() *forwardQueue { return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)} } -func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { shouldWait := false f.Lock() diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 344d60baa..63537aaad 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -68,7 +68,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { return &f.id } -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) { +func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { // Consume the network header. b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) if !ok { @@ -96,7 +96,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu return f.proto.Number() } -func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { +func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.Header.Prepend(fwdTestNetHeaderLen) @@ -112,7 +112,7 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf panic("not implemented") } -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error { +func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -190,7 +190,7 @@ func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumb type fwdTestPacketInfo struct { RemoteLinkAddress tcpip.LinkAddress LocalLinkAddress tcpip.LinkAddress - Pkt PacketBuffer + Pkt *PacketBuffer } type fwdTestLinkEndpoint struct { @@ -203,12 +203,12 @@ type fwdTestLinkEndpoint struct { } // InjectInbound injects an inbound packet. -func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) } // InjectLinkAddr injects an inbound packet with a remote link address. -func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) { +func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) { e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt) } @@ -251,7 +251,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -270,7 +270,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.WritePacket(r, gso, protocol, *pkt) + e.WritePacket(r, gso, protocol, pkt) n++ } @@ -280,7 +280,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := fwdTestPacketInfo{ - Pkt: PacketBuffer{Data: vv}, + Pkt: &PacketBuffer{Data: vv}, } select { @@ -362,7 +362,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -399,7 +399,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -430,7 +430,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -460,7 +460,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // not be forwarded. buf := buffer.NewView(30) buf[0] = 4 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -468,7 +468,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // forwarded to NIC 2. buf = buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -510,7 +510,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { buf := buffer.NewView(30) buf[0] = 3 - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -557,7 +557,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { buf[0] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -610,7 +610,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // maxPendingResolutions + 7). buf := buffer.NewView(30) buf[0] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 709ede3fa..d989dbe91 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -321,7 +321,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { - matches, hotdrop := matcher.Match(hook, *pkt, "") + matches, hotdrop := matcher.Match(hook, pkt, "") if hotdrop { return RuleDrop, 0 } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index a3bd3e700..af72b9c46 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -245,7 +245,7 @@ type Matcher interface { // used for suspicious packets. // // Precondition: packet.NetworkHeader is set. - Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool) + Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } // A Target is the interface for taking an action for a packet. diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 526c7d6ff..ae7a8f740 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -750,7 +750,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, PacketBuffer{Header: hdr}, + }, &PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() return err @@ -1881,7 +1881,7 @@ func (ndp *ndpState) startSolicitingRouters() { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, PacketBuffer{Header: hdr}, + }, &PacketBuffer{Header: hdr}, ); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index b3d174cdd..58f1ebf60 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -613,7 +613,7 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ Data: hdr.View().ToVectorisedView(), }) @@ -935,7 +935,7 @@ func TestSetNDPConfigurations(t *testing.T) { // raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options // and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer { +func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) @@ -970,14 +970,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} } // raBufWithOpts returns a valid NDP Router Advertisement with options. // // Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer { +func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) } @@ -986,7 +986,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ // // Note, raBufWithDHCPv6 does not populate any of the RA fields other than the // DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer { +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer { return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) } @@ -994,7 +994,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo // // Note, raBuf does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { +func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer { return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) } @@ -1003,7 +1003,7 @@ func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer { // // Note, raBufWithPI does not populate any of the RA fields other than the // Router Lifetime. -func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer { +func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer { flags := uint8(0) if onLink { // The OnLink flag is the 7th bit in the flags byte. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 05646e5e2..ec8e3cb85 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1153,7 +1153,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return joins != 0 } -func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) { +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr @@ -1167,7 +1167,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // Note that the ownership of the slice backing vv is retained by the caller. // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. -func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { n.mu.RLock() enabled := n.mu.enabled // If the NIC is not yet enabled, don't receive any packets. @@ -1233,7 +1233,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address, ""); !ok { + if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { // iptables is telling us to drop the packet. return } @@ -1298,7 +1298,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } -func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) { +func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { pkt.Header = buffer.NewPrependable(linkHeaderLen) @@ -1318,7 +1318,7 @@ func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) { +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -1365,7 +1365,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // DeliverTransportControlPacket delivers control packets to the appropriate // transport protocol endpoint. -func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) { +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) { state, ok := n.stack.transportProtocols[trans] if !ok { return diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index b01b3f476..fea46158c 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -44,7 +44,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket("", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 926df4d7b..1b5da6017 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -24,6 +24,8 @@ import ( // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. type PacketBuffer struct { + _ noCopy + // PacketBufferEntry is used to build an intrusive list of // PacketBuffers. PacketBufferEntry @@ -82,7 +84,32 @@ type PacketBuffer struct { // VectorisedView but does not deep copy the underlying bytes. // // Clone also does not deep copy any of its other fields. -func (pk PacketBuffer) Clone() PacketBuffer { - pk.Data = pk.Data.Clone(nil) - return pk +// +// FIXME(b/153685824): Data gets copied but not other header references. +func (pk *PacketBuffer) Clone() *PacketBuffer { + return &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + Header: pk.Header, + LinkHeader: pk.LinkHeader, + NetworkHeader: pk.NetworkHeader, + TransportHeader: pk.TransportHeader, + Hash: pk.Hash, + Owner: pk.Owner, + EgressRoute: pk.EgressRoute, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + } } + +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*noCopy) Lock() {} +func (*noCopy) Unlock() {} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index db89234e8..94f177841 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -67,12 +67,12 @@ type TransportEndpoint interface { // this transport endpoint. It sets pkt.TransportHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) + HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. // HandleControlPacket takes ownership of pkt. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -100,7 +100,7 @@ type RawTransportEndpoint interface { // layer up. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt PacketBuffer) + HandlePacket(r *Route, pkt *PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -118,7 +118,7 @@ type PacketEndpoint interface { // should construct its own ethernet header for applications. // // HandlePacket takes ownership of pkt. - HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer) + HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // TransportProtocol is the interface that needs to be implemented by transport @@ -150,7 +150,7 @@ type TransportProtocol interface { // stats purposes only). // // HandleUnknownDestinationPacket takes ownership of pkt. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -180,7 +180,7 @@ type TransportDispatcher interface { // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -189,7 +189,7 @@ type TransportDispatcher interface { // DeliverTransportControlPacket. // // DeliverTransportControlPacket takes ownership of pkt. - DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -242,7 +242,7 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. It takes ownership of pkt. pkt.TransportHeader must have already // been set. - WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and // protocol. pkts must not be zero length. It takes ownership of pkts and @@ -251,7 +251,7 @@ type NetworkEndpoint interface { // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. It takes ownership of pkt. - WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error + WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -266,7 +266,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt PacketBuffer) + HandlePacket(r *Route, pkt *PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -327,7 +327,7 @@ type NetworkDispatcher interface { // packets sent via loopback), and won't have the field set. // // DeliverNetworkPacket takes ownership of pkt. - DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) + DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -389,7 +389,7 @@ type LinkEndpoint interface { // To participate in transparent bridging, a LinkEndpoint implementation // should call eth.Encode with header.EthernetFields.SrcAddr set to // r.LocalLinkAddress if it is provided. - WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error + WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets with the given protocol through the // given route. pkts must not be zero length. It takes ownership of pkts and @@ -431,7 +431,7 @@ type InjectableLinkEndpoint interface { LinkEndpoint // InjectInbound injects an inbound packet. - InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) + InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) // InjectOutbound writes a fully formed outbound packet directly to the // link. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 3d0e5cc6e..f5b6ca0b9 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -153,7 +153,7 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } @@ -199,7 +199,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. -func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error { +func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0ab4c3e19..8af06cb9a 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -52,7 +52,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool + defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -778,7 +778,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 1a2cf007c..f6ddc3ced 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -90,7 +90,7 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { return &f.id } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ @@ -132,7 +132,7 @@ func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumbe return f.proto.Number() } -func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error { +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { // Increment the sent packet count in the protocol descriptor. f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ @@ -147,7 +147,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) views[0] = pkt.Header.View() views = append(views, pkt.Data.Views()...) - f.HandlePacket(r, stack.PacketBuffer{ + f.HandlePacket(r, &stack.PacketBuffer{ Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), }) } @@ -163,7 +163,7 @@ func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -293,7 +293,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 0 { @@ -305,7 +305,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -317,7 +317,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -328,7 +328,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -340,7 +340,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeNet.packetCount[1] != 1 { @@ -362,7 +362,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: payload.ToVectorisedView(), }) @@ -420,7 +420,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got := fakeNet.PacketCount(localAddrByte); got != want { @@ -2263,7 +2263,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { @@ -2345,7 +2345,7 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) buf[0] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 9a33ed375..e09866405 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -152,7 +152,7 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) { +func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()] @@ -183,7 +183,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) { +func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -251,7 +251,7 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer) + QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { @@ -379,7 +379,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs @@ -470,7 +470,7 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -520,7 +520,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool { +func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] if !ok { return false @@ -544,7 +544,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 2474a7db3..67d778137 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -127,7 +127,7 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), @@ -165,7 +165,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a611e44ab..cb350ead3 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -88,7 +88,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: buffer.View(v).ToVectorisedView(), }); err != nil { @@ -215,7 +215,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -232,7 +232,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -289,7 +289,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { return true } @@ -369,7 +369,7 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -380,7 +380,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 0 { @@ -391,7 +391,7 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.packetCount != 1 { @@ -446,7 +446,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -457,7 +457,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 0 { @@ -468,7 +468,7 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) if fakeTrans.controlCount != 1 { @@ -623,7 +623,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: req.ToVectorisedView(), }) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index b1d820372..29ff68df3 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -450,7 +450,7 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: data.ToVectorisedView(), TransportHeader: buffer.View(icmpv4), @@ -481,7 +481,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, stack.PacketBuffer{ + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: dataVV, TransportHeader: buffer.View(icmpv6), @@ -743,7 +743,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: @@ -805,7 +805,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 3c47692b2..2ec6749c7 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,7 +104,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { return true } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 23158173d..bab2d63ae 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -298,7 +298,7 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.PacketEndpoint.HandlePacket. -func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) { +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { ep.rcvMu.Lock() // Drop the packet if our buffer is currently full. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index eee754a5a..25a17940d 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -348,7 +348,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, switch e.NetProto { case header.IPv4ProtocolNumber: if !e.associated { - if err := route.WriteHeaderIncludedPacket(stack.PacketBuffer{ + if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ Data: buffer.View(payloadBytes).ToVectorisedView(), }); err != nil { return 0, nil, err @@ -357,7 +357,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: buffer.View(payloadBytes).ToVectorisedView(), Owner: e.owner, @@ -584,7 +584,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index e4a06c9e1..7da93dcc4 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -833,13 +833,13 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac return sendTCPBatch(r, tf, data, gso, owner) } - pkt := stack.PacketBuffer{ + pkt := &stack.PacketBuffer{ Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), Data: data, Hash: tf.txHash, Owner: owner, } - buildTCPHdr(r, tf, &pkt, gso) + buildTCPHdr(r, tf, pkt, gso) if tf.ttl == 0 { tf.ttl = r.DefaultTTL() diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 6062ca916..047704c80 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -186,7 +186,7 @@ func (d *dispatcher) wait() { } } -func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { ep := stackEP.(*endpoint) s := newSegment(r, id, pkt) if !s.parse() { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b5ba972f1..d048ef90c 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2462,7 +2462,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { }, nil } -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first // land at the Dispatcher which then can either delivery using the // worker go routine or directly do the invoke the tcp processing inline @@ -2481,7 +2481,7 @@ func (e *endpoint) enqueueSegment(s *segment) bool { } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 704d01c64..070b634b4 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -61,7 +61,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { s := newSegment(r, id, pkt) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2a2a7ddeb..c827d0277 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -206,7 +206,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // to a specific processing queue. Each queue is serviced by its own processor // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. -func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { p.dispatcher.queuePacket(r, ep, id, pkt) } @@ -217,7 +217,7 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { s := newSegment(r, id, pkt) defer s.decRef() diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 074edded6..0c099e2fd 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -60,7 +60,7 @@ type segment struct { xmitCount uint32 } -func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) *segment { +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { s := &segment{ refCnt: 1, id: id, diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 7b1d72cf4..9721f6caf 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -316,7 +316,7 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } @@ -372,7 +372,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: s, }) } @@ -380,7 +380,7 @@ func (c *Context) SendSegment(s buffer.VectorisedView) { // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: c.BuildSegment(payload, h), }) } @@ -389,7 +389,7 @@ func (c *Context) SendPacket(payload []byte, h *Headers) { // & TCPheaders) in an IPv4 packet via the link layer endpoint using the // provided source and destination IPv4 addresses. func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: c.BuildSegmentWithAddrs(payload, h, src, dst), }) } @@ -564,7 +564,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 647b2067a..79faa7869 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -921,7 +921,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u if useDefaultTTL { ttl = r.DefaultTTL() } - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}, stack.PacketBuffer{ + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: ProtocolNumber, + TTL: ttl, + TOS: tos, + }, &stack.PacketBuffer{ Header: hdr, Data: data, TransportHeader: buffer.View(udp), @@ -1269,7 +1273,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) { +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { @@ -1336,7 +1340,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { e.mu.RLock() defer e.mu.RUnlock() diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index a674ceb68..7abfa0ed2 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -43,7 +43,7 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, route: r, @@ -61,7 +61,7 @@ type ForwarderRequest struct { stack *stack.Stack route *stack.Route id stack.TransportEndpointID - pkt stack.PacketBuffer + pkt *stack.PacketBuffer } // ID returns the 4-tuple (src address, src port, dst address, dst port) that diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 52af6de22..e320c5758 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -66,7 +66,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool { +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { // Get the header then trim it from the view. h, ok := pkt.Data.PullUp(header.UDPMinimumSize) if !ok { @@ -140,7 +140,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: payload, }) @@ -177,7 +177,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, stack.PacketBuffer{ + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, Data: payload, }) diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 8acaa607a..e8ade882b 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -440,7 +440,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), @@ -487,7 +487,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), NetworkHeader: buffer.View(ip), TransportHeader: buffer.View(u), -- cgit v1.2.3 From 41da7a568b1e4f46b3bc09724996556fb18b4d16 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Thu, 4 Jun 2020 15:38:33 -0700 Subject: Fix copylocks error about copying IPTables. IPTables.connections contains a sync.RWMutex. Copying it will trigger copylocks analysis. Tested by manually enabling nogo tests. sync.RWMutex is added to IPTables for the additional race condition discovered. PiperOrigin-RevId: 314817019 --- pkg/sentry/socket/netfilter/netfilter.go | 36 ++++++++++++--------------- pkg/sentry/socket/netstack/stack.go | 9 +++---- pkg/tcpip/stack/iptables.go | 42 +++++++++++++++++++++++++++----- pkg/tcpip/stack/iptables_types.go | 15 ++++++++---- pkg/tcpip/stack/stack.go | 23 ++++------------- pkg/tcpip/transport/icmp/endpoint.go | 5 ---- pkg/tcpip/transport/packet/endpoint.go | 5 ---- pkg/tcpip/transport/raw/endpoint.go | 5 ---- pkg/tcpip/transport/tcp/endpoint.go | 5 ---- pkg/tcpip/transport/udp/endpoint.go | 5 ---- runsc/boot/loader.go | 2 +- 11 files changed, 71 insertions(+), 81 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 47ff48c00..66015e2bc 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -144,31 +144,27 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen } func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) { - ipt := stk.IPTables() - table, ok := ipt.Tables[tablename.String()] + table, ok := stk.IPTables().GetTable(tablename.String()) if !ok { return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename) } return table, nil } -// FillDefaultIPTables sets stack's IPTables to the default tables and -// populates them with metadata. -func FillDefaultIPTables(stk *stack.Stack) { - ipt := stack.DefaultTables() - - // In order to fill in the metadata, we have to translate ipt from its - // netstack format to Linux's giant-binary-blob format. - for name, table := range ipt.Tables { - _, metadata, err := convertNetstackToBinary(name, table) - if err != nil { - panic(fmt.Errorf("Unable to set default IP tables: %v", err)) +// FillIPTablesMetadata populates stack's IPTables with metadata. +func FillIPTablesMetadata(stk *stack.Stack) { + stk.IPTables().ModifyTables(func(tables map[string]stack.Table) { + // In order to fill in the metadata, we have to translate ipt from its + // netstack format to Linux's giant-binary-blob format. + for name, table := range tables { + _, metadata, err := convertNetstackToBinary(name, table) + if err != nil { + panic(fmt.Errorf("Unable to set default IP tables: %v", err)) + } + table.SetMetadata(metadata) + tables[name] = table } - table.SetMetadata(metadata) - ipt.Tables[name] = table - } - - stk.SetIPTables(ipt) + }) } // convertNetstackToBinary converts the iptables as stored in netstack to the @@ -573,15 +569,13 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - ipt := stk.IPTables() table.SetMetadata(metadata{ HookEntry: replace.HookEntry, Underflow: replace.Underflow, NumEntries: replace.NumEntries, Size: replace.Size, }) - ipt.Tables[replace.Name.String()] = table - stk.SetIPTables(ipt) + stk.IPTables().ReplaceTable(replace.Name.String(), table) return nil } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index f5fa18136..9b44c2b89 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -362,14 +362,13 @@ func (s *Stack) RouteTable() []inet.Route { } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() (stack.IPTables, error) { +func (s *Stack) IPTables() (*stack.IPTables, error) { return s.Stack.IPTables(), nil } -// FillDefaultIPTables sets the stack's iptables to the default tables, which -// allow and do not modify all traffic. -func (s *Stack) FillDefaultIPTables() { - netfilter.FillDefaultIPTables(s.Stack) +// FillIPTablesMetadata populates stack's IPTables with metadata. +func (s *Stack) FillIPTablesMetadata() { + netfilter.FillIPTablesMetadata(s.Stack) } // Resume implements inet.Stack.Resume. diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index d989dbe91..4e9b404c8 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -43,11 +43,11 @@ const HookUnset = -1 // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() IPTables { +func DefaultTables() *IPTables { // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for // iotas. - return IPTables{ - Tables: map[string]Table{ + return &IPTables{ + tables: map[string]Table{ TablenameNat: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, @@ -106,7 +106,7 @@ func DefaultTables() IPTables { UserChains: map[string]int{}, }, }, - Priorities: map[Hook][]string{ + priorities: map[Hook][]string{ Input: []string{TablenameNat, TablenameFilter}, Prerouting: []string{TablenameMangle, TablenameNat}, Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, @@ -158,6 +158,36 @@ func EmptyNatTable() Table { } } +// GetTable returns table by name. +func (it *IPTables) GetTable(name string) (Table, bool) { + it.mu.RLock() + defer it.mu.RUnlock() + t, ok := it.tables[name] + return t, ok +} + +// ReplaceTable replaces or inserts table by name. +func (it *IPTables) ReplaceTable(name string, table Table) { + it.mu.Lock() + defer it.mu.Unlock() + it.tables[name] = table +} + +// ModifyTables acquires write-lock and calls fn with internal name-to-table +// map. This function can be used to update multiple tables atomically. +func (it *IPTables) ModifyTables(fn func(map[string]Table)) { + it.mu.Lock() + defer it.mu.Unlock() + fn(it.tables) +} + +// GetPriorities returns slice of priorities associated with hook. +func (it *IPTables) GetPriorities(hook Hook) []string { + it.mu.RLock() + defer it.mu.RUnlock() + return it.priorities[hook] +} + // A chainVerdict is what a table decides should be done with a packet. type chainVerdict int @@ -184,8 +214,8 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr it.connections.HandlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - for _, tablename := range it.Priorities[hook] { - table := it.Tables[tablename] + for _, tablename := range it.GetPriorities(hook) { + table, _ := it.GetTable(tablename) ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { // If the table returns Accept, move on to the next table. diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index af72b9c46..4a6a5c6f1 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -16,6 +16,7 @@ package stack import ( "strings" + "sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -78,13 +79,17 @@ const ( // IPTables holds all the tables for a netstack. type IPTables struct { - // Tables maps table names to tables. User tables have arbitrary names. - Tables map[string]Table + // mu protects tables and priorities. + mu sync.RWMutex - // Priorities maps each hook to a list of table names. The order of the + // tables maps table names to tables. User tables have arbitrary names. mu + // needs to be locked for accessing. + tables map[string]Table + + // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that - // hook. - Priorities map[Hook][]string + // hook. mu needs to be locked for accessing. + priorities map[Hook][]string connections ConnTrackTable } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 8af06cb9a..294ce8775 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -424,12 +424,8 @@ type Stack struct { // handleLocal allows non-loopback interfaces to loop packets. handleLocal bool - // tablesMu protects iptables. - tablesMu sync.RWMutex - - // tables are the iptables packet filtering and manipulation rules. The are - // protected by tablesMu.` - tables IPTables + // tables are the iptables packet filtering and manipulation rules. + tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. @@ -676,6 +672,7 @@ func New(opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, + tables: DefaultTables(), icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, @@ -1741,18 +1738,8 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() IPTables { - s.tablesMu.RLock() - t := s.tables - s.tablesMu.RUnlock() - return t -} - -// SetIPTables sets the stack's iptables. -func (s *Stack) SetIPTables(ipt IPTables) { - s.tablesMu.Lock() - s.tables = ipt - s.tablesMu.Unlock() +func (s *Stack) IPTables() *IPTables { + return s.tables } // ICMPLimit returns the maximum number of ICMP messages that can be sent diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 29ff68df3..3bc72bc19 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -140,11 +140,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index bab2d63ae..baf08eda6 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -132,11 +132,6 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (ep *endpoint) IPTables() (stack.IPTables, error) { - return ep.stack.IPTables(), nil -} - // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 25a17940d..21c34fac2 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -166,11 +166,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { if !e.associated { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index d048ef90c..edca98160 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1172,11 +1172,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 79faa7869..663af8fec 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -247,11 +247,6 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} -// IPTables implements tcpip.Endpoint.IPTables. -func (e *endpoint) IPTables() (stack.IPTables, error) { - return e.stack.IPTables(), nil -} - // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index f802bc9fb..002479612 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -1056,7 +1056,7 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err) } - s.FillDefaultIPTables() + s.FillIPTablesMetadata() return &s, nil } -- cgit v1.2.3 From 74a7d76c9777820fcd7bd6002481eb959f58e247 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 3 Jun 2020 19:57:39 -0700 Subject: iptables: loopback traffic skips prerouting chain Loopback traffic is not affected by rules in the PREROUTING chain. This change is also necessary for istio's envoy to talk to other components in the same pod. --- pkg/tcpip/network/ipv4/ipv4.go | 49 ++++++++++------------- pkg/tcpip/stack/nic.go | 3 +- test/iptables/iptables_test.go | 4 ++ test/iptables/nat.go | 89 ++++++++++++++++++++++++++++++------------ 4 files changed, 89 insertions(+), 56 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 9cd7592f4..959f7e007 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -258,38 +258,24 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } + // If the packet is manipulated as per NAT Ouput rules, handle packet + // based on destination address and do not send the packet to link layer. + // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than + // only NATted packets, but removing this check short circuits broadcasts + // before they are sent out to other hosts. if pkt.NatDone { - // If the packet is manipulated as per NAT Ouput rules, handle packet - // based on destination address and do not send the packet to link layer. netHeader := header.IPv4(pkt.NetworkHeader) ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) if err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - ep.HandlePacket(&route, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + handleLoopback(&route, pkt, ep) return nil } } if r.Loop&stack.PacketLoop != 0 { - // The inbound path expects the network header to still be in - // the PacketBuffer's Data field. - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - - e.HandlePacket(&loopedR, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) - + handleLoopback(&loopedR, pkt, e) loopedR.Release() } if r.Loop&stack.PacketOut == 0 { @@ -305,6 +291,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } +func handleLoopback(route *stack.Route, pkt *stack.PacketBuffer, ep stack.NetworkEndpoint) { + // The inbound path expects the network header to still be in + // the PacketBuffer's Data field. + views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) + views[0] = pkt.Header.View() + views = append(views, pkt.Data.Views()...) + ep.HandlePacket(route, &stack.PacketBuffer{ + Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + }) +} + // WritePackets implements stack.NetworkEndpoint.WritePackets. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { @@ -347,13 +344,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) - - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - ep.HandlePacket(&route, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + handleLoopback(&route, pkt, ep) n++ continue } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ec8e3cb85..6664aea06 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1229,7 +1229,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. - if protocol == header.IPv4ProtocolNumber { + // Loopback traffic skips the prerouting chain. + if protocol == header.IPv4ProtocolNumber && !n.isLoopback() { // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 172ad9e16..38319a3b2 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -303,6 +303,10 @@ func TestNATRedirectRequiresProtocol(t *testing.T) { singleTest(t, NATRedirectRequiresProtocol{}) } +func TestNATLoopbackSkipsPrerouting(t *testing.T) { + singleTest(t, NATLoopbackSkipsPrerouting{}) +} + func TestInputSource(t *testing.T) { singleTest(t, FilterInputSource{}) } diff --git a/test/iptables/nat.go b/test/iptables/nat.go index 0a10ce7fe..5e54a3963 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -39,6 +39,7 @@ func init() { RegisterTestCase(NATOutDontRedirectIP{}) RegisterTestCase(NATOutRedirectInvert{}) RegisterTestCase(NATRedirectRequiresProtocol{}) + RegisterTestCase(NATLoopbackSkipsPrerouting{}) } // NATPreRedirectUDPPort tests that packets are redirected to different port. @@ -326,32 +327,6 @@ func (NATRedirectRequiresProtocol) LocalAction(ip net.IP) error { return nil } -// loopbackTests runs an iptables rule and ensures that packets sent to -// dest:dropPort are received by localhost:acceptPort. -func loopbackTest(dest net.IP, args ...string) error { - if err := natTable(args...); err != nil { - return err - } - sendCh := make(chan error) - listenCh := make(chan error) - go func() { - sendCh <- sendUDPLoop(dest, dropPort, sendloopDuration) - }() - go func() { - listenCh <- listenUDP(acceptPort, sendloopDuration) - }() - select { - case err := <-listenCh: - if err != nil { - return err - } - case <-time.After(sendloopDuration): - return errors.New("timed out") - } - // sendCh will always take the full sendloop time. - return <-sendCh -} - // NATOutRedirectTCPPort tests that connections are redirected on specified ports. type NATOutRedirectTCPPort struct{} @@ -400,3 +375,65 @@ func (NATOutRedirectTCPPort) ContainerAction(ip net.IP) error { func (NATOutRedirectTCPPort) LocalAction(ip net.IP) error { return nil } + +// NATLoopbackSkipsPrerouting tests that packets sent via loopback aren't +// affected by PREROUTING rules. +type NATLoopbackSkipsPrerouting struct{} + +// Name implements TestCase.Name. +func (NATLoopbackSkipsPrerouting) Name() string { + return "NATLoopbackSkipsPrerouting" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP) error { + // Redirect anything sent to localhost to an unused port. + dest := []byte{127, 0, 0, 1} + if err := natTable("-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + + // Establish a connection via localhost. If the PREROUTING rule did apply to + // loopback traffic, the connection would fail. + sendCh := make(chan error) + go func() { + sendCh <- connectTCP(dest, acceptPort, sendloopDuration) + }() + + if err := listenTCP(acceptPort, sendloopDuration); err != nil { + return err + } + return <-sendCh +} + +// LocalAction implements TestCase.LocalAction. +func (NATLoopbackSkipsPrerouting) LocalAction(ip net.IP) error { + // No-op. + return nil +} + +// loopbackTests runs an iptables rule and ensures that packets sent to +// dest:dropPort are received by localhost:acceptPort. +func loopbackTest(dest net.IP, args ...string) error { + if err := natTable(args...); err != nil { + return err + } + sendCh := make(chan error) + listenCh := make(chan error) + go func() { + sendCh <- sendUDPLoop(dest, dropPort, sendloopDuration) + }() + go func() { + listenCh <- listenUDP(acceptPort, sendloopDuration) + }() + select { + case err := <-listenCh: + if err != nil { + return err + } + case <-time.After(sendloopDuration): + return errors.New("timed out") + } + // sendCh will always take the full sendloop time. + return <-sendCh +} -- cgit v1.2.3 From 32b823fcdb00a7d6eb5ddcd378f19a659edc3da3 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Sun, 7 Jun 2020 13:37:25 -0700 Subject: netstack: parse incoming packet headers up-front Netstack has traditionally parsed headers on-demand as a packet moves up the stack. This is conceptually simple and convenient, but incompatible with iptables, where headers can be inspected and mangled before even a routing decision is made. This changes header parsing to happen early in the incoming packet path, as soon as the NIC gets the packet from a link endpoint. Even if an invalid packet is found (e.g. a TCP header of insufficient length), the packet is passed up the stack for proper stats bookkeeping. PiperOrigin-RevId: 315179302 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 34 +------ pkg/sentry/socket/netfilter/udp_matcher.go | 34 +------ pkg/tcpip/header/ipv4.go | 5 + pkg/tcpip/header/ipv6_extension_headers.go | 7 ++ pkg/tcpip/network/arp/arp.go | 17 +++- pkg/tcpip/network/fragmentation/fragmentation.go | 4 +- pkg/tcpip/network/ip_test.go | 48 +++++---- pkg/tcpip/network/ipv4/icmp.go | 3 + pkg/tcpip/network/ipv4/ipv4.go | 74 +++++++------- pkg/tcpip/network/ipv4/ipv4_test.go | 12 +++ pkg/tcpip/network/ipv6/icmp.go | 7 +- pkg/tcpip/network/ipv6/icmp_test.go | 43 ++++----- pkg/tcpip/network/ipv6/ipv6.go | 118 +++++++++++++++++++---- pkg/tcpip/network/ipv6/ndp_test.go | 33 ++++--- pkg/tcpip/stack/conntrack.go | 46 --------- pkg/tcpip/stack/forwarder_test.go | 78 ++++++++------- pkg/tcpip/stack/iptables_targets.go | 5 - pkg/tcpip/stack/nic.go | 51 ++++++++-- pkg/tcpip/stack/registration.go | 13 +++ pkg/tcpip/stack/stack_test.go | 70 ++++++++------ pkg/tcpip/stack/transport_test.go | 19 +++- pkg/tcpip/transport/icmp/protocol.go | 10 ++ pkg/tcpip/transport/raw/endpoint.go | 5 +- pkg/tcpip/transport/tcp/protocol.go | 21 ++++ pkg/tcpip/transport/tcp/segment.go | 35 +++---- pkg/tcpip/transport/udp/endpoint.go | 6 +- pkg/tcpip/transport/udp/protocol.go | 45 +++++---- pkg/tcpip/transport/udp/udp_test.go | 60 ++++++++++-- 28 files changed, 539 insertions(+), 364 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index ebabdf334..4f98ee2d5 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -111,36 +111,10 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceN return false, false } - // Now we need the transport header. However, this may not have been set - // yet. - // TODO(gvisor.dev/issue/170): Parsing the transport header should - // ultimately be moved into the stack.Check codepath as matchers are - // added. - var tcpHeader header.TCP - if pkt.TransportHeader != nil { - tcpHeader = header.TCP(pkt.TransportHeader) - } else { - var length int - if hook == stack.Prerouting { - // The network header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // There's no valid TCP header here, so we hotdrop the - // packet. - return false, true - } - h := header.IPv4(hdr) - pkt.NetworkHeader = hdr - length = int(h.HeaderLength()) - } - // The TCP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(length + header.TCPMinimumSize) - if !ok { - // There's no valid TCP header here, so we hotdrop the - // packet. - return false, true - } - tcpHeader = header.TCP(hdr[length:]) + tcpHeader := header.TCP(pkt.TransportHeader) + if len(tcpHeader) < header.TCPMinimumSize { + // There's no valid TCP header here, so we drop the packet immediately. + return false, true } // Check whether the source and destination ports are within the diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 98b9943f8..3f20fc891 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -110,36 +110,10 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceN return false, false } - // Now we need the transport header. However, this may not have been set - // yet. - // TODO(gvisor.dev/issue/170): Parsing the transport header should - // ultimately be moved into the stack.Check codepath as matchers are - // added. - var udpHeader header.UDP - if pkt.TransportHeader != nil { - udpHeader = header.UDP(pkt.TransportHeader) - } else { - var length int - if hook == stack.Prerouting { - // The network header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // There's no valid UDP header here, so we hotdrop the - // packet. - return false, true - } - h := header.IPv4(hdr) - pkt.NetworkHeader = hdr - length = int(h.HeaderLength()) - } - // The UDP header hasn't been parsed yet. We have to do it here. - hdr, ok := pkt.Data.PullUp(length + header.UDPMinimumSize) - if !ok { - // There's no valid UDP header here, so we hotdrop the - // packet. - return false, true - } - udpHeader = header.UDP(hdr[length:]) + udpHeader := header.UDP(pkt.TransportHeader) + if len(udpHeader) < header.UDPMinimumSize { + // There's no valid UDP header here, so we drop the packet immediately. + return false, true } // Check whether the source and destination ports are within the diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 76839eb92..62ac932bb 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -159,6 +159,11 @@ func (b IPv4) Flags() uint8 { return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13) } +// More returns whether the more fragments flag is set. +func (b IPv4) More() bool { + return b.Flags()&IPv4FlagMoreFragments != 0 +} + // TTL returns the "TTL" field of the ipv4 header. func (b IPv4) TTL() uint8 { return b[ttl] diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 2c4591409..3499d8399 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -354,6 +354,13 @@ func (b IPv6FragmentExtHdr) ID() uint32 { return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:]) } +// IsAtomic returns whether the fragment header indicates an atomic fragment. An +// atomic fragment is a fragment that contains all the data required to +// reassemble a full packet. +func (b IPv6FragmentExtHdr) IsAtomic() bool { + return !b.More() && b.FragmentOffset() == 0 +} + // IPv6PayloadIterator is an iterator over the contents of an IPv6 payload. // // The IPv6 payload may contain IPv6 extension headers before any upper layer diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index ea1acba83..7f27a840d 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -99,11 +99,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - v, ok := pkt.Data.PullUp(header.ARPSize) - if !ok { - return - } - h := header.ARP(v) + h := header.ARP(pkt.NetworkHeader) if !h.IsValid() { return } @@ -209,6 +205,17 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.NetworkProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.ARPSize) + if !ok { + return 0, false, false + } + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(header.ARPSize) + return 0, false, true +} + var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) // NewProtocol returns an ARP network protocol. diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index f42abc4bb..2982450f8 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -81,8 +81,8 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t } } -// Process processes an incoming fragment belonging to an ID -// and returns a complete packet when all the packets belonging to that ID have been received. +// Process processes an incoming fragment belonging to an ID and returns a +// complete packet when all the packets belonging to that ID have been received. func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { f.mu.Lock() r, ok := f.reassemblers[id] diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index d9b62f2db..7c8fb3e0a 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -293,9 +293,9 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, &stack.PacketBuffer{ - Data: view.ToVectorisedView(), - }) + pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -382,10 +382,7 @@ func TestIPv4ReceiveControl(t *testing.T) { o.typ = c.expectedTyp o.extra = c.expectedExtra - vv := view[:len(view)-c.trunc].ToVectorisedView() - ep.HandlePacket(&r, &stack.PacketBuffer{ - Data: vv, - }) + ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } @@ -448,17 +445,17 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - ep.HandlePacket(&r, &stack.PacketBuffer{ - Data: frag1.ToVectorisedView(), - }) + pkt := stack.PacketBuffer{Data: frag1.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - ep.HandlePacket(&r, &stack.PacketBuffer{ - Data: frag2.ToVectorisedView(), - }) + pkt = stack.PacketBuffer{Data: frag2.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -538,9 +535,9 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - ep.HandlePacket(&r, &stack.PacketBuffer{ - Data: view.ToVectorisedView(), - }) + pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} + proto.Parse(&pkt) + ep.HandlePacket(&r, &pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -652,12 +649,25 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, &stack.PacketBuffer{ - Data: view[:len(view)-c.trunc].ToVectorisedView(), - }) + ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) if want := c.expectedCount; o.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) } }) } } + +// truncatedPacket returns a PacketBuffer based on a truncated view. If view, +// after truncation, is large enough to hold a network header, it makes part of +// view the packet's NetworkHeader and the rest its Data. Otherwise all of view +// becomes Data. +func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { + v := view[:len(view)-trunc] + if len(v) < netHdrLen { + return &stack.PacketBuffer{Data: v.ToVectorisedView()} + } + return &stack.PacketBuffer{ + NetworkHeader: v[:netHdrLen], + Data: v[netHdrLen:].ToVectorisedView(), + } +} diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index d1c3ae835..1b67aa066 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -59,6 +59,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { stats := r.Stats() received := stats.ICMP.V4PacketsReceived + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) if !ok { received.Invalid.Increment() diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 959f7e007..7e9f16c90 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -21,6 +21,7 @@ package ipv4 import ( + "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" @@ -268,14 +269,14 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) if err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - handleLoopback(&route, pkt, ep) + ep.HandlePacket(&route, pkt) return nil } } if r.Loop&stack.PacketLoop != 0 { loopedR := r.MakeLoopedRoute() - handleLoopback(&loopedR, pkt, e) + e.HandlePacket(&loopedR, pkt) loopedR.Release() } if r.Loop&stack.PacketOut == 0 { @@ -291,17 +292,6 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } -func handleLoopback(route *stack.Route, pkt *stack.PacketBuffer, ep stack.NetworkEndpoint) { - // The inbound path expects the network header to still be in - // the PacketBuffer's Data field. - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - ep.HandlePacket(route, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) -} - // WritePackets implements stack.NetworkEndpoint.WritePackets. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) { if r.Loop&stack.PacketLoop != 0 { @@ -339,12 +329,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader) - ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) - if err == nil { + if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() route := r.ReverseRoute(src, dst) - handleLoopback(&route, pkt, ep) + ep.HandlePacket(&route, pkt) n++ continue } @@ -418,22 +407,11 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { + h := header.IPv4(pkt.NetworkHeader) + if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { r.Stats().IP.MalformedPacketsReceived.Increment() return } - h := header.IPv4(headerView) - if !h.IsValid(pkt.Data.Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } - pkt.NetworkHeader = headerView[:h.HeaderLength()] - - hlen := int(h.HeaderLength()) - tlen := int(h.TotalLength()) - pkt.Data.TrimFront(hlen) - pkt.Data.CapLength(tlen - hlen) // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. @@ -443,9 +421,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } - more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 - if more || h.FragmentOffset() != 0 { - if pkt.Data.Size() == 0 { + if h.More() || h.FragmentOffset() != 0 { + if pkt.Data.Size()+len(pkt.TransportHeader) == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -464,7 +441,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } var ready bool var err error - pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, pkt.Data) + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, h.More(), pkt.Data) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -476,7 +453,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { - headerView.CapLength(hlen) + pkt.NetworkHeader.CapLength(int(h.HeaderLength())) e.handleICMP(r, pkt) return } @@ -556,6 +533,35 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return 0, false, false + } + ipHdr := header.IPv4(hdr) + + // If there are options, pull those into hdr as well. + if headerLen := int(ipHdr.HeaderLength()); headerLen > header.IPv4MinimumSize && headerLen <= pkt.Data.Size() { + hdr, ok = pkt.Data.PullUp(headerLen) + if !ok { + panic(fmt.Sprintf("There are only %d bytes in pkt.Data, but there should be at least %d", pkt.Data.Size(), headerLen)) + } + ipHdr = header.IPv4(hdr) + } + + // If this is a fragment, don't bother parsing the transport header. + parseTransportHeader := true + if ipHdr.More() || ipHdr.FragmentOffset() != 0 { + parseTransportHeader = false + } + + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(len(hdr)) + pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) + return ipHdr.TransportProtocol(), parseTransportHeader, true +} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index c208ebd99..11e579c4b 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -652,6 +652,18 @@ func TestReceiveFragments(t *testing.T) { }, expectedPayloads: [][]byte{udpPayload1, udpPayload2}, }, + { + name: "Fragment without followup", + fragments: []fragmentData{ + { + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1[:64], + }, + }, + expectedPayloads: nil, + }, } for _, test := range tests { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index b62fb1de6..2ff7eedf4 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -70,17 +70,20 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt *stack.PacketBuffer, hasFragmentHeader bool) { +func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize) if !ok { received.Invalid.Increment() return } h := header.ICMPv6(v) - iph := header.IPv6(netHeader) + iph := header.IPv6(pkt.NetworkHeader) // Validate ICMPv6 checksum before processing the packet. // diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index a720f626f..52a01b44e 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -179,36 +179,32 @@ func TestICMPCounts(t *testing.T) { }, } - handleIPv6Payload := func(hdr buffer.Prependable) { - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + handleIPv6Payload := func(icmp header.ICMPv6) { + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), + PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) ep.HandlePacket(&r, &stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), + NetworkHeader: buffer.View(ip), + Data: buffer.View(icmp).ToVectorisedView(), }) } for _, typ := range types { - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) - - handleIPv6Payload(hdr) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + handleIPv6Payload(icmp) } // Construct an empty ICMP packet so that // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. - handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize)) + handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) icmpv6Stats := s.Stats().ICMP.V6PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { @@ -546,25 +542,22 @@ func TestICMPChecksumValidationSimple(t *testing.T) { } handleIPv6Payload := func(checksum bool) { - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) } - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(typ.size + extraDataLen), + PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, SrcAddr: lladdr1, DstAddr: lladdr0, }) e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), + Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), }) } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 0d94ad122..95fbcf2d1 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -171,22 +171,20 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuff // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - headerView, ok := pkt.Data.PullUp(header.IPv6MinimumSize) - if !ok { + h := header.IPv6(pkt.NetworkHeader) + if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { r.Stats().IP.MalformedPacketsReceived.Increment() return } - h := header.IPv6(headerView) - if !h.IsValid(pkt.Data.Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() - return - } - - pkt.NetworkHeader = headerView[:header.IPv6MinimumSize] - pkt.Data.TrimFront(header.IPv6MinimumSize) - pkt.Data.CapLength(int(h.PayloadLength())) - it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), pkt.Data) + // vv consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + vv := pkt.NetworkHeader[header.IPv6MinimumSize:].ToVectorisedView() + vv.AppendView(pkt.TransportHeader) + vv.Append(pkt.Data) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) hasFragmentHeader := false for firstHeader := true; ; firstHeader = false { @@ -262,9 +260,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6FragmentExtHdr: hasFragmentHeader = true - fragmentOffset := extHdr.FragmentOffset() - more := extHdr.More() - if !more && fragmentOffset == 0 { + if extHdr.IsAtomic() { // This fragment extension header indicates that this packet is an // atomic fragment. An atomic fragment is a fragment that contains // all the data required to reassemble a full packet. As per RFC 6946, @@ -277,9 +273,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // Don't consume the iterator if we have the first fragment because we // will use it to validate that the first fragment holds the upper layer // header. - rawPayload := it.AsRawHeader(fragmentOffset != 0 /* consume */) + rawPayload := it.AsRawHeader(extHdr.FragmentOffset() != 0 /* consume */) - if fragmentOffset == 0 { + if extHdr.FragmentOffset() == 0 { // Check that the iterator ends with a raw payload as the first fragment // should include all headers up to and including any upper layer // headers, as per RFC 8200 section 4.5; only upper layer data @@ -332,7 +328,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } // The packet is a fragment, let's try to reassemble it. - start := fragmentOffset * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit + start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit last := start + uint16(fragmentPayloadLen) - 1 // Drop the packet if the fragmentOffset is incorrect. i.e the @@ -345,7 +341,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } var ready bool - pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, more, rawPayload.Buf) + // Note that pkt doesn't have its transport header set after reassembly, + // and won't until DeliverNetworkPacket sets it. + pkt.Data, ready, err = e.fragmentation.Process(hash.IPv6FragmentHash(h, extHdr.ID()), start, last, extHdr.More(), rawPayload.Buf) if err != nil { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() @@ -394,10 +392,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6RawPayloadHeader: // If the last header in the payload isn't a known IPv6 extension header, // handle it as if it is transport layer data. + + // For unfragmented packets, extHdr still contains the transport header. + // Get rid of it. + // + // For reassembled fragments, pkt.TransportHeader is unset, so this is a + // no-op and pkt.Data begins with the transport header. + extHdr.Buf.TrimFront(len(pkt.TransportHeader)) pkt.Data = extHdr.Buf if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { - e.handleICMP(r, headerView, pkt, hasFragmentHeader) + e.handleICMP(r, pkt, hasFragmentHeader) } else { r.Stats().IP.PacketsDelivered.Increment() // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error @@ -505,6 +510,79 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return 0, false, false + } + ipHdr := header.IPv6(hdr) + + // dataClone consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + views := [8]buffer.View{} + dataClone := pkt.Data.Clone(views[:]) + dataClone.TrimFront(header.IPv6MinimumSize) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + + // Iterate over the IPv6 extensions to find their length. + // + // Parsing occurs again in HandlePacket because we don't track the + // extensions in PacketBuffer. Unfortunately, that means HandlePacket + // has to do the parsing work again. + var nextHdr tcpip.TransportProtocolNumber + foundNext := true + extensionsSize := 0 +traverseExtensions: + for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() { + if err != nil { + break + } + // If we exhaust the extension list, the entire packet is the IPv6 header + // and (possibly) extensions. + if done { + extensionsSize = dataClone.Size() + foundNext = false + break + } + + switch extHdr := extHdr.(type) { + case header.IPv6FragmentExtHdr: + // If this is an atomic fragment, we don't have to treat it specially. + if !extHdr.More() && extHdr.FragmentOffset() == 0 { + continue + } + // This is a non-atomic fragment and has to be re-assembled before we can + // examine the payload for a transport header. + foundNext = false + + case header.IPv6RawPayloadHeader: + // We've found the payload after any extensions. + extensionsSize = dataClone.Size() - extHdr.Buf.Size() + nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) + break traverseExtensions + + default: + // Any other extension is a no-op, keep looping until we find the payload. + } + } + + // Put the IPv6 header with extensions in pkt.NetworkHeader. + hdr, ok = pkt.Data.PullUp(header.IPv6MinimumSize + extensionsSize) + if !ok { + panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + } + ipHdr = header.IPv6(hdr) + + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(len(hdr)) + pkt.Data.CapLength(int(ipHdr.PayloadLength())) + + return nextHdr, foundNext, true +} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 3c141b91b..64239ce9a 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -551,25 +551,29 @@ func TestNDPValidation(t *testing.T) { return s, ep, r } - handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { nextHdr := uint8(header.ICMPv6ProtocolNumber) + var extensions buffer.View if atomicFragment { - bytes := hdr.Prepend(header.IPv6FragmentExtHdrLength) - bytes[0] = nextHdr + extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) + extensions[0] = nextHdr nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) } - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize + len(extensions))) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), + PayloadLength: uint16(len(payload) + len(extensions)), NextHeader: nextHdr, HopLimit: hopLimit, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) + if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { + t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) + } ep.HandlePacket(r, &stack.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), + NetworkHeader: buffer.View(ip), + Data: payload.ToVectorisedView(), }) } @@ -676,14 +680,11 @@ func TestNDPValidation(t *testing.T) { invalid := stats.Invalid typStat := typ.statCounter(stats) - extraDataLen := len(typ.extraData) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen + header.IPv6FragmentExtHdrLength) - extraData := buffer.View(hdr.Prepend(extraDataLen)) - copy(extraData, typ.extraData) - pkt := header.ICMPv6(hdr.Prepend(typ.size)) - pkt.SetType(typ.typ) - pkt.SetCode(test.code) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetCode(test.code) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) // Rx count of the NDP message should initially be 0. if got := typStat.Value(); got != 0 { @@ -699,7 +700,7 @@ func TestNDPValidation(t *testing.T) { t.FailNow() } - handleIPv6Payload(hdr, test.hopLimit, test.atomicFragment, ep, &r) + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) // Rx count of the NDP packet should have increased. if got := typStat.Value(); got != 1 { diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index d4053be08..05bf62788 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -20,7 +20,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" @@ -147,44 +146,6 @@ type ConnTrackTable struct { Seed uint32 } -// parseHeaders sets headers in the packet. -func parseHeaders(pkt *PacketBuffer) { - newPkt := pkt.Clone() - - // Set network header. - hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - return - } - netHeader := header.IPv4(hdr) - newPkt.NetworkHeader = hdr - length := int(netHeader.HeaderLength()) - - // TODO(gvisor.dev/issue/170): Need to support for other - // protocols as well. - // Set transport header. - switch protocol := netHeader.TransportProtocol(); protocol { - case header.UDPProtocolNumber: - if newPkt.TransportHeader == nil { - h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize) - if !ok { - return - } - newPkt.TransportHeader = buffer.View(header.UDP(h[length:])) - } - case header.TCPProtocolNumber: - if newPkt.TransportHeader == nil { - h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize) - if !ok { - return - } - newPkt.TransportHeader = buffer.View(header.TCP(h[length:])) - } - } - pkt.NetworkHeader = newPkt.NetworkHeader - pkt.TransportHeader = newPkt.TransportHeader -} - // packetToTuple converts packet to a tuple in original direction. func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { var tuple connTrackTuple @@ -257,13 +218,6 @@ func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 { // TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other // transport protocols. func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) { - if hook == Prerouting { - // Headers will not be set in Prerouting. - // TODO(gvisor.dev/issue/170): Change this after parsing headers - // code is added. - parseHeaders(pkt) - } - var dir ctDirection tuple, err := packetToTuple(pkt, hook) if err != nil { diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 63537aaad..a6546cef0 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -33,6 +33,10 @@ const ( // except where another value is explicitly used. It is chosen to match // the MTU of loopback interfaces on linux systems. fwdTestNetDefaultMTU = 65536 + + dstAddrOffset = 0 + srcAddrOffset = 1 + protocolNumberOffset = 2 ) // fwdTestNetworkEndpoint is a network-layer protocol endpoint. @@ -69,15 +73,8 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { } func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { - // Consume the network header. - b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) - if !ok { - return - } - pkt.Data.TrimFront(fwdTestNetHeaderLen) - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -100,9 +97,9 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH // Add the protocol's header to the packet and send it to the link // endpoint. b := pkt.Header.Prepend(fwdTestNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(params.Protocol) + b[dstAddrOffset] = r.RemoteAddress[0] + b[srcAddrOffset] = f.id.LocalAddress[0] + b[protocolNumberOffset] = byte(params.Protocol) return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) } @@ -140,7 +137,17 @@ func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int { } func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) + return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) +} + +func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + if !ok { + return 0, false, false + } + pkt.NetworkHeader = netHeader + pkt.Data.TrimFront(fwdTestNetHeaderLen) + return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true } func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { @@ -361,7 +368,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -398,7 +405,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -429,7 +436,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -459,7 +466,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Inject an inbound packet to address 4 on NIC 1. This packet should // not be forwarded. buf := buffer.NewView(30) - buf[0] = 4 + buf[dstAddrOffset] = 4 ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -467,7 +474,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf = buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -480,9 +487,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -509,7 +515,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { // Inject two inbound packets to address 3 on NIC 1. for i := 0; i < 2; i++ { buf := buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -524,9 +530,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -554,7 +559,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { for i := 0; i < maxPendingPacketsPerResolution+5; i++ { // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. buf := buffer.NewView(30) - buf[0] = 3 + buf[dstAddrOffset] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ @@ -571,14 +576,18 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - b := p.Pkt.Data.ToView() - if b[0] != 3 { - t.Fatalf("got b[0] = %d, want = 3", b[0]) + if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 { + t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) + } + seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes). + if !ok { + t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size()) } - // The first 5 packets should not be forwarded so the the - // sequemnce number should start with 5. + + // The first 5 packets should not be forwarded so the sequence number should + // start with 5. want := uint16(i + 5) - if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want { + if n := binary.BigEndian.Uint16(seqNumBuf); n != want { t.Fatalf("got the packet #%d, want = #%d", n, want) } @@ -609,7 +618,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // Each packet has a different destination address (3 to // maxPendingResolutions + 7). buf := buffer.NewView(30) - buf[0] = byte(3 + i) + buf[dstAddrOffset] = byte(3 + i) ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -626,9 +635,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - b := p.Pkt.Data.ToView() - if b[0] < 8 { - t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0]) + if p.Pkt.NetworkHeader[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset]) } // Test that the address resolution happened correctly. diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 36cc6275d..92e31643e 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -98,11 +98,6 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook return RuleAccept, 0 } - // Set network header. - if hook == Prerouting { - parseHeaders(pkt) - } - // Drop the packet if network and transport header are not set. if pkt.NetworkHeader == nil || pkt.TransportHeader == nil { return RuleDrop, 0 diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 6664aea06..d756ae6f5 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1212,12 +1212,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.stack.stats.IP.PacketsReceived.Increment() } - netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize()) + // Parse headers. + transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) if !ok { + // The packet is too small to contain a network header. n.stack.stats.MalformedRcvdPackets.Increment() return } - src, dst := netProto.ParseAddresses(netHeader) + if hasTransportHdr { + // Parse the transport header if present. + if state, ok := n.stack.transportProtocols[transProtoNum]; ok { + state.proto.Parse(pkt) + } + } + + src, dst := netProto.ParseAddresses(pkt.NetworkHeader) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1301,8 +1310,18 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 { - pkt.Header = buffer.NewPrependable(linkHeaderLen) + // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this + // by having lower layers explicity write each header instead of just + // pkt.Header. + + // pkt may have set its NetworkHeader and TransportHeader. If we're + // forwarding, we'll have to copy them into pkt.Header. + pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) + if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader))) + } + if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader))) } // WritePacket takes ownership of pkt, calculate numBytes first. @@ -1333,13 +1352,31 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + // TransportHeader is nil only when pkt is an ICMP packet or was reassembled + // from fragments. + if pkt.TransportHeader == nil { + // TODO(gvisor.dev/issue/170): ICMP packets don't have their + // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // full explanation. + if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { + transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) + if !ok { + n.stack.stats.MalformedRcvdPackets.Increment() + return + } + pkt.TransportHeader = transHeader + } else { + // This is either a bad packet or was re-assembled from fragments. + transProto.Parse(pkt) + } + } + + if len(pkt.TransportHeader) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(transHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 94f177841..5cbc946b6 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -168,6 +168,11 @@ type TransportProtocol interface { // Wait waits for any worker goroutines owned by the protocol to stop. Wait() + + // Parse sets pkt.TransportHeader and trims pkt.Data appropriately. It does + // neither and returns false if pkt.Data is too small, i.e. pkt.Data.Size() < + // MinimumPacketSize() + Parse(pkt *PacketBuffer) (ok bool) } // TransportDispatcher contains the methods used by the network stack to deliver @@ -313,6 +318,14 @@ type NetworkProtocol interface { // Wait waits for any worker goroutines owned by the protocol to stop. Wait() + + // Parse sets pkt.NetworkHeader and trims pkt.Data appropriately. It + // returns: + // - The encapsulated protocol, if present. + // - Whether there is an encapsulated transport protocol payload (e.g. ARP + // does not encapsulate anything). + // - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader. + Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) } // NetworkDispatcher contains the methods used by the network stack to deliver diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index f6ddc3ced..ffef9bc2c 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -52,6 +52,10 @@ const ( // where another value is explicitly used. It is chosen to match the MTU // of loopback interfaces on linux systems. defaultMTU = 65536 + + dstAddrOffset = 0 + srcAddrOffset = 1 + protocolNumberOffset = 2 ) // fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and @@ -94,26 +98,24 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff // Increment the received packet count in the protocol descriptor. f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ - // Consume the network header. - b, ok := pkt.Data.PullUp(fakeNetHeaderLen) - if !ok { - return - } - pkt.Data.TrimFront(fakeNetHeaderLen) - // Handle control packets. - if b[2] == uint8(fakeControlProtocol) { + if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return } pkt.Data.TrimFront(fakeNetHeaderLen) - f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt) + f.dispatcher.DeliverTransportControlPacket( + tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]), + tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]), + fakeNetNumber, + tcpip.TransportProtocolNumber(nb[protocolNumberOffset]), + stack.ControlPortUnreachable, 0, pkt) return } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -138,18 +140,13 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // Add the protocol's header to the packet and send it to the link // endpoint. - b := pkt.Header.Prepend(fakeNetHeaderLen) - b[0] = r.RemoteAddress[0] - b[1] = f.id.LocalAddress[0] - b[2] = byte(params.Protocol) + pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen) + pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0] + pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0] + pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) - f.HandlePacket(r, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + f.HandlePacket(r, pkt) } if r.Loop&stack.PacketOut == 0 { return nil @@ -205,7 +202,7 @@ func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { } func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) + return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { @@ -247,6 +244,17 @@ func (*fakeNetworkProtocol) Close() {} // Wait implements TransportProtocol.Wait. func (*fakeNetworkProtocol) Wait() {} +// Parse implements TransportProtocol.Parse. +func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen) + if !ok { + return 0, false, false + } + pkt.NetworkHeader = hdr + pkt.Data.TrimFront(fakeNetHeaderLen) + return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true +} + func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } @@ -292,7 +300,7 @@ func TestNetworkReceive(t *testing.T) { buf := buffer.NewView(30) // Make sure packet with wrong address is not delivered. - buf[0] = 3 + buf[dstAddrOffset] = 3 ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -304,7 +312,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is delivered to first endpoint. - buf[0] = 1 + buf[dstAddrOffset] = 1 ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -316,7 +324,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is delivered to second endpoint. - buf[0] = 2 + buf[dstAddrOffset] = 2 ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) @@ -982,7 +990,7 @@ func TestAddressRemoval(t *testing.T) { buf := buffer.NewView(30) // Send and receive packets, and verify they are received. - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testRecv(t, fakeNet, localAddrByte, ep, buf) testSendTo(t, s, remoteAddr, ep, nil) @@ -1032,7 +1040,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { } // Send and receive packets, and verify they are received. - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testRecv(t, fakeNet, localAddrByte, ep, buf) testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) @@ -1114,7 +1122,7 @@ func TestEndpointExpiration(t *testing.T) { fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte if promiscuous { if err := s.SetPromiscuousMode(nicID, true); err != nil { @@ -1277,7 +1285,7 @@ func TestPromiscuousMode(t *testing.T) { // Write a packet, and check that it doesn't get delivered as we don't // have a matching endpoint. const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte testFailingRecv(t, fakeNet, localAddrByte, ep, buf) // Set promiscuous mode, then check that packet is delivered. @@ -1658,7 +1666,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { buf := buffer.NewView(30) const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -1766,7 +1774,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { buf := buffer.NewView(30) const localAddrByte byte = 0x01 - buf[0] = localAddrByte + buf[dstAddrOffset] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) @@ -2344,7 +2352,7 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) - buf[0] = dstAddr[0] + buf[dstAddrOffset] = dstAddr[0] ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ Data: buf.ToVectorisedView(), }) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index cb350ead3..ad61c09d6 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -83,7 +83,8 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return 0, nil, tcpip.ErrNoRoute } - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) + hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen) + hdr.Prepend(fakeTransHeaderLen) v, err := p.FullPayload() if err != nil { return 0, nil, err @@ -324,6 +325,17 @@ func (*fakeTransportProtocol) Close() {} // Wait implements TransportProtocol.Wait. func (*fakeTransportProtocol) Wait() {} +// Parse implements TransportProtocol.Parse. +func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen) + if !ok { + return false + } + pkt.TransportHeader = hdr + pkt.Data.TrimFront(fakeTransHeaderLen) + return true +} + func fakeTransFactory() stack.TransportProtocol { return &fakeTransportProtocol{} } @@ -642,11 +654,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - hdrs := p.Pkt.Data.ToView() - if dst := hdrs[0]; dst != 3 { + if dst := p.Pkt.NetworkHeader[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := hdrs[1]; src != 1 { + if src := p.Pkt.NetworkHeader[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 2ec6749c7..74ef6541e 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -124,6 +124,16 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + // TODO(gvisor.dev/issue/170): Implement parsing of ICMP. + // + // Right now, the Parse() method is tied to enabled protocols passed into + // stack.New. This works for UDP and TCP, but we handle ICMP traffic even + // when netstack users don't pass ICMP as a supported protocol. + return false +} + // NewProtocol4 returns an ICMPv4 transport protocol. func NewProtocol4() stack.TransportProtocol { return &protocol{ProtocolNumber4} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 21c34fac2..a406d815e 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -627,8 +627,9 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { }, } - networkHeader := append(buffer.View(nil), pkt.NetworkHeader...) - combinedVV := networkHeader.ToVectorisedView() + headers := append(buffer.View(nil), pkt.NetworkHeader...) + headers = append(headers, pkt.TransportHeader...) + combinedVV := headers.ToVectorisedView() combinedVV.Append(pkt.Data) packet.data = combinedVV packet.timestampNS = e.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index c827d0277..73b8a6782 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -21,6 +21,7 @@ package tcp import ( + "fmt" "runtime" "strings" "time" @@ -490,6 +491,26 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter { return &p.synRcvdCount } +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + + // If the header has options, pull those up as well. + if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { + hdr, ok = pkt.Data.PullUp(offset) + if !ok { + panic(fmt.Sprintf("There should be at least %d bytes in pkt.Data.", offset)) + } + } + + pkt.TransportHeader = hdr + pkt.Data.TrimFront(len(hdr)) + return true +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 0c099e2fd..0280892a8 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -35,6 +35,7 @@ type segment struct { id stack.TransportEndpointID `state:"manual"` route stack.Route `state:"manual"` data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + hdr header.TCP // views is used as buffer for data when its length is large // enough to store a VectorisedView. views [8]buffer.View `state:"nosave"` @@ -67,6 +68,7 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB route: r.Clone(), } s.data = pkt.Data.Clone(s.views[:]) + s.hdr = header.TCP(pkt.TransportHeader) s.rcvdTime = time.Now() return s } @@ -146,12 +148,6 @@ func (s *segment) logicalLen() seqnum.Size { // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. func (s *segment) parse() bool { - h, ok := s.data.PullUp(header.TCPMinimumSize) - if !ok { - return false - } - hdr := header.TCP(h) - // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: // 1. That it's at least the minimum header size; if we don't do this @@ -162,16 +158,12 @@ func (s *segment) parse() bool { // N.B. The segment has already been validated as having at least the // minimum TCP size before reaching here, so it's safe to read the // fields. - offset := int(hdr.DataOffset()) - if offset < header.TCPMinimumSize { - return false - } - hdrWithOpts, ok := s.data.PullUp(offset) - if !ok { + offset := int(s.hdr.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(s.hdr) { return false } - s.options = []byte(hdrWithOpts[header.TCPMinimumSize:]) + s.options = []byte(s.hdr[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) // Query the link capabilities to decide if checksum validation is @@ -180,22 +172,19 @@ func (s *segment) parse() bool { if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 { s.csumValid = true verifyChecksum = false - s.data.TrimFront(offset) } if verifyChecksum { - hdr = header.TCP(hdrWithOpts) - s.csum = hdr.Checksum() - xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) - xsum = hdr.CalculateChecksum(xsum) - s.data.TrimFront(offset) + s.csum = s.hdr.Checksum() + xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr))) + xsum = s.hdr.CalculateChecksum(xsum) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff } - s.sequenceNumber = seqnum.Value(hdr.SequenceNumber()) - s.ackNumber = seqnum.Value(hdr.AckNumber()) - s.flags = hdr.Flags() - s.window = seqnum.Size(hdr.WindowSize()) + s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(s.hdr.AckNumber()) + s.flags = s.hdr.Flags() + s.window = seqnum.Size(s.hdr.WindowSize()) return true } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 663af8fec..8c7895713 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1270,16 +1270,14 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. - hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.TransportHeader) + if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } - pkt.Data.TrimFront(header.UDPMinimumSize) - e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index e320c5758..4218e7d03 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -67,14 +67,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - // Get the header then trim it from the view. - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Malformed packet. - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() - return true - } - if int(header.UDP(h).Length()) > pkt.Data.Size() { + hdr := header.UDP(pkt.TransportHeader) + if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true @@ -121,7 +115,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() + payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } @@ -130,9 +124,10 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans // For example, a raw or packet socket may use what UDP // considers an unreachable destination. Thus we deep copy pkt // to prevent multiple ownership and SR errors. - newNetHeader := append(buffer.View(nil), pkt.NetworkHeader...) - payload := newNetHeader.ToVectorisedView() - payload.Append(pkt.Data.ToView().ToVectorisedView()) + newHeader := append(buffer.View(nil), pkt.NetworkHeader...) + newHeader = append(newHeader, pkt.TransportHeader...) + payload := newHeader.ToVectorisedView() + payload.AppendView(pkt.Data.ToView()) payload.CapLength(payloadLen) hdr := buffer.NewPrependable(headerLen) @@ -141,8 +136,9 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: payload, + Header: hdr, + TransportHeader: buffer.View(pkt), + Data: payload, }) case header.IPv6AddressSize: @@ -164,11 +160,11 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + pkt.Data.Size() + payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(pkt.NetworkHeader), []buffer.View{pkt.NetworkHeader}) + payload := buffer.NewVectorisedView(len(pkt.NetworkHeader)+len(pkt.TransportHeader), []buffer.View{pkt.NetworkHeader, pkt.TransportHeader}) payload.Append(pkt.Data) payload.CapLength(payloadLen) @@ -178,8 +174,9 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: payload, + Header: hdr, + TransportHeader: buffer.View(pkt), + Data: payload, }) } return true @@ -201,6 +198,18 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Packet is too small + return false + } + pkt.TransportHeader = h + pkt.Data.TrimFront(header.UDPMinimumSize) + return true +} + // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index e8ade882b..313a3f117 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -441,9 +441,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool // Inject packet. c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + Data: buf.ToVectorisedView(), }) } @@ -488,9 +486,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool // Inject packet. c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + Data: buf.ToVectorisedView(), }) } @@ -1720,6 +1716,58 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { } } +// TestShortHeader verifies that when a packet with a too-short UDP header is +// received, the malformed received global stat gets incremented. +func TestShortHeader(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + c.t.Helper() + h := unicastV6.header4Tuple(incoming) + + // Allocate a buffer for an IPv6 and too-short UDP header. + const udpSize = header.UDPMinimumSize - 1 + buf := buffer.NewView(header.IPv6MinimumSize + udpSize) + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + TrafficClass: testTOS, + PayloadLength: uint16(udpSize), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, + }) + + // Initialize the UDP header. + udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize)) + udpHdr.Encode(&header.UDPFields{ + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, + Length: header.UDPMinimumSize, + }) + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr))) + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) + // Copy all but the last byte of the UDP header into the packet. + copy(buf[header.IPv6MinimumSize:], udpHdr) + + // Inject packet. + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { + t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) + } +} + // TestShutdownRead verifies endpoint read shutdown and error // stats increment on packet receive. func TestShutdownRead(t *testing.T) { -- cgit v1.2.3 From 2d3b9d18e79c93dc3186b60de06ec587852c980f Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 9 Jun 2020 11:31:49 -0700 Subject: Handle removed NIC in NDP timer for packet tx NDP packets are sent periodically from NDP timers. These timers do not hold the NIC lock when sending packets as the packet write operation may take some time. While the lock is not held, the NIC may be removed by some other goroutine. This change handles that scenario gracefully. Test: stack_test.TestRemoveNICWhileHandlingRSTimer PiperOrigin-RevId: 315524143 --- pkg/tcpip/stack/BUILD | 1 + pkg/tcpip/stack/ndp.go | 96 ++++++++++++----- pkg/tcpip/stack/nic.go | 38 ++++--- pkg/tcpip/stack/nic_test.go | 257 ++++++++++++++++++++++++++++++++++++++++++++ pkg/tcpip/stack/route.go | 4 + pkg/tcpip/stack/stack.go | 7 ++ 6 files changed, 364 insertions(+), 39 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index f71073207..afca925ad 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -110,5 +110,6 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", ], ) diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index ae7a8f740..e28c23d66 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -467,8 +467,17 @@ type ndpState struct { // The default routers discovered through Router Advertisements. defaultRouters map[tcpip.Address]defaultRouterState - // The timer used to send the next router solicitation message. - rtrSolicitTimer *time.Timer + rtrSolicit struct { + // The timer used to send the next router solicitation message. + timer *time.Timer + + // Used to let the Router Solicitation timer know that it has been stopped. + // + // Must only be read from or written to while protected by the lock of + // the NIC this ndpState is associated with. MUST be set when the timer is + // set. + done *bool + } // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. @@ -648,13 +657,14 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // as starting a goroutine but we use a timer that fires immediately so we can // reset it for the next DAD iteration. timer = time.AfterFunc(0, func() { - ndp.nic.mu.RLock() + ndp.nic.mu.Lock() + defer ndp.nic.mu.Unlock() + if done { // If we reach this point, it means that the DAD timer fired after // another goroutine already obtained the NIC lock and stopped DAD // before this function obtained the NIC lock. Simply return here and do // nothing further. - ndp.nic.mu.RUnlock() return } @@ -665,15 +675,23 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref } dadDone := remaining == 0 - ndp.nic.mu.RUnlock() var err *tcpip.Error if !dadDone { - err = ndp.sendDADPacket(addr) + // Use the unspecified address as the source address when performing DAD. + ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) + + // Do not hold the lock when sending packets which may be a long running + // task or may block link address resolution. We know this is safe + // because immediately after obtaining the lock again, we check if DAD + // has been stopped before doing any work with the NIC. Note, DAD would be + // stopped if the NIC was disabled or removed, or if the address was + // removed. + ndp.nic.mu.Unlock() + err = ndp.sendDADPacket(addr, ref) + ndp.nic.mu.Lock() } - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() if done { // If we reach this point, it means that DAD was stopped after we released // the NIC's read lock and before we obtained the write lock. @@ -721,17 +739,24 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // addr. // // addr must be a tentative IPv6 address on ndp's NIC. -func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { +// +// The NIC ndp belongs to MUST NOT be locked. +func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - // Use the unspecified address as the source address when performing DAD. - ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) + r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() // Route should resolve immediately since snmc is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { + // Do not consider the NIC being unknown or disabled as a fatal error. + // Since this method is required to be called when the NIC is not locked, + // the NIC could have been disabled or removed by another goroutine. + if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { + return err + } + panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err)) } else if c != nil { panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) @@ -1816,7 +1841,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // // The NIC ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { - if ndp.rtrSolicitTimer != nil { + if ndp.rtrSolicit.timer != nil { // We are already soliciting routers. return } @@ -1833,14 +1858,27 @@ func (ndp *ndpState) startSolicitingRouters() { delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) } - ndp.rtrSolicitTimer = time.AfterFunc(delay, func() { + var done bool + ndp.rtrSolicit.done = &done + ndp.rtrSolicit.timer = time.AfterFunc(delay, func() { + ndp.nic.mu.Lock() + if done { + // If we reach this point, it means that the RS timer fired after another + // goroutine already obtained the NIC lock and stopped solicitations. + // Simply return here and do nothing further. + ndp.nic.mu.Unlock() + return + } + // As per RFC 4861 section 4.1, the source of the RS is an address assigned // to the sending interface, or the unspecified address if no address is // assigned to the sending interface. - ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress) + ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress) if ref == nil { - ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) + ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint) } + ndp.nic.mu.Unlock() + localAddr := ref.ep.ID().LocalAddress r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() @@ -1849,6 +1887,13 @@ func (ndp *ndpState) startSolicitingRouters() { // header.IPv6AllRoutersMulticastAddress is a multicast address so a // remote link address can be calculated without a resolution process. if c, err := r.Resolve(nil); err != nil { + // Do not consider the NIC being unknown or disabled as a fatal error. + // Since this method is required to be called when the NIC is not locked, + // the NIC could have been disabled or removed by another goroutine. + if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { + return + } + panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err)) } else if c != nil { panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID())) @@ -1893,17 +1938,18 @@ func (ndp *ndpState) startSolicitingRouters() { } ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - if remaining == 0 { - ndp.rtrSolicitTimer = nil - } else if ndp.rtrSolicitTimer != nil { + if done || remaining == 0 { + ndp.rtrSolicit.timer = nil + ndp.rtrSolicit.done = nil + } else if ndp.rtrSolicit.timer != nil { // Note, we need to explicitly check to make sure that // the timer field is not nil because if it was nil but // we still reached this point, then we know the NIC // was requested to stop soliciting routers so we don't // need to send the next Router Solicitation message. - ndp.rtrSolicitTimer.Reset(ndp.configs.RtrSolicitationInterval) + ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) } + ndp.nic.mu.Unlock() }) } @@ -1913,13 +1959,15 @@ func (ndp *ndpState) startSolicitingRouters() { // // The NIC ndp belongs to MUST be locked. func (ndp *ndpState) stopSolicitingRouters() { - if ndp.rtrSolicitTimer == nil { + if ndp.rtrSolicit.timer == nil { // Nothing to do. return } - ndp.rtrSolicitTimer.Stop() - ndp.rtrSolicitTimer = nil + *ndp.rtrSolicit.done = true + ndp.rtrSolicit.timer.Stop() + ndp.rtrSolicit.timer = nil + ndp.rtrSolicit.done = nil } // initializeTempAddrState initializes state related to temporary SLAAC diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index d756ae6f5..644c0d437 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -457,8 +457,20 @@ type ipv6AddrCandidate struct { // remoteAddr must be a valid IPv6 address. func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint { n.mu.RLock() - defer n.mu.RUnlock() + ref := n.primaryIPv6EndpointRLocked(remoteAddr) + n.mu.RUnlock() + return ref +} +// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address +// Selection (RFC 6724 section 5). +// +// Note, only rules 1-3 and 7 are followed. +// +// remoteAddr must be a valid IPv6 address. +// +// n.mu MUST be read locked. +func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint { primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber] if len(primaryAddrs) == 0 { @@ -568,11 +580,6 @@ const ( // promiscuous indicates that the NIC's promiscuous flag should be observed // when getting a NIC's referenced network endpoint. promiscuous - - // forceSpoofing indicates that the NIC should be assumed to be spoofing, - // regardless of what the NIC's spoofing flag is when getting a NIC's - // referenced network endpoint. - forceSpoofing ) func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { @@ -591,8 +598,6 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // or spoofing. Promiscuous mode will only be checked if promiscuous is true. // Similarly, spoofing will only be checked if spoofing is true. func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { - id := NetworkEndpointID{address} - n.mu.RLock() var spoofingOrPromiscuous bool @@ -601,11 +606,9 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t spoofingOrPromiscuous = n.mu.spoofing case promiscuous: spoofingOrPromiscuous = n.mu.promiscuous - case forceSpoofing: - spoofingOrPromiscuous = true } - if ref, ok := n.mu.endpoints[id]; ok { + if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // An endpoint with this id exists, check if it can be used and return it. switch ref.getKind() { case permanentExpired: @@ -654,11 +657,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - if ref, ok := n.mu.endpoints[id]; ok { + ref := n.getRefOrCreateTempLocked(protocol, address, peb) + n.mu.Unlock() + return ref +} + +/// getRefOrCreateTempLocked returns an existing endpoint for address or creates +/// and returns a temporary endpoint. +func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // No need to check the type as we are ok with expired endpoints at this // point. if ref.tryIncRef() { - n.mu.Unlock() return ref } // tryIncRef failing means the endpoint is scheduled to be removed once the @@ -670,7 +680,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { - n.mu.Unlock() return nil } ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ @@ -681,7 +690,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t }, }, peb, temporary, static, false) - n.mu.Unlock() return ref } diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index fea46158c..31f865260 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,11 +15,268 @@ package stack import ( + "math" "testing" + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) +var _ LinkEndpoint = (*testLinkEndpoint)(nil) + +// A LinkEndpoint that throws away outgoing packets. +// +// We use this instead of the channel endpoint as the channel package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testLinkEndpoint struct { + dispatcher NetworkDispatcher +} + +// Attach implements LinkEndpoint.Attach. +func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements LinkEndpoint.IsAttached. +func (e *testLinkEndpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements LinkEndpoint.MTU. +func (*testLinkEndpoint) MTU() uint32 { + return math.MaxUint16 +} + +// Capabilities implements LinkEndpoint.Capabilities. +func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities { + return CapabilityResolutionRequired +} + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*testLinkEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +// Wait implements LinkEndpoint.Wait. +func (*testLinkEndpoint) Wait() {} + +// WritePacket implements LinkEndpoint.WritePacket. +func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error { + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // Our tests don't use this so we don't support it. + return 0, tcpip.ErrNotSupported +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + // Our tests don't use this so we don't support it. + return tcpip.ErrNotSupported +} + +var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) + +// An IPv6 NetworkEndpoint that throws away outgoing packets. +// +// We use this instead of ipv6.endpoint because the ipv6 package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testIPv6Endpoint struct { + nicID tcpip.NICID + id NetworkEndpointID + prefixLen int + linkEP LinkEndpoint + protocol *testIPv6Protocol +} + +// DefaultTTL implements NetworkEndpoint.DefaultTTL. +func (*testIPv6Endpoint) DefaultTTL() uint8 { + return 0 +} + +// MTU implements NetworkEndpoint.MTU. +func (e *testIPv6Endpoint) MTU() uint32 { + return e.linkEP.MTU() - header.IPv6MinimumSize +} + +// Capabilities implements NetworkEndpoint.Capabilities. +func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength. +func (e *testIPv6Endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize +} + +// WritePacket implements NetworkEndpoint.WritePacket. +func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error { + return nil +} + +// WritePackets implements NetworkEndpoint.WritePackets. +func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) { + // Our tests don't use this so we don't support it. + return 0, tcpip.ErrNotSupported +} + +// WriteHeaderIncludedPacket implements +// NetworkEndpoint.WriteHeaderIncludedPacket. +func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error { + // Our tests don't use this so we don't support it. + return tcpip.ErrNotSupported +} + +// ID implements NetworkEndpoint.ID. +func (e *testIPv6Endpoint) ID() *NetworkEndpointID { + return &e.id +} + +// PrefixLen implements NetworkEndpoint.PrefixLen. +func (e *testIPv6Endpoint) PrefixLen() int { + return e.prefixLen +} + +// NICID implements NetworkEndpoint.NICID. +func (e *testIPv6Endpoint) NICID() tcpip.NICID { + return e.nicID +} + +// HandlePacket implements NetworkEndpoint.HandlePacket. +func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { +} + +// Close implements NetworkEndpoint.Close. +func (*testIPv6Endpoint) Close() {} + +// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber. +func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +var _ NetworkProtocol = (*testIPv6Protocol)(nil) + +// An IPv6 NetworkProtocol that supports the bare minimum to make a stack +// believe it supports IPv6. +// +// We use this instead of ipv6.protocol because the ipv6 package depends on +// the stack package which this test lives in, causing a cyclic dependency. +type testIPv6Protocol struct{} + +// Number implements NetworkProtocol.Number. +func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize. +func (*testIPv6Protocol) MinimumPacketSize() int { + return header.IPv6MinimumSize +} + +// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen. +func (*testIPv6Protocol) DefaultPrefixLen() int { + return header.IPv6AddressSize * 8 +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.IPv6(v) + return h.SourceAddress(), h.DestinationAddress() +} + +// NewEndpoint implements NetworkProtocol.NewEndpoint. +func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { + return &testIPv6Endpoint{ + nicID: nicID, + id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, + linkEP: linkEP, + protocol: p, + }, nil +} + +// SetOption implements NetworkProtocol.SetOption. +func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error { + return nil +} + +// Option implements NetworkProtocol.Option. +func (*testIPv6Protocol) Option(interface{}) *tcpip.Error { + return nil +} + +// Close implements NetworkProtocol.Close. +func (*testIPv6Protocol) Close() {} + +// Wait implements NetworkProtocol.Wait. +func (*testIPv6Protocol) Wait() {} + +// Parse implements NetworkProtocol.Parse. +func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { + return 0, false, false +} + +var _ LinkAddressResolver = (*testIPv6Protocol)(nil) + +// LinkAddressProtocol implements LinkAddressResolver. +func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// LinkAddressRequest implements LinkAddressResolver. +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { + return nil +} + +// ResolveStaticAddress implements LinkAddressResolver. +func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if header.IsV6MulticastAddress(addr) { + return header.EthernetAddressFromMulticastIPv6Address(addr), true + } + return "", false +} + +// Test the race condition where a NIC is removed and an RS timer fires at the +// same time. +func TestRemoveNICWhileHandlingRSTimer(t *testing.T) { + const ( + nicID = 1 + + maxRtrSolicitations = 5 + ) + + e := testLinkEndpoint{} + s := New(Options{ + NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}}, + NDPConfigs: NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: minimumRtrSolicitationInterval, + }, + }) + + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err) + } + + s.mu.Lock() + // Wait for the router solicitation timer to fire and block trying to obtain + // the stack lock when doing link address resolution. + time.Sleep(minimumRtrSolicitationInterval * 2) + if err := s.removeNICLocked(nicID); err != nil { + t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err) + } + s.mu.Unlock() +} + func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index f5b6ca0b9..d65f8049e 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -113,6 +113,8 @@ func (r *Route) GSOMaxSize() uint32 { // If address resolution is required, ErrNoLinkAddress and a notification channel is // returned for the top level caller to block. Channel is closed once address resolution // is complete (success or not). +// +// The NIC r uses must not be locked. func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if !r.IsResolutionRequired() { // Nothing to do if there is no cache (which does the resolution on cache miss) or @@ -148,6 +150,8 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // IsResolutionRequired returns true if Resolve() must be called to resolve // the link address before the this route can be written to. +// +// The NIC r uses must not be locked. func (r *Route) IsResolutionRequired() bool { return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 294ce8775..648791302 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1017,6 +1017,13 @@ func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() + return s.removeNICLocked(id) +} + +// removeNICLocked removes NIC and all related routes from the network stack. +// +// s.mu must be locked. +func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { nic, ok := s.nics[id] if !ok { return tcpip.ErrUnknownNICID -- cgit v1.2.3 From a085e562d0592bccc99e9e0380706a8025f70d53 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Wed, 10 Jun 2020 23:48:03 -0700 Subject: Add support for SO_REUSEADDR to UDP sockets/endpoints. On UDP sockets, SO_REUSEADDR allows multiple sockets to bind to the same address, but only delivers packets to the most recently bound socket. This differs from the behavior of SO_REUSEADDR on TCP sockets. SO_REUSEADDR for TCP sockets will likely need an almost completely independent implementation. SO_REUSEADDR has some odd interactions with the similar SO_REUSEPORT. These interactions are tested fairly extensively and all but one particularly odd one (that honestly seems like a bug) behave the same on gVisor and Linux. PiperOrigin-RevId: 315844832 --- pkg/tcpip/ports/ports.go | 116 ++++++++++++++------- pkg/tcpip/stack/BUILD | 1 + pkg/tcpip/stack/stack.go | 12 +-- pkg/tcpip/stack/transport_demuxer.go | 95 ++++++++--------- pkg/tcpip/stack/transport_demuxer_test.go | 3 +- pkg/tcpip/stack/transport_test.go | 7 +- pkg/tcpip/transport/icmp/BUILD | 1 + pkg/tcpip/transport/icmp/endpoint.go | 7 +- pkg/tcpip/transport/tcp/accept.go | 4 +- pkg/tcpip/transport/tcp/endpoint.go | 19 +++- pkg/tcpip/transport/udp/endpoint.go | 42 +++++--- pkg/tcpip/transport/udp/forwarder.go | 3 +- .../linux/socket_bind_to_device_sequence.cc | 29 ++++-- test/syscalls/linux/socket_inet_loopback.cc | 9 -- test/syscalls/linux/socket_ip_udp_generic.cc | 6 -- test/syscalls/linux/socket_ip_unbound.cc | 56 +++++++--- test/syscalls/linux/socket_ipv4_udp_unbound.cc | 23 ++-- 17 files changed, 255 insertions(+), 178 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index b937cb84b..edc29ad27 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -54,17 +54,27 @@ type Flags struct { LoadBalanced bool } -func (f Flags) bits() reuseFlag { - var rf reuseFlag +// Bits converts the Flags to their bitset form. +func (f Flags) Bits() BitFlags { + var rf BitFlags if f.MostRecent { - rf |= mostRecentFlag + rf |= MostRecentFlag } if f.LoadBalanced { - rf |= loadBalancedFlag + rf |= LoadBalancedFlag } return rf } +// Effective returns the effective behavior of a flag config. +func (f Flags) Effective() Flags { + e := f + if e.LoadBalanced && e.MostRecent { + e.MostRecent = false + } + return e +} + // PortManager manages allocating, reserving and releasing ports. type PortManager struct { mu sync.RWMutex @@ -78,56 +88,88 @@ type PortManager struct { hint uint32 } -type reuseFlag int +// BitFlags is a bitset representation of Flags. +type BitFlags uint32 const ( - mostRecentFlag reuseFlag = 1 << iota - loadBalancedFlag + // MostRecentFlag represents Flags.MostRecent. + MostRecentFlag BitFlags = 1 << iota + + // LoadBalancedFlag represents Flags.LoadBalanced. + LoadBalancedFlag + + // nextFlag is the value that the next added flag will have. + // + // It is used to calculate FlagMask below. It is also the number of + // valid flag states. nextFlag - flagMask = nextFlag - 1 + // FlagMask is a bit mask for BitFlags. + FlagMask = nextFlag - 1 ) -type portNode struct { - // refs stores the count for each possible flag combination. +// ToFlags converts the bitset into a Flags struct. +func (f BitFlags) ToFlags() Flags { + return Flags{ + MostRecent: f&MostRecentFlag != 0, + LoadBalanced: f&LoadBalancedFlag != 0, + } +} + +// FlagCounter counts how many references each flag combination has. +type FlagCounter struct { + // refs stores the count for each possible flag combination, (0 though + // FlagMask). refs [nextFlag]int } -func (p portNode) totalRefs() int { +// AddRef increases the reference count for a specific flag combination. +func (c *FlagCounter) AddRef(flags BitFlags) { + c.refs[flags]++ +} + +// DropRef decreases the reference count for a specific flag combination. +func (c *FlagCounter) DropRef(flags BitFlags) { + c.refs[flags]-- +} + +// TotalRefs calculates the total number of references for all flag +// combinations. +func (c FlagCounter) TotalRefs() int { var total int - for _, r := range p.refs { + for _, r := range c.refs { total += r } return total } -// flagRefs returns the number of references with all specified flags. -func (p portNode) flagRefs(flags reuseFlag) int { +// FlagRefs returns the number of references with all specified flags. +func (c FlagCounter) FlagRefs(flags BitFlags) int { var total int - for i, r := range p.refs { - if reuseFlag(i)&flags == flags { + for i, r := range c.refs { + if BitFlags(i)&flags == flags { total += r } } return total } -// allRefsHave returns if all references have all specified flags. -func (p portNode) allRefsHave(flags reuseFlag) bool { - for i, r := range p.refs { - if reuseFlag(i)&flags == flags && r > 0 { +// AllRefsHave returns if all references have all specified flags. +func (c FlagCounter) AllRefsHave(flags BitFlags) bool { + for i, r := range c.refs { + if BitFlags(i)&flags != flags && r > 0 { return false } } return true } -// intersectionRefs returns the set of flags shared by all references. -func (p portNode) intersectionRefs() reuseFlag { - intersection := flagMask - for i, r := range p.refs { +// IntersectionRefs returns the set of flags shared by all references. +func (c FlagCounter) IntersectionRefs() BitFlags { + intersection := FlagMask + for i, r := range c.refs { if r > 0 { - intersection &= reuseFlag(i) + intersection &= BitFlags(i) } } return intersection @@ -135,26 +177,26 @@ func (p portNode) intersectionRefs() reuseFlag { // deviceNode is never empty. When it has no elements, it is removed from the // map that references it. -type deviceNode map[tcpip.NICID]portNode +type deviceNode map[tcpip.NICID]FlagCounter // isAvailable checks whether binding is possible by device. If not binding to a -// device, check against all portNodes. If binding to a specific device, check +// device, check against all FlagCounters. If binding to a specific device, check // against the unspecified device and the provided device. // // If either of the port reuse flags is enabled on any of the nodes, all nodes // sharing a port must share at least one reuse flag. This matches Linux's // behavior. func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { - flagBits := flags.bits() + flagBits := flags.Bits() if bindToDevice == 0 { // Trying to binding all devices. if flagBits == 0 { // Can't bind because the (addr,port) is already bound. return false } - intersection := flagMask + intersection := FlagMask for _, p := range d { - i := p.intersectionRefs() + i := p.IntersectionRefs() intersection &= i if intersection&flagBits == 0 { // Can't bind because the (addr,port) was @@ -165,17 +207,17 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { return true } - intersection := flagMask + intersection := FlagMask if p, ok := d[0]; ok { - intersection = p.intersectionRefs() + intersection = p.IntersectionRefs() if intersection&flagBits == 0 { return false } } if p, ok := d[bindToDevice]; ok { - i := p.intersectionRefs() + i := p.IntersectionRefs() intersection &= i if intersection&flagBits == 0 { return false @@ -324,7 +366,7 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) { return false } - flagBits := flags.bits() + flagBits := flags.Bits() // Reserve port on all network protocols. for _, network := range networks { @@ -340,7 +382,7 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber m[addr] = d } n := d[bindToDevice] - n.refs[flagBits]++ + n.AddRef(flagBits) d[bindToDevice] = n } @@ -353,7 +395,7 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp s.mu.Lock() defer s.mu.Unlock() - flagBits := flags.bits() + flagBits := flags.Bits() for _, network := range networks { desc := portDescriptor{network, transport, port} @@ -368,7 +410,7 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp } n.refs[flagBits]-- d[bindToDevice] = n - if n.refs == [nextFlag]int{} { + if n.TotalRefs() == 0 { delete(d, bindToDevice) } if len(d) == 0 { diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index afca925ad..24f52b735 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -89,6 +89,7 @@ go_test( "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/ports", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 648791302..a2190341c 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1404,25 +1404,25 @@ func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep. // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { - return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice) +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { - s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // StartTransportEndpointCleanup removes the endpoint with the given id from // the stack transport dispatcher. It also transitions it to the cleanup stage. -func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { +func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { s.mu.Lock() defer s.mu.Unlock() s.cleanupEndpoints[ep] = struct{}{} - s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } // CompleteTransportEndpointCleanup removes the endpoint from the cleanup diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index e09866405..118b449d5 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -15,7 +15,6 @@ package stack import ( - "container/heap" "fmt" "math/rand" @@ -23,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" ) type protocolIDs struct { @@ -43,14 +43,14 @@ type transportEndpoints struct { // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() epsByNIC, ok := eps.endpoints[id] if !ok { return } - if !epsByNIC.unregisterEndpoint(bindToDevice, ep) { + if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) { return } delete(eps.endpoints, id) @@ -204,7 +204,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint // registerEndpoint returns true if it succeeds. It fails and returns // false if ep already has an element with the same key. -func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { epsByNIC.mu.Lock() defer epsByNIC.mu.Unlock() @@ -214,23 +214,22 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t demux: d, netProto: netProto, transProto: transProto, - reuse: reusePort, } epsByNIC.endpoints[bindToDevice] = multiPortEp } - return multiPortEp.singleRegisterEndpoint(t, reusePort) + return multiPortEp.singleRegisterEndpoint(t, flags) } // unregisterEndpoint returns true if endpointsByNIC has to be unregistered. -func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { +func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool { epsByNIC.mu.Lock() defer epsByNIC.mu.Unlock() multiPortEp, ok := epsByNIC.endpoints[bindToDevice] if !ok { return false } - if multiPortEp.unregisterEndpoint(t) { + if multiPortEp.unregisterEndpoint(t, flags) { delete(epsByNIC.endpoints, bindToDevice) } return len(epsByNIC.endpoints) == 0 @@ -279,10 +278,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { for i, n := range netProtos { - if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil { - d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice) + if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice) return err } } @@ -290,35 +289,6 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum return nil } -type transportEndpointHeap []TransportEndpoint - -var _ heap.Interface = (*transportEndpointHeap)(nil) - -func (h *transportEndpointHeap) Len() int { - return len(*h) -} - -func (h *transportEndpointHeap) Less(i, j int) bool { - return (*h)[i].UniqueID() < (*h)[j].UniqueID() -} - -func (h *transportEndpointHeap) Swap(i, j int) { - (*h)[i], (*h)[j] = (*h)[j], (*h)[i] -} - -func (h *transportEndpointHeap) Push(x interface{}) { - *h = append(*h, x.(TransportEndpoint)) -} - -func (h *transportEndpointHeap) Pop() interface{} { - old := *h - n := len(old) - x := old[n-1] - old[n-1] = nil - *h = old[:n-1] - return x -} - // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. @@ -334,9 +304,10 @@ type multiPortEndpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber - endpoints transportEndpointHeap - // reuse indicates if more than one endpoint is allowed. - reuse bool + // endpoints stores the transport endpoints in the order in which they + // were bound. This is required for UDP SO_REUSEADDR. + endpoints []TransportEndpoint + flags ports.FlagCounter } func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint { @@ -362,6 +333,10 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[0] } + if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent { + return mpep.endpoints[len(mpep.endpoints)-1] + } + payload := []byte{ byte(id.LocalPort), byte(id.LocalPort >> 8), @@ -401,40 +376,47 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() + bits := flags.Bits() + if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. - if !ep.reuse || !reusePort { + if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { return tcpip.ErrPortInUse } } - heap.Push(&ep.endpoints, t) + ep.endpoints = append(ep.endpoints, t) + ep.flags.AddRef(bits) return nil } // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. -func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { +func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool { ep.mu.Lock() defer ep.mu.Unlock() for i, endpoint := range ep.endpoints { if endpoint == t { - heap.Remove(&ep.endpoints, i) + copy(ep.endpoints[i:], ep.endpoints[i+1:]) + ep.endpoints[len(ep.endpoints)-1] = nil + ep.endpoints = ep.endpoints[:len(ep.endpoints)-1] + + ep.flags.DropRef(flags.Bits()) break } } return len(ep.endpoints) == 0 } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { if id.RemotePort != 0 { - // TODO(eyalsoha): Why? - reusePort = false + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false } eps, ok := d.protocol[protocolIDs{netProto, protocol}] @@ -454,15 +436,20 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.endpoints[id] = epsByNIC } - return epsByNIC.registerEndpoint(d, netProto, protocol, ep, reusePort, bindToDevice) + return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { + if id.RemotePort != 0 { + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false + } + for _, n := range netProtos { if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.unregisterEndpoint(id, ep, bindToDevice) + eps.unregisterEndpoint(id, ep, flags, bindToDevice) } } } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 67d778137..73dada928 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" @@ -195,7 +196,7 @@ func TestTransportDemuxerRegister(t *testing.T) { if !ok { t.Fatalf("%T does not implement stack.TransportEndpoint", ep) } - if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, false, 0), test.want; got != want { + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want { t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want) } }) diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index ad61c09d6..7e8b84867 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -154,7 +155,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */) + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { return err } @@ -199,8 +200,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, f, - false, /* reuse */ - 0, /* bindtoDevice */ + ports.Flags{}, + 0, /* bindtoDevice */ ); err != nil { return err } diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 9ce625c17..7e5c79776 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 57e0a069b..8ce294002 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" ) @@ -110,7 +111,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -607,14 +608,14 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) switch err { case nil: return true, nil diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index e6a23c978..ad197e8db 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -238,13 +239,14 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.mu.Lock() // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.boundBindToDevice); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, ports.Flags{LoadBalanced: n.reusePort}, n.boundBindToDevice); err != nil { n.mu.Unlock() n.Close() return nil, err } n.isRegistered = true + n.registeredReusePort = n.reusePort return n, nil } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 19f7bf449..6e4d607da 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -465,6 +465,10 @@ type endpoint struct { // reusePort is set to true if SO_REUSEPORT is enabled. reusePort bool + // registeredReusePort is set if the current endpoint registration was + // done with SO_REUSEPORT enabled. + registeredReusePort bool + // bindToDevice is set to the NIC on which to bind or disabled if 0. bindToDevice tcpip.NICID @@ -1021,8 +1025,9 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice) e.isRegistered = false + e.registeredReusePort = false } e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) @@ -1086,8 +1091,9 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice) e.isRegistered = false + e.registeredReusePort = false } if e.isPortReserved { @@ -2088,10 +2094,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice) if err != nil { return err } + e.registeredReusePort = e.reusePort } else { // The endpoint doesn't have a local port yet, so try to get // one. Make sure that it isn't one that will result in the same @@ -2123,12 +2130,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { + switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, ports.Flags{LoadBalanced: e.reusePort}, e.bindToDevice) { case nil: // Port picking successful. Save the details of // the selected port. e.ID = id e.boundBindToDevice = e.bindToDevice + e.registeredReusePort = e.reusePort return true, nil case tcpip.ErrPortInUse: return false, nil @@ -2326,12 +2334,13 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice); err != nil { return err } e.isRegistered = true e.setEndpointState(StateListen) + e.registeredReusePort = e.reusePort // The channel may be non-nil when we're restoring the endpoint, and it // may be pre-populated with some previously accepted (but not Accepted) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index c5e3c73ef..df5efbf6a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -103,7 +103,7 @@ type endpoint struct { multicastAddr tcpip.Address multicastNICID tcpip.NICID multicastLoop bool - reusePort bool + portFlags ports.Flags bindToDevice tcpip.NICID broadcast bool @@ -214,7 +214,7 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} @@ -558,10 +558,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.mu.Unlock() case tcpip.ReuseAddressOption: + e.mu.Lock() + e.portFlags.MostRecent = v + e.mu.Unlock() case tcpip.ReusePortOption: e.mu.Lock() - e.reusePort = v + e.portFlags.LoadBalanced = v e.mu.Unlock() case tcpip.V6OnlyOption: @@ -795,11 +798,15 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return v, nil case tcpip.ReuseAddressOption: - return false, nil + e.mu.RLock() + v := e.portFlags.MostRecent + e.mu.RUnlock() + + return v, nil case tcpip.ReusePortOption: e.mu.RLock() - v := e.reusePort + v := e.portFlags.LoadBalanced e.mu.RUnlock() return v, nil @@ -968,6 +975,11 @@ func (e *endpoint) Disconnect() *tcpip.Error { id stack.TransportEndpointID btd tcpip.NICID ) + + // We change this value below and we need the old value to unregister + // the endpoint. + boundPortFlags := e.boundPortFlags + // Exclude ephemerally bound endpoints. if e.BindNICID != 0 || e.ID.LocalAddress == "" { var err *tcpip.Error @@ -980,16 +992,17 @@ func (e *endpoint) Disconnect() *tcpip.Error { return err } e.state = StateBound + boundPortFlags = e.boundPortFlags } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice) e.boundPortFlags = ports.Flags{} } e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) e.ID = id e.boundBindToDevice = btd e.route.Release() @@ -1061,6 +1074,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } } + oldPortFlags := e.boundPortFlags + id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { return err @@ -1068,7 +1083,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Remove the old registration. if e.ID.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundBindToDevice) + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) } e.ID = id @@ -1132,20 +1147,15 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { - flags := ports.Flags{ - LoadBalanced: e.reusePort, - // FIXME(b/129164367): Support SO_REUSEADDR. - MostRecent: false, - } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, flags, e.bindToDevice) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice) if err != nil { return id, e.bindToDevice, err } - e.boundPortFlags = flags id.LocalPort = port } + e.boundPortFlags = e.portFlags - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) if err != nil { e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice) e.boundPortFlags = ports.Flags{} diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 7abfa0ed2..c67e0ba95 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -73,7 +73,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { ep.Close() return nil, err } @@ -82,6 +82,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.route = r.route.Clone() ep.dstPort = r.id.RemotePort ep.RegisterNICID = r.route.NICID() + ep.boundPortFlags = ep.portFlags ep.state = StateConnected diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc index 1967329ee..4b2cb0606 100644 --- a/test/syscalls/linux/socket_bind_to_device_sequence.cc +++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc @@ -363,8 +363,9 @@ TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) { } TEST_P(BindToDeviceSequenceTest, BindWithReuseAddr) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); + // FIXME(b/158253835): Support SO_REUSEADDR on TCP sockets. + SKIP_IF(IsRunningOnGvisor() && + (GetParam().type & SOCK_STREAM) == SOCK_STREAM); ASSERT_NO_FATAL_FAILURE( BindSocket(/* reusePort */ false, /* reuse_addr */ true)); @@ -405,8 +406,9 @@ TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReusePort) { } TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReuseAddr) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); + // FIXME(b/158253835): Support SO_REUSEADDR on TCP sockets. + SKIP_IF(IsRunningOnGvisor() && + (GetParam().type & SOCK_STREAM) == SOCK_STREAM); ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, /* reuse_addr */ true, @@ -434,8 +436,9 @@ TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReusePort) { } TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReuseAddr) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); + // FIXME(b/158253835): Support SO_REUSEADDR on TCP sockets. + SKIP_IF(IsRunningOnGvisor() && + (GetParam().type & SOCK_STREAM) == SOCK_STREAM); ASSERT_NO_FATAL_FAILURE(BindSocket( /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0)); @@ -469,10 +472,20 @@ TEST_P(BindToDeviceSequenceTest, BindReuseAddrThenReuseAddr) { /* bind_to_device */ 0, EADDRINUSE)); } -// This behavior seems like a bug? TEST_P(BindToDeviceSequenceTest, BindReuseAddrThenReuseAddrReusePortThenReuseAddr) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. + // The behavior described in this test seems like a Linux bug. It doesn't + // make any sense and it is unlikely that any applications rely on it. + // + // Both SO_REUSEADDR and SO_REUSEPORT allow binding multiple UDP sockets to + // the same address and deliver each packet to exactly one of the bound + // sockets. If both are enabled, one of the strategies is selected to route + // packets. The strategy is selected dynamically based on the settings of the + // currently bound sockets. Usually, the strategy is selected based on the + // common setting (SO_REUSEADDR or SO_REUSEPORT) amongst the sockets, but for + // some reason, Linux allows binding sets of sockets with no overlapping + // settings in some situations. In this case, it is not obvious which strategy + // would be selected as the configured setting is a contradiction. SKIP_IF(IsRunningOnGvisor()); ASSERT_NO_FATAL_FAILURE(BindSocket( diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index fa890ec98..a914d9a12 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -1900,9 +1900,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) { auto const& param = GetParam(); - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_DGRAM); - // Bind the v6 loopback on a dual stack socket. TestAddress const& test_addr = V6Loopback(); sockaddr_storage bound_addr = test_addr.addr; @@ -2091,9 +2088,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReservedResueAddr) { auto const& param = GetParam(); - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_DGRAM); - // Bind the v4 loopback on a dual stack socket. TestAddress const& test_addr = V4MappedLoopback(); sockaddr_storage bound_addr = test_addr.addr; @@ -2283,9 +2277,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { auto const& param = GetParam(); - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_DGRAM); - // Bind the v4 loopback on a v4 socket. TestAddress const& test_addr = V4Loopback(); sockaddr_storage bound_addr = test_addr.addr; diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 1c533fdf2..edb86aded 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -225,9 +225,6 @@ TEST_P(UDPSocketPairTest, ReuseAddrDefault) { TEST_P(UDPSocketPairTest, SetReuseAddr) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); @@ -292,9 +289,6 @@ TEST_P(UDPSocketPairTest, SetReusePort) { TEST_P(UDPSocketPairTest, SetReuseAddrReusePort) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, sizeof(kSockOptOn)), SyscallSucceeds()); diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc index af8dda19c..5536ab53a 100644 --- a/test/syscalls/linux/socket_ip_unbound.cc +++ b/test/syscalls/linux/socket_ip_unbound.cc @@ -157,7 +157,7 @@ TEST_P(IPUnboundSocketTest, TOSDefault) { int get = -1; socklen_t get_sz = sizeof(get); constexpr int kDefaultTOS = 0; - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(get)); EXPECT_EQ(get, kDefaultTOS); @@ -173,7 +173,7 @@ TEST_P(IPUnboundSocketTest, SetTOS) { int get = -1; socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(get)); EXPECT_EQ(get, set); @@ -188,7 +188,7 @@ TEST_P(IPUnboundSocketTest, ZeroTOS) { SyscallSucceedsWithValue(0)); int get = -1; socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(get)); EXPECT_EQ(get, set); @@ -210,7 +210,7 @@ TEST_P(IPUnboundSocketTest, InvalidLargeTOS) { } int get = -1; socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(get)); EXPECT_EQ(get, kDefaultTOS); @@ -229,7 +229,7 @@ TEST_P(IPUnboundSocketTest, CheckSkipECN) { } int get = -1; socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(get)); EXPECT_EQ(get, expect); @@ -249,7 +249,7 @@ TEST_P(IPUnboundSocketTest, ZeroTOSOptionSize) { } int get = -1; socklen_t get_sz = 0; - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, 0); EXPECT_EQ(get, -1); @@ -276,7 +276,7 @@ TEST_P(IPUnboundSocketTest, SmallTOSOptionSize) { } uint get = -1; socklen_t get_sz = i; - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, expect_sz); // Account for partial copies by getsockopt, retrieve the lower @@ -297,7 +297,7 @@ TEST_P(IPUnboundSocketTest, LargeTOSOptionSize) { // We expect the system call handler to only copy atmost sizeof(int) bytes // as asserted by the check below. Hence, we do not expect the copy to // overflow in getsockopt. - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(int)); EXPECT_EQ(get, set); @@ -325,7 +325,7 @@ TEST_P(IPUnboundSocketTest, NegativeTOS) { } int get = -1; socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(get)); EXPECT_EQ(get, expect); @@ -338,20 +338,20 @@ TEST_P(IPUnboundSocketTest, InvalidNegativeTOS) { TOSOption t = GetTOSOption(GetParam().domain); int expect; if (GetParam().domain == AF_INET) { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), + ASSERT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), SyscallSucceedsWithValue(0)); expect = static_cast(set); if (GetParam().protocol == IPPROTO_TCP) { expect &= ~INET_ECN_MASK; } } else { - EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), + ASSERT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz), SyscallFailsWithErrno(EINVAL)); expect = 0; } int get = 0; socklen_t get_sz = sizeof(get); - EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), + ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_sz, sizeof(get)); EXPECT_EQ(get, expect); @@ -422,6 +422,38 @@ TEST_P(IPUnboundSocketTest, InsufficientBufferTOS) { EXPECT_THAT(sendmsg(socket->get(), &msg, 0), SyscallFailsWithErrno(EINVAL)); } +TEST_P(IPUnboundSocketTest, ReuseAddrDefault) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // FIXME(b/76031995): gVisor TCP sockets default to SO_REUSEADDR enabled. + SKIP_IF(IsRunningOnGvisor() && + (GetParam().type & SOCK_STREAM) == SOCK_STREAM); + + int get = -1; + socklen_t get_sz = sizeof(get); + ASSERT_THAT( + getsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, kSockOptOff); + EXPECT_EQ(get_sz, sizeof(get)); +} + +TEST_P(IPUnboundSocketTest, SetReuseAddr) { + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + ASSERT_THAT(setsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + int get = -1; + socklen_t get_sz = sizeof(get); + ASSERT_THAT( + getsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &get, &get_sz), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, kSockOptOn); + EXPECT_EQ(get_sz, sizeof(get)); +} + INSTANTIATE_TEST_SUITE_P( IPUnboundSockets, IPUnboundSocketTest, ::testing::ValuesIn(VecCat(VecCat( diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index 1294d9050..4db9553e1 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -1672,10 +1672,10 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) { } // Check that SO_REUSEADDR always delivers to the most recently bound socket. -TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - +// +// FIXME(gvisor.dev/issue/873): Endpoint order is not restored correctly. Enable +// random and co-op save (below) once that is fixed. +TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) { std::vector> sockets; sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket())); @@ -1696,6 +1696,9 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution) { constexpr int kMessageSize = 200; + // FIXME(gvisor.dev/issue/873): Endpoint order is not restored correctly. + const DisableSave ds; + for (int i = 0; i < 10; i++) { // Add a new receiver. sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket())); @@ -1836,9 +1839,6 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReusePort) { } TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReuseAddr) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -1880,9 +1880,6 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReuseAddr) { } TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -1927,9 +1924,6 @@ TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) { } TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable2) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -2020,9 +2014,6 @@ TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReusePort) { } TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReuseAddr) { - // FIXME(b/129164367): Support SO_REUSEADDR on UDP sockets. - SKIP_IF(IsRunningOnGvisor()); - auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); -- cgit v1.2.3 From 4c0a8bdaf5e21ac85a4275e9008e5cd4294f45f3 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 11 Jun 2020 16:08:06 -0700 Subject: Do not use tentative addresses for routes Tentative addresses should not be used when finding a route. This change fixes a bug where a tentative address may have been used. Test: stack_test.TestDADResolve PiperOrigin-RevId: 315997624 --- pkg/tcpip/stack/ndp_test.go | 75 ++++++++++++++++++++++++++++++++++++--------- pkg/tcpip/stack/nic.go | 52 ++++++++++++++++++------------- 2 files changed, 91 insertions(+), 36 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 58f1ebf60..ae326b3ab 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -421,28 +421,52 @@ func TestDADResolve(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } + // We add a default route so the call to FindRoute below will succeed + // once we have an assigned address. + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + Gateway: addr3, + NIC: nicID, + }}) + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { + if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); addr != want { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Make sure the address does not resolve before the resolution time has // passed. time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout) - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } + // Should not get a route even if we specify the local address as the + // tentative address. + { + r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) + if err != tcpip.ErrNoRoute { + t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) + } + r.Release() } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + { + r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) + if err != tcpip.ErrNoRoute { + t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) + } + r.Release() + } + + if t.Failed() { + t.FailNow() } // Wait for DAD to resolve. @@ -454,12 +478,33 @@ func TestDADResolve(t *testing.T) { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } else if addr.Address != addr1 { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) + // Should get a route using the address now that it is resolved. + { + r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false) + if err != nil { + t.Errorf("got FindRoute(%d, '', %s, %d, false): %s", nicID, addr2, header.IPv6ProtocolNumber, err) + } else if r.LocalAddress != addr1 { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1) + } + r.Release() + } + { + r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) + if err != nil { + t.Errorf("got FindRoute(%d, %s, %s, %d, false): %s", nicID, addr1, addr2, header.IPv6ProtocolNumber, err) + } else if r.LocalAddress != addr1 { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1) + } + r.Release() + } + + if t.Failed() { + t.FailNow() } // Should not have sent any more NS messages. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 644c0d437..afb7dfeaf 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -610,18 +610,14 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // An endpoint with this id exists, check if it can be used and return it. - switch ref.getKind() { - case permanentExpired: - if !spoofingOrPromiscuous { - n.mu.RUnlock() - return nil - } - fallthrough - case temporary, permanent: - if ref.tryIncRef() { - n.mu.RUnlock() - return ref - } + if !ref.isAssignedRLocked(spoofingOrPromiscuous) { + n.mu.RUnlock() + return nil + } + + if ref.tryIncRef() { + n.mu.RUnlock() + return ref } } @@ -689,7 +685,6 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add PrefixLen: netProto.DefaultPrefixLen(), }, }, peb, temporary, static, false) - return ref } @@ -1660,8 +1655,8 @@ func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { } // isValidForOutgoing returns true if the endpoint can be used to send out a -// packet. It requires the endpoint to not be marked expired (i.e., its address -// has been removed), or the NIC to be in spoofing mode. +// packet. It requires the endpoint to not be marked expired (i.e., its address) +// has been removed) unless the NIC is in spoofing mode, or temporary. func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { r.nic.mu.RLock() defer r.nic.mu.RUnlock() @@ -1669,13 +1664,28 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { return r.isValidForOutgoingRLocked() } -// isValidForOutgoingRLocked returns true if the endpoint can be used to send -// out a packet. It requires the endpoint to not be marked expired (i.e., its -// address has been removed), or the NIC to be in spoofing mode. -// -// r's NIC must be read locked. +// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires +// r.nic.mu to be read locked. func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { - return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing) + if !r.nic.mu.enabled { + return false + } + + return r.isAssignedRLocked(r.nic.mu.spoofing) +} + +// isAssignedRLocked returns true if r is considered to be assigned to the NIC. +// +// r.nic.mu must be read locked. +func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool { + switch r.getKind() { + case permanentTentative: + return false + case permanentExpired: + return spoofingOrPromiscuous + default: + return true + } } // expireLocked decrements the reference count and marks the permanent endpoint -- cgit v1.2.3 From 57286eb642b9becc566f8e9c1dcbe24619f7772b Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 17 Jun 2020 17:21:19 -0700 Subject: Increase timeouts for NDP tests ... to help reduce flakes. When waiting for an event to occur, use a timeout of 10s. When waiting for an event to not occur, use a timeout of 1s. Test: Ran test locally w/ run count of 1000 with and without gotsan. PiperOrigin-RevId: 316998128 --- pkg/tcpip/stack/ndp_test.go | 157 ++++++++++++++++++++++-------------------- pkg/tcpip/stack/stack_test.go | 2 +- 2 files changed, 84 insertions(+), 75 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index ae326b3ab..6f86abc98 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -36,15 +36,24 @@ import ( ) const ( - addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") - linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") - linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") - defaultTimeout = 100 * time.Millisecond - defaultAsyncEventTimeout = time.Second + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") + linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") + linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") + + // Extra time to use when waiting for an async event to occur. + defaultAsyncPositiveEventTimeout = 10 * time.Second + + // Extra time to use when waiting for an async event to not occur. + // + // Since a negative check is used to make sure an event did not happen, it is + // okay to use a smaller timeout compared to the positive case since execution + // stall in regards to the monotonic clock will not affect the expected + // outcome. + defaultAsyncNegativeEventTimeout = time.Second ) var ( @@ -442,7 +451,7 @@ func TestDADResolve(t *testing.T) { // Make sure the address does not resolve before the resolution time has // passed. - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncEventTimeout) + time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout) if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } else if want := (tcpip.AddressWithPrefix{}); addr != want { @@ -471,7 +480,7 @@ func TestDADResolve(t *testing.T) { // Wait for DAD to resolve. select { - case <-time.After(2 * defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { @@ -1169,7 +1178,7 @@ func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { select { case <-ndpDisp.routerC: t.Fatal("should not have received any router events") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): } } @@ -1252,7 +1261,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncEventTimeout) + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) // Rx an RA from lladdr2 with huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) @@ -1269,7 +1278,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncEventTimeout) + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) } // TestRouterDiscoveryMaxRouters tests that only @@ -1418,7 +1427,7 @@ func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("should not have received any prefix events") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): } } @@ -1500,7 +1509,7 @@ func TestPrefixDiscovery(t *testing.T) { if diff := checkPrefixEvent(e, subnet2, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncEventTimeout): + case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for prefix discovery event") } @@ -1565,7 +1574,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultTimeout): + case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): } // Receive an RA with finite lifetime. @@ -1590,7 +1599,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultTimeout): + case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): } // Receive an RA with a prefix with a lifetime value greater than the @@ -1599,7 +1608,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultTimeout): + case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout): } // Receive an RA with 0 lifetime. @@ -1835,7 +1844,7 @@ func TestAutoGenAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { @@ -1962,7 +1971,7 @@ func TestAutoGenTempAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } @@ -1975,7 +1984,7 @@ func TestAutoGenTempAddr(t *testing.T) { if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } } @@ -2081,10 +2090,10 @@ func TestAutoGenTempAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - case <-time.After(newMinVLDuration + defaultTimeout): + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { @@ -2180,7 +2189,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } @@ -2188,7 +2197,7 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Errorf("got unxpected auto gen addr event = %+v", e) - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } }) } @@ -2265,7 +2274,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } select { @@ -2273,7 +2282,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } @@ -2363,13 +2372,13 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncEventTimeout) + expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" { t.Fatal(mismatch) } // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout) + expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" { t.Fatal(mismatch) } @@ -2398,7 +2407,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" { @@ -2407,12 +2416,12 @@ func TestAutoGenTempAddrRegen(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated event = %+v", e) - case <-time.After(defaultTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } } else { t.Fatalf("got unexpected auto-generated event = %+v", e) } - case <-time.After(invalidateAfter + defaultAsyncEventTimeout): + case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } @@ -2517,7 +2526,7 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncEventTimeout): + case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): } // Prefer the prefix again. @@ -2546,7 +2555,7 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncEventTimeout): + case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): } // Set the maximum lifetimes for temporary addresses such that on the next @@ -2556,14 +2565,14 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { // addresses + the time that has already passed since the last address was // generated so that the regeneration timer is needed to generate the next // address. - newLifetimes := newMinVLDuration + regenAfter + defaultAsyncEventTimeout + newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout ndpConfigs.MaxTempAddrValidLifetime = newLifetimes ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil { t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err) } e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncEventTimeout) + expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) } // TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response @@ -2711,7 +2720,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } @@ -2724,7 +2733,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } } @@ -3070,7 +3079,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3110,7 +3119,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3124,7 +3133,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout) + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3156,7 +3165,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { @@ -3165,12 +3174,12 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } } else { t.Fatalf("got unexpected auto-generated event") } - case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3295,7 +3304,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(minVLSeconds*time.Second + defaultAsyncEventTimeout): + case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout): t.Fatal("timeout waiting for addr auto gen event") } }) @@ -3439,7 +3448,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncEventTimeout): + case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout): } // Wait for the invalidation event. @@ -3448,7 +3457,7 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(2 * defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timeout waiting for addr auto gen event") } }) @@ -3509,7 +3518,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): } } @@ -3672,7 +3681,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): + case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): } if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -3770,7 +3779,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout): + case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3837,7 +3846,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } @@ -3863,7 +3872,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } } @@ -4030,7 +4039,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } }) } @@ -4149,7 +4158,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncNegativeEventTimeout): } }) } @@ -4251,7 +4260,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") } @@ -4277,7 +4286,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncEventTimeout): + case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") } } else { @@ -4285,7 +4294,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } } - case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for auto gen addr event") } } @@ -4869,7 +4878,7 @@ func TestCleanupNDPState(t *testing.T) { // Should not get any more events (invalidation timers should have been // cancelled when the NDP state was cleaned up). - time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) + time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout) select { case <-ndpDisp.routerC: t.Error("unexpected router event") @@ -5172,24 +5181,24 @@ func TestRouterSolicitation(t *testing.T) { // Make sure each RS is sent at the right time. remaining := test.maxRtrSolicit if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) + waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout) remaining-- } for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncEventTimeout) - waitForPkt(2 * defaultAsyncEventTimeout) + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout) + waitForPkt(defaultAsyncPositiveEventTimeout) } else { - waitForPkt(test.effectiveRtrSolicitInt * defaultAsyncEventTimeout) + waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout) } } // Make sure no more RS. if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncEventTimeout) + waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout) } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) + waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout) } // Make sure the counter got properly @@ -5305,11 +5314,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stop soliciting routers. test.stopFn(t, s, true /* first */) - ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) + ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { // A single RS may have been sent before solicitations were stopped. - ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout) + ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok = e.ReadContext(ctx); ok { t.Fatal("should not have sent more than one RS message") @@ -5319,7 +5328,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stopping router solicitations after it has already been stopped should // do nothing. test.stopFn(t, s, false /* first */) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") @@ -5332,10 +5341,10 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Start soliciting routers. test.startFn(t, s) - waitForPkt(delay + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncEventTimeout) + waitForPkt(delay + defaultAsyncPositiveEventTimeout) + waitForPkt(interval + defaultAsyncPositiveEventTimeout) + waitForPkt(interval + defaultAsyncPositiveEventTimeout) + ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") @@ -5344,7 +5353,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Starting router solicitations after it has already completed should do // nothing. test.startFn(t, s) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncEventTimeout) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) defer cancel() if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after finishing router solicitations") diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index ffef9bc2c..5aacbf53e 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -3305,7 +3305,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { // Wait for DAD to resolve. select { - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" { -- cgit v1.2.3 From 09b2fca40c61f9ec8d6745f422f6f45b399e8f94 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 18 Jun 2020 00:08:55 -0700 Subject: Cleanup tcp.timer and tcpip.Route When a tcp.timer or tcpip.Route is no longer used, clean up its resources so that unused memory may be released. PiperOrigin-RevId: 317046582 --- pkg/tcpip/stack/stack.go | 6 ++--- pkg/tcpip/transport/tcp/BUILD | 10 +++++++- pkg/tcpip/transport/tcp/timer.go | 1 + pkg/tcpip/transport/tcp/timer_test.go | 47 +++++++++++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 4 deletions(-) create mode 100644 pkg/tcpip/transport/tcp/timer_test.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a2190341c..51abe32a7 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1033,14 +1033,14 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { // Remove routes in-place. n tracks the number of routes written. n := 0 for i, r := range s.routeTable { + s.routeTable[i] = tcpip.Route{} if r.NIC != id { // Keep this route. - if i > n { - s.routeTable[n] = r - } + s.routeTable[n] = r n++ } } + s.routeTable = s.routeTable[:n] return nic.remove() diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index e26f01fae..6baeda8e4 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -76,7 +76,7 @@ go_library( ) go_test( - name = "tcp_test", + name = "tcp_x_test", size = "medium", srcs = [ "dual_stack_test.go", @@ -115,3 +115,11 @@ go_test( "//pkg/tcpip/seqnum", ], ) + +go_test( + name = "tcp_test", + size = "small", + srcs = ["timer_test.go"], + library = ":tcp", + deps = ["//pkg/sleep"], +) diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go index c70525f27..7981d469b 100644 --- a/pkg/tcpip/transport/tcp/timer.go +++ b/pkg/tcpip/transport/tcp/timer.go @@ -85,6 +85,7 @@ func (t *timer) init(w *sleep.Waker) { // cleanup frees all resources associated with the timer. func (t *timer) cleanup() { t.timer.Stop() + *t = timer{} } // checkExpiration checks if the given timer has actually expired, it should be diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go new file mode 100644 index 000000000..dbd6dff54 --- /dev/null +++ b/pkg/tcpip/transport/tcp/timer_test.go @@ -0,0 +1,47 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sleep" +) + +func TestCleanup(t *testing.T) { + const ( + timerDurationSeconds = 2 + isAssertedTimeoutSeconds = timerDurationSeconds + 1 + ) + + tmr := timer{} + w := sleep.Waker{} + tmr.init(&w) + tmr.enable(timerDurationSeconds * time.Second) + tmr.cleanup() + + if want := (timer{}); tmr != want { + t.Errorf("got tmr = %+v, want = %+v", tmr, want) + } + + // The waker should not be asserted. + for i := 0; i < isAssertedTimeoutSeconds; i++ { + time.Sleep(time.Second) + if w.IsAsserted() { + t.Fatalf("waker asserted unexpectedly") + } + } +} -- cgit v1.2.3 From 28b8a5cc3ac538333756084da28d7f13f13b5c87 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 18 Jun 2020 17:00:47 -0700 Subject: iptables: remove metadata struct Metadata was useful for debugging and safety, but enough tests exist that we should see failures when (de)serialization is broken. It made stack initialization more cumbersome and it's also getting in the way of ip6tables. PiperOrigin-RevId: 317210653 --- pkg/sentry/socket/netfilter/netfilter.go | 103 ++++++------------------------- pkg/sentry/socket/netstack/stack.go | 6 -- pkg/tcpip/stack/iptables.go | 8 --- pkg/tcpip/stack/iptables_types.go | 16 +---- runsc/boot/loader.go | 2 - 5 files changed, 21 insertions(+), 114 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 66015e2bc..f7abe77d3 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -41,19 +41,6 @@ const errorTargetName = "ERROR" // change the destination port/destination IP for packets. const redirectTargetName = "REDIRECT" -// Metadata is used to verify that we are correctly serializing and -// deserializing iptables into structs consumable by the iptables tool. We save -// a metadata struct when the tables are written, and when they are read out we -// verify that certain fields are the same. -// -// metadata is used by this serialization/deserializing code, not netstack. -type metadata struct { - HookEntry [linux.NF_INET_NUMHOOKS]uint32 - Underflow [linux.NF_INET_NUMHOOKS]uint32 - NumEntries uint32 - Size uint32 -} - // enableLogging controls whether to log the (de)serialization of netfilter // structs between userspace and netstack. These logs are useful when // developing iptables, but can pollute sentry logs otherwise. @@ -83,29 +70,13 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT return linux.IPTGetinfo{}, syserr.FromError(err) } - // Find the appropriate table. - table, err := findTable(stack, info.Name) + _, info, err := convertNetstackToBinary(stack, info.Name) if err != nil { - nflog("%v", err) + nflog("couldn't convert iptables: %v", err) return linux.IPTGetinfo{}, syserr.ErrInvalidArgument } - // Get the hooks that apply to this table. - info.ValidHooks = table.ValidHooks() - - // Grab the metadata struct, which is used to store information (e.g. - // the number of entries) that applies to the user's encoding of - // iptables, but not netstack's. - metadata := table.Metadata().(metadata) - - // Set values from metadata. - info.HookEntry = metadata.HookEntry - info.Underflow = metadata.Underflow - info.NumEntries = metadata.NumEntries - info.Size = metadata.Size - nflog("returning info: %+v", info) - return info, nil } @@ -118,23 +89,13 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen return linux.KernelIPTGetEntries{}, syserr.FromError(err) } - // Find the appropriate table. - table, err := findTable(stack, userEntries.Name) - if err != nil { - nflog("%v", err) - return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument - } - // Convert netstack's iptables rules to something that the iptables // tool can understand. - entries, meta, err := convertNetstackToBinary(userEntries.Name.String(), table) + entries, _, err := convertNetstackToBinary(stack, userEntries.Name) if err != nil { nflog("couldn't read entries: %v", err) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } - if meta != table.Metadata().(metadata) { - panic(fmt.Sprintf("Table %q metadata changed between writing and reading. Was saved as %+v, but is now %+v", userEntries.Name.String(), table.Metadata().(metadata), meta)) - } if binary.Size(entries) > uintptr(outLen) { nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument @@ -143,44 +104,26 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen return entries, nil } -func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) { - table, ok := stk.IPTables().GetTable(tablename.String()) - if !ok { - return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename) - } - return table, nil -} - -// FillIPTablesMetadata populates stack's IPTables with metadata. -func FillIPTablesMetadata(stk *stack.Stack) { - stk.IPTables().ModifyTables(func(tables map[string]stack.Table) { - // In order to fill in the metadata, we have to translate ipt from its - // netstack format to Linux's giant-binary-blob format. - for name, table := range tables { - _, metadata, err := convertNetstackToBinary(name, table) - if err != nil { - panic(fmt.Errorf("Unable to set default IP tables: %v", err)) - } - table.SetMetadata(metadata) - tables[name] = table - } - }) -} - // convertNetstackToBinary converts the iptables as stored in netstack to the // format expected by the iptables tool. Linux stores each table as a binary // blob that can only be traversed by parsing a bit, reading some offsets, // jumping to those offsets, parsing again, etc. -func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelIPTGetEntries, metadata, error) { - // Return values. +func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) { + table, ok := stack.IPTables().GetTable(tablename.String()) + if !ok { + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) + } + var entries linux.KernelIPTGetEntries - var meta metadata + var info linux.IPTGetinfo + info.ValidHooks = table.ValidHooks() // The table name has to fit in the struct. if linux.XT_TABLE_MAXNAMELEN < len(tablename) { - return linux.KernelIPTGetEntries{}, metadata{}, fmt.Errorf("table name %q too long.", tablename) + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) } - copy(entries.Name[:], tablename) + copy(info.Name[:], tablename[:]) + copy(entries.Name[:], tablename[:]) for ruleIdx, rule := range table.Rules { nflog("convert to binary: current offset: %d", entries.Size) @@ -189,14 +132,14 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI for hook, hookRuleIdx := range table.BuiltinChains { if hookRuleIdx == ruleIdx { nflog("convert to binary: found hook %d at offset %d", hook, entries.Size) - meta.HookEntry[hook] = entries.Size + info.HookEntry[hook] = entries.Size } } // Is this a chain underflow point? for underflow, underflowRuleIdx := range table.Underflows { if underflowRuleIdx == ruleIdx { nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size) - meta.Underflow[underflow] = entries.Size + info.Underflow[underflow] = entries.Size } } @@ -251,12 +194,12 @@ func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelI entries.Size += uint32(entry.NextOffset) entries.Entrytable = append(entries.Entrytable, entry) - meta.NumEntries++ + info.NumEntries++ } - nflog("convert to binary: finished with an marshalled size of %d", meta.Size) - meta.Size = entries.Size - return entries, meta, nil + nflog("convert to binary: finished with an marshalled size of %d", info.Size) + info.Size = entries.Size + return entries, info, nil } func marshalTarget(target stack.Target) []byte { @@ -569,12 +512,6 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - table.SetMetadata(metadata{ - HookEntry: replace.HookEntry, - Underflow: replace.Underflow, - NumEntries: replace.NumEntries, - Size: replace.Size, - }) stk.IPTables().ReplaceTable(replace.Name.String(), table) return nil diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index f97f9b6f3..ee11742a6 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -366,11 +365,6 @@ func (s *Stack) IPTables() (*stack.IPTables, error) { return s.Stack.IPTables(), nil } -// FillIPTablesMetadata populates stack's IPTables with metadata. -func (s *Stack) FillIPTablesMetadata() { - netfilter.FillIPTablesMetadata(s.Stack) -} - // Resume implements inet.Stack.Resume. func (s *Stack) Resume() { s.Stack.Resume() diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 4e9b404c8..dc2b77c9d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -173,14 +173,6 @@ func (it *IPTables) ReplaceTable(name string, table Table) { it.tables[name] = table } -// ModifyTables acquires write-lock and calls fn with internal name-to-table -// map. This function can be used to update multiple tables atomically. -func (it *IPTables) ModifyTables(fn func(map[string]Table)) { - it.mu.Lock() - defer it.mu.Unlock() - fn(it.tables) -} - // GetPriorities returns slice of priorities associated with hook. func (it *IPTables) GetPriorities(hook Hook) []string { it.mu.RLock() diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 4a6a5c6f1..72f1dd329 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -95,7 +95,7 @@ type IPTables struct { } // A Table defines a set of chains and hooks into the network stack. It is -// really just a list of rules with some metadata for entrypoints and such. +// really just a list of rules. type Table struct { // Rules holds the rules that make up the table. Rules []Rule @@ -110,10 +110,6 @@ type Table struct { // UserChains holds user-defined chains for the keyed by name. Users // can give their chains arbitrary names. UserChains map[string]int - - // Metadata holds information about the Table that is useful to users - // of IPTables, but not to the netstack IPTables code itself. - metadata interface{} } // ValidHooks returns a bitmap of the builtin hooks for the given table. @@ -125,16 +121,6 @@ func (table *Table) ValidHooks() uint32 { return hooks } -// Metadata returns the metadata object stored in table. -func (table *Table) Metadata() interface{} { - return table.metadata -} - -// SetMetadata sets the metadata object stored in table. -func (table *Table) SetMetadata(metadata interface{}) { - table.metadata = metadata -} - // A Rule is a packet processing rule. It consists of two pieces. First it // contains zero or more matchers, each of which is a specification of which // packets this rule applies to. If there are no matchers in the rule, it diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index c6efcdc83..081db39c1 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -1071,8 +1071,6 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err) } - s.FillIPTablesMetadata() - return &s, nil } -- cgit v1.2.3 From 0c169b6ad598200a57db7bf0f679da1d6cb395c4 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 18 Jun 2020 19:45:13 -0700 Subject: iptables: skip iptables if no rules are set Users that never set iptables rules shouldn't incur the iptables performance cost. Suggested by Ian (@iangudger). PiperOrigin-RevId: 317232921 --- pkg/tcpip/stack/iptables.go | 10 ++++++++++ pkg/tcpip/stack/iptables_types.go | 11 ++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index dc2b77c9d..62d4eb1b6 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -170,6 +170,7 @@ func (it *IPTables) GetTable(name string) (Table, bool) { func (it *IPTables) ReplaceTable(name string, table Table) { it.mu.Lock() defer it.mu.Unlock() + it.modified = true it.tables[name] = table } @@ -201,6 +202,15 @@ const ( // // Precondition: pkt.NetworkHeader is set. func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool { + // Many users never configure iptables. Spare them the cost of rule + // traversal if rules have never been set. + it.mu.RLock() + if !it.modified { + it.mu.RUnlock() + return true + } + it.mu.RUnlock() + // Packets are manipulated only if connection and matching // NAT rule exists. it.connections.HandlePacket(pkt, hook, gso, r) diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 72f1dd329..7026990c4 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -79,11 +79,11 @@ const ( // IPTables holds all the tables for a netstack. type IPTables struct { - // mu protects tables and priorities. + // mu protects tables, priorities, and modified. mu sync.RWMutex - // tables maps table names to tables. User tables have arbitrary names. mu - // needs to be locked for accessing. + // tables maps table names to tables. User tables have arbitrary names. + // mu needs to be locked for accessing. tables map[string]Table // priorities maps each hook to a list of table names. The order of the @@ -91,6 +91,11 @@ type IPTables struct { // hook. mu needs to be locked for accessing. priorities map[Hook][]string + // modified is whether tables have been modified at least once. It is + // used to elide the iptables performance overhead for workloads that + // don't utilize iptables. + modified bool + connections ConnTrackTable } -- cgit v1.2.3 From 2141013dcef04de0591e36462ea997be1b6374d7 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Tue, 23 Jun 2020 19:14:05 -0700 Subject: Add support for SO_REUSEADDR to TCP sockets/endpoints. For TCP sockets, SO_REUSEADDR relaxes the rules for binding addresses. gVisor/netstack already supported a behavior similar to SO_REUSEADDR, but did not allow disabling it. This change brings the SO_REUSEADDR behavior closer to the behavior implemented by Linux and adds a new SO_REUSEADDR disabled behavior. Like Linux, SO_REUSEADDR is now disabled by default. PiperOrigin-RevId: 317984380 --- pkg/tcpip/ports/ports.go | 206 ++++++++++++++++---- pkg/tcpip/ports/ports_test.go | 58 +++++- pkg/tcpip/stack/stack.go | 6 + pkg/tcpip/stack/transport_demuxer.go | 65 ++++++- pkg/tcpip/transport/tcp/accept.go | 140 ++++++++++---- pkg/tcpip/transport/tcp/endpoint.go | 100 +++++----- pkg/tcpip/transport/tcp/endpoint_state.go | 31 +-- pkg/tcpip/transport/tcp/tcp_test.go | 18 ++ pkg/tcpip/transport/udp/endpoint.go | 8 +- .../linux/socket_bind_to_device_sequence.cc | 12 -- test/syscalls/linux/socket_inet_loopback.cc | 210 +++++++++++++++++---- test/syscalls/linux/socket_ip_unbound.cc | 4 - 12 files changed, 633 insertions(+), 225 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index edc29ad27..f6d592eb5 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -52,6 +52,9 @@ type Flags struct { // // LoadBalanced takes precidence over MostRecent. LoadBalanced bool + + // TupleOnly represents TCP SO_REUSEADDR. + TupleOnly bool } // Bits converts the Flags to their bitset form. @@ -63,6 +66,9 @@ func (f Flags) Bits() BitFlags { if f.LoadBalanced { rf |= LoadBalancedFlag } + if f.TupleOnly { + rf |= TupleOnlyFlag + } return rf } @@ -98,6 +104,9 @@ const ( // LoadBalancedFlag represents Flags.LoadBalanced. LoadBalancedFlag + // TupleOnlyFlag represents Flags.TupleOnly. + TupleOnlyFlag + // nextFlag is the value that the next added flag will have. // // It is used to calculate FlagMask below. It is also the number of @@ -106,6 +115,10 @@ const ( // FlagMask is a bit mask for BitFlags. FlagMask = nextFlag - 1 + + // MultiBindFlagMask contains the flags that allow binding the same + // tuple multiple times. + MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag ) // ToFlags converts the bitset into a Flags struct. @@ -113,6 +126,7 @@ func (f BitFlags) ToFlags() Flags { return Flags{ MostRecent: f&MostRecentFlag != 0, LoadBalanced: f&LoadBalancedFlag != 0, + TupleOnly: f&TupleOnlyFlag != 0, } } @@ -175,9 +189,54 @@ func (c FlagCounter) IntersectionRefs() BitFlags { return intersection } +type destination struct { + addr tcpip.Address + port uint16 +} + +func makeDestination(a tcpip.FullAddress) destination { + return destination{ + a.Addr, + a.Port, + } +} + +// portNode is never empty. When it has no elements, it is removed from the +// map that references it. +type portNode map[destination]FlagCounter + +// intersectionRefs calculates the intersection of flag bit values which affect +// the specified destination. +// +// If no destinations are present, all flag values are returned as there are no +// entries to limit possible flag values of a new entry. +// +// In addition to the intersection, the number of intersecting refs is +// returned. +func (p portNode) intersectionRefs(dst destination) (BitFlags, int) { + intersection := FlagMask + var count int + + for d, f := range p { + if d == dst { + intersection &= f.IntersectionRefs() + count++ + continue + } + // Wildcard destinations affect all destinations for TupleOnly. + if d.addr == anyIPAddress || dst.addr == anyIPAddress { + // Only bitwise and the TupleOnlyFlag. + intersection &= ((^TupleOnlyFlag) | f.IntersectionRefs()) + count++ + } + } + + return intersection, count +} + // deviceNode is never empty. When it has no elements, it is removed from the // map that references it. -type deviceNode map[tcpip.NICID]FlagCounter +type deviceNode map[tcpip.NICID]portNode // isAvailable checks whether binding is possible by device. If not binding to a // device, check against all FlagCounters. If binding to a specific device, check @@ -186,17 +245,15 @@ type deviceNode map[tcpip.NICID]FlagCounter // If either of the port reuse flags is enabled on any of the nodes, all nodes // sharing a port must share at least one reuse flag. This matches Linux's // behavior. -func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { +func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst destination) bool { flagBits := flags.Bits() if bindToDevice == 0 { - // Trying to binding all devices. - if flagBits == 0 { - // Can't bind because the (addr,port) is already bound. - return false - } intersection := FlagMask for _, p := range d { - i := p.IntersectionRefs() + i, c := p.intersectionRefs(dst) + if c == 0 { + continue + } intersection &= i if intersection&flagBits == 0 { // Can't bind because the (addr,port) was @@ -210,16 +267,17 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { intersection := FlagMask if p, ok := d[0]; ok { - intersection = p.IntersectionRefs() - if intersection&flagBits == 0 { + var c int + intersection, c = p.intersectionRefs(dst) + if c > 0 && intersection&flagBits == 0 { return false } } if p, ok := d[bindToDevice]; ok { - i := p.IntersectionRefs() + i, c := p.intersectionRefs(dst) intersection &= i - if intersection&flagBits == 0 { + if c > 0 && intersection&flagBits == 0 { return false } } @@ -233,12 +291,12 @@ type bindAddresses map[tcpip.Address]deviceNode // isAvailable checks whether an IP address is available to bind to. If the // address is the "any" address, check all other addresses. Otherwise, just // check against the "any" address and the provided address. -func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID) bool { +func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { if addr == anyIPAddress { // If binding to the "any" address then check that there are no conflicts // with all addresses. for _, d := range b { - if !d.isAvailable(flags, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } @@ -247,14 +305,14 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice // Check that there is no conflict with the "any" address. if d, ok := b[anyIPAddress]; ok { - if !d.isAvailable(flags, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } // Check that this is no conflict with the provided address. if d, ok := b[addr]; ok { - if !d.isAvailable(flags, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } @@ -320,17 +378,17 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) + return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, makeDestination(dest)) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, flags, bindToDevice) { + if !addrs.isAvailable(addr, flags, bindToDevice, dst) { return false } } @@ -342,14 +400,16 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() + dst := makeDestination(dest) + // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice) { + if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) { return 0, tcpip.ErrPortInUse } return port, nil @@ -357,15 +417,16 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice), nil + return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, dst) { return false } + flagBits := flags.Bits() // Reserve port on all network protocols. @@ -381,9 +442,65 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber d = make(deviceNode) m[addr] = d } - n := d[bindToDevice] + p := d[bindToDevice] + if p == nil { + p = make(portNode) + } + n := p[dst] n.AddRef(flagBits) - d[bindToDevice] = n + p[dst] = n + d[bindToDevice] = p + } + + return true +} + +// ReserveTuple adds a port reservation for the tuple on all given protocol. +func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { + flagBits := flags.Bits() + dst := makeDestination(dest) + + s.mu.Lock() + defer s.mu.Unlock() + + // It is easier to undo the entire reservation, so if we find that the + // tuple can't be fully added, finish and undo the whole thing. + undo := false + + // Reserve port on all network protocols. + for _, network := range networks { + desc := portDescriptor{network, transport, port} + m, ok := s.allocatedPorts[desc] + if !ok { + m = make(bindAddresses) + s.allocatedPorts[desc] = m + } + d, ok := m[addr] + if !ok { + d = make(deviceNode) + m[addr] = d + } + p := d[bindToDevice] + if p == nil { + p = make(portNode) + } + + n := p[dst] + if n.TotalRefs() != 0 && n.IntersectionRefs()&flagBits == 0 { + // Tuple already exists. + undo = true + } + n.AddRef(flagBits) + p[dst] = n + d[bindToDevice] = p + } + + if undo { + // releasePortLocked decrements the counts (rather than setting + // them to zero), so it will undo the incorrect incrementing + // above. + s.releasePortLocked(networks, transport, addr, port, flagBits, bindToDevice, dst) + return false } return true @@ -391,12 +508,14 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) { s.mu.Lock() defer s.mu.Unlock() - flagBits := flags.Bits() + s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, makeDestination(dest)) +} +func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags BitFlags, bindToDevice tcpip.NICID, dst destination) { for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { @@ -404,21 +523,32 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp if !ok { continue } - n, ok := d[bindToDevice] + p, ok := d[bindToDevice] if !ok { continue } - n.refs[flagBits]-- - d[bindToDevice] = n - if n.TotalRefs() == 0 { - delete(d, bindToDevice) + n, ok := p[dst] + if !ok { + continue + } + n.DropRef(flags) + if n.TotalRefs() > 0 { + p[dst] = n + continue } - if len(d) == 0 { - delete(m, addr) + delete(p, dst) + if len(p) > 0 { + continue + } + delete(d, bindToDevice) + if len(d) > 0 { + continue } - if len(m) == 0 { - delete(s.allocatedPorts, desc) + delete(m, addr) + if len(m) > 0 { + continue } + delete(s.allocatedPorts, desc) } } } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index d6969d050..58db5868c 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -36,6 +36,7 @@ type portReserveTestAction struct { flags Flags release bool device tcpip.NICID + dest tcpip.FullAddress } func TestPortReservation(t *testing.T) { @@ -272,6 +273,54 @@ func TestPortReservation(t *testing.T) { {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil}, + }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind two tuples with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil}, + }, + }, { + tname: "bind two tuples", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil}, + }, + }, { + tname: "bind wildcard, and then tuple with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind wildcard twice with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil}, + }, }, } { t.Run(test.tname, func(t *testing.T) { @@ -280,19 +329,18 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { - t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) + t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) } } }) - } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 51abe32a7..e92ec0c24 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1408,6 +1408,12 @@ func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.N return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } +// CheckRegisterTransportEndpoint checks if an endpoint can be registered with +// the stack transport dispatcher. +func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice) +} + // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 118b449d5..b902c6ca9 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -221,6 +221,18 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t return multiPortEp.singleRegisterEndpoint(t, flags) } +func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() + + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] + if !ok { + return nil + } + + return multiPortEp.singleCheckEndpoint(flags) +} + // unregisterEndpoint returns true if endpointsByNIC has to be unregistered. func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool { epsByNIC.mu.Lock() @@ -289,6 +301,17 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum return nil } +// checkEndpoint checks if an endpoint can be registered with the dispatcher. +func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + for _, n := range netProtos { + if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil { + return err + } + } + + return nil +} + // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. @@ -380,7 +403,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p ep.mu.Lock() defer ep.mu.Unlock() - bits := flags.Bits() + bits := flags.Bits() & ports.MultiBindFlagMask if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. @@ -395,6 +418,22 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p return nil } +func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error { + ep.mu.RLock() + defer ep.mu.RUnlock() + + bits := flags.Bits() & ports.MultiBindFlagMask + + if len(ep.endpoints) != 0 { + // If it was previously bound, we need to check if we can bind again. + if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { + return tcpip.ErrPortInUse + } + } + + return nil +} + // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool { ep.mu.Lock() @@ -406,7 +445,7 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports ep.endpoints[len(ep.endpoints)-1] = nil ep.endpoints = ep.endpoints[:len(ep.endpoints)-1] - ep.flags.DropRef(flags.Bits()) + ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask) break } } @@ -439,6 +478,28 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) } +func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + if id.RemotePort != 0 { + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false + } + + eps, ok := d.protocol[protocolIDs{netProto, protocol}] + if !ok { + return tcpip.ErrUnknownProtocol + } + + eps.mu.RLock() + defer eps.mu.RUnlock() + + epsByNIC, ok := eps.endpoints[id] + if !ok { + return nil + } + + return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice) +} + // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 7679fe169..6e00e5526 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -199,9 +198,8 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu } // createConnectingEndpoint creates a new endpoint in a connecting state, with -// the connection parameters given by the arguments. The endpoint is returned -// with n.mu held. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { +// the connection parameters given by the arguments. +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -227,22 +225,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // window to grow to a really large value. n.rcvAutoParams.prevCopied = n.initialReceiveWindow() - // Lock the endpoint before registering to ensure that no out of - // band changes are possible due to incoming packets etc till - // the endpoint is done initializing. - n.mu.Lock() - - // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, ports.Flags{LoadBalanced: n.reusePort}, n.boundBindToDevice); err != nil { - n.mu.Unlock() - n.Close() - return nil, err - } - - n.isRegistered = true - n.registeredReusePort = n.reusePort - - return n, nil + return n } // createEndpointAndPerformHandshake creates a new endpoint in connected state @@ -253,10 +236,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) - if err != nil { - return nil, err - } + ep := l.createConnectingEndpoint(s, isn, irs, opts, queue) + + // Lock the endpoint before registering to ensure that no out of + // band changes are possible due to incoming packets etc till + // the endpoint is done initializing. + ep.mu.Lock() ep.owner = owner // listenEP is nil when listenContext is used by tcp.Forwarder. @@ -264,18 +249,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if l.listenEP != nil { l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { + l.listenEP.mu.Unlock() // Ensure we release any registrations done by the newly // created endpoint. ep.mu.Unlock() ep.Close() - // Wake up any waiters. This is strictly not required normally - // as a socket that was never accepted can't really have any - // registered waiters except when stack.Wait() is called which - // waits for all registered endpoints to stop and expects an - // EventHUp. - ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) return nil, tcpip.ErrConnectionAborted } l.addPendingEndpoint(ep) @@ -284,21 +264,44 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // to the newly created endpoint. l.listenEP.propagateInheritableOptionsLocked(ep) + if !ep.reserveTupleLocked() { + ep.mu.Unlock() + ep.Close() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + l.listenEP.mu.Unlock() + } + + return nil, tcpip.ErrConnectionAborted + } + deferAccept = l.listenEP.deferAccept l.listenEP.mu.Unlock() } + // Register new endpoint so that packets are routed to it. + if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + ep.mu.Unlock() + ep.Close() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } + + ep.drainClosingSegmentQueue() + + return nil, err + } + + ep.isRegistered = true + // Perform the 3-way handshake. h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.mu.Unlock() ep.Close() - // Wake up any waiters. This is strictly not required normally - // as a socket that was never accepted can't really have any - // registered waiters except when stack.Wait() is called which - // waits for all registered endpoints to stop and expects an - // EventHUp. - ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + ep.notifyAborted() if l.listenEP != nil { l.removePendingEndpoint(ep) @@ -374,6 +377,43 @@ func (e *endpoint) deliverAccepted(n *endpoint) { // Precondition: e.mu and n.mu must be held. func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.userTimeout = e.userTimeout + n.portFlags = e.portFlags + n.boundBindToDevice = e.boundBindToDevice + n.boundPortFlags = e.boundPortFlags +} + +// reserveTupleLocked reserves an accepted endpoint's tuple. +// +// Preconditions: +// * propagateInheritableOptionsLocked has been called. +// * e.mu is held. +func (e *endpoint) reserveTupleLocked() bool { + dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort} + if !e.stack.ReserveTuple( + e.effectiveNetProtos, + ProtocolNumber, + e.ID.LocalAddress, + e.ID.LocalPort, + e.boundPortFlags, + e.boundBindToDevice, + dest, + ) { + return false + } + + e.isPortReserved = true + e.boundDest = dest + return true +} + +// notifyAborted wakes up any waiters on registered, but not accepted +// endpoints. +// +// This is strictly not required normally as a socket that was never accepted +// can't really have any registered waiters except when stack.Wait() is called +// which waits for all registered endpoints to stop and expects an EventHUp. +func (e *endpoint) notifyAborted() { + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } // handleSynSegment is called in its own goroutine once the listening endpoint @@ -568,16 +608,34 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) - if err != nil { + n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + + n.mu.Lock() + + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + e.propagateInheritableOptionsLocked(n) + + if !n.reserveTupleLocked() { + n.mu.Unlock() + n.Close() + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() return } - // Propagate any inheritable options from the listening endpoint - // to the newly created endpoint. - e.propagateInheritableOptionsLocked(n) + // Register new endpoint so that packets are routed to it. + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + n.mu.Unlock() + n.Close() + + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + return + } + + n.isRegistered = true // clear the tsOffset for the newly created // endpoint as the Timestamp was already diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 10df2bcd5..1e4c2f507 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -396,7 +396,8 @@ type endpoint struct { mu sync.Mutex `state:"nosave"` ownedByUser uint32 - // state must be read/set using the EndpointState()/setEndpointState() methods. + // state must be read/set using the EndpointState()/setEndpointState() + // methods. state EndpointState `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the @@ -405,8 +406,8 @@ type endpoint struct { origEndpointState EndpointState `state:"nosave"` isPortReserved bool `state:"manual"` - isRegistered bool - boundNICID tcpip.NICID `state:"manual"` + isRegistered bool `state:"manual"` + boundNICID tcpip.NICID route stack.Route `state:"manual"` ttl uint8 v6only bool @@ -415,10 +416,14 @@ type endpoint struct { // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // portFlags stores the current values of port related flags. + portFlags ports.Flags + // Values used to reserve a port or register a transport endpoint // (which ever happens first). boundBindToDevice tcpip.NICID boundPortFlags ports.Flags + boundDest tcpip.FullAddress // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -426,7 +431,7 @@ type endpoint struct { // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped // address). - effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"` + effectiveNetProtos []tcpip.NetworkProtocolNumber // workerRunning specifies if a worker goroutine is running. workerRunning bool @@ -462,13 +467,6 @@ type endpoint struct { // sack holds TCP SACK related information for this endpoint. sack SACKInfo - // reusePort is set to true if SO_REUSEPORT is enabled. - reusePort bool - - // registeredReusePort is set if the current endpoint registration was - // done with SO_REUSEPORT enabled. - registeredReusePort bool - // bindToDevice is set to the NIC on which to bind or disabled if 0. bindToDevice tcpip.NICID @@ -488,7 +486,6 @@ type endpoint struct { // The options below aren't implemented, but we remember the user // settings because applications expect to be able to set/query these // options. - reuseAddr bool // slowAck holds the negated state of quick ack. It is stubbed out and // does nothing. @@ -838,7 +835,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue rcvBufSize: DefaultReceiveBufferSize, sndBufSize: DefaultSendBufferSize, sndMTU: int(math.MaxInt32), - reuseAddr: true, keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -1025,15 +1021,15 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false - e.registeredReusePort = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) e.isPortReserved = false e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} + e.boundDest = tcpip.FullAddress{} } // Mark endpoint as closed. @@ -1091,17 +1087,17 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false - e.registeredReusePort = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) e.isPortReserved = false } e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} + e.boundDest = tcpip.FullAddress{} e.route.Release() e.stack.CompleteTransportEndpointCleanup(e) @@ -1522,12 +1518,12 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { case tcpip.ReuseAddressOption: e.LockUser() - e.reuseAddr = v + e.portFlags.TupleOnly = v e.UnlockUser() case tcpip.ReusePortOption: e.LockUser() - e.reusePort = v + e.portFlags.LoadBalanced = v e.UnlockUser() case tcpip.V6OnlyOption: @@ -1831,14 +1827,14 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.ReuseAddressOption: e.LockUser() - v := e.reuseAddr + v := e.portFlags.TupleOnly e.UnlockUser() return v, nil case tcpip.ReusePortOption: e.LockUser() - v := e.reusePort + v := e.portFlags.LoadBalanced e.UnlockUser() return v, nil @@ -2091,8 +2087,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } defer r.Release() - origID := e.ID - netProtos := []tcpip.NetworkProtocolNumber{netProto} e.ID.LocalAddress = r.LocalAddress e.ID.RemoteAddress = r.RemoteAddress @@ -2100,11 +2094,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) if err != nil { return err } - e.registeredReusePort = e.reusePort } else { // The endpoint doesn't have a local port yet, so try to get // one. Make sure that it isn't one that will result in the same @@ -2128,40 +2121,33 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if sameAddr && p == e.ID.RemotePort { return false, nil } - // reusePort is false below because connect cannot reuse a port even if - // reusePort was set. - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { return false, nil } id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, ports.Flags{LoadBalanced: e.reusePort}, e.bindToDevice) { - case nil: - // Port picking successful. Save the details of - // the selected port. - e.ID = id - e.boundBindToDevice = e.bindToDevice - e.registeredReusePort = e.reusePort - return true, nil - case tcpip.ErrPortInUse: - return false, nil - default: + if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr) + if err == tcpip.ErrPortInUse { + return false, nil + } return false, err } + + // Port picking successful. Save the details of + // the selected port. + e.ID = id + e.isPortReserved = true + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = e.portFlags + e.boundDest = addr + return true, nil }); err != nil { return err } } - // Remove the port reservation. This can happen when Bind is called - // before Connect: in such a case we don't want to hold on to - // reservations anymore. - if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice) - e.isPortReserved = false - } - e.isRegistered = true e.setEndpointState(StateConnecting) e.route = r.Clone() @@ -2340,13 +2326,12 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { return err } e.isRegistered = true e.setEndpointState(StateListen) - e.registeredReusePort = e.reusePort // The channel may be non-nil when we're restoring the endpoint, and it // may be pre-populated with some previously accepted (but not Accepted) @@ -2433,16 +2418,13 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } } - flags := ports.Flags{ - LoadBalanced: e.reusePort, - } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) if err != nil { return err } e.boundBindToDevice = e.bindToDevice - e.boundPortFlags = flags + e.boundPortFlags = e.portFlags e.isPortReserved = true e.effectiveNetProtos = netProtos e.ID.LocalPort = port @@ -2450,7 +2432,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Any failures beyond this point must remove the port registration. defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{}) e.isPortReserved = false e.effectiveNetProtos = nil e.ID.LocalPort = 0 @@ -2473,6 +2455,10 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { e.ID.LocalAddress = addr.Addr } + if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil { + return err + } + // Mark endpoint as bound. e.setEndpointState(StateBound) diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 0bebec2d1..8258c0ecc 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -93,10 +93,6 @@ func (e *endpoint) beforeSave() { if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() { panic("endpoint still has waiters upon save") } - - if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) { - panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") - } } // saveAcceptedChan is invoked by stateify. @@ -198,14 +194,17 @@ func (e *endpoint) Resume(s *stack.Stack) { } bind := func() { - if len(e.BindAddr) == 0 { - e.BindAddr = e.ID.LocalAddress + addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}) + if err != nil { + panic("unable to parse BindAddr: " + err.String()) } - addr := e.BindAddr - port := e.ID.LocalPort - if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil { - panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err)) + if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok { + panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest)) } + e.isPortReserved = true + + // Mark endpoint as bound. + e.setEndpointState(StateBound) } switch { @@ -277,17 +276,7 @@ func (e *endpoint) Resume(s *stack.Stack) { tcpip.AsyncLoading.Done() }() case epState == StateClose: - if e.isPortReserved { - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - connectingLoading.Wait() - bind() - e.setEndpointState(StateClose) - tcpip.AsyncLoading.Done() - }() - } + e.isPortReserved = false e.state = StateClose e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index aca6a7951..2632a3c67 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3879,6 +3879,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3888,6 +3891,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3898,6 +3904,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3910,6 +3919,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3920,6 +3932,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3932,6 +3947,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 40d66ef09..6ea212093 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -231,7 +231,7 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} } @@ -1047,7 +1047,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } e.state = StateInitial @@ -1198,7 +1198,7 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) if err != nil { return id, e.bindToDevice, err } @@ -1208,7 +1208,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } return id, e.bindToDevice, err diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc index 4b2cb0606..6557c8e8e 100644 --- a/test/syscalls/linux/socket_bind_to_device_sequence.cc +++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc @@ -363,10 +363,6 @@ TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) { } TEST_P(BindToDeviceSequenceTest, BindWithReuseAddr) { - // FIXME(b/158253835): Support SO_REUSEADDR on TCP sockets. - SKIP_IF(IsRunningOnGvisor() && - (GetParam().type & SOCK_STREAM) == SOCK_STREAM); - ASSERT_NO_FATAL_FAILURE( BindSocket(/* reusePort */ false, /* reuse_addr */ true)); ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, @@ -406,10 +402,6 @@ TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReusePort) { } TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReuseAddr) { - // FIXME(b/158253835): Support SO_REUSEADDR on TCP sockets. - SKIP_IF(IsRunningOnGvisor() && - (GetParam().type & SOCK_STREAM) == SOCK_STREAM); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0)); @@ -436,10 +428,6 @@ TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReusePort) { } TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReuseAddr) { - // FIXME(b/158253835): Support SO_REUSEADDR on TCP sockets. - SKIP_IF(IsRunningOnGvisor() && - (GetParam().type & SOCK_STREAM) == SOCK_STREAM); - ASSERT_NO_FATAL_FAILURE(BindSocket( /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0)); ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index a914d9a12..18b9e4b70 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -686,24 +687,9 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { // be restarted causing the final bind/connect to fail. DisableSave ds; - // TODO(gvisor.dev/issue/1030): Portmanager does not track all 5 tuple - // reservations which causes the bind() to succeed on gVisor but connect - // correctly fails. - if (IsRunningOnGvisor()) { - ASSERT_THAT( - bind(conn_fd2.get(), reinterpret_cast(&conn_bound_addr), - conn_addrlen), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), - reinterpret_cast(&conn_addr), - conn_addrlen), - SyscallFailsWithErrno(EADDRINUSE)); - } else { - ASSERT_THAT( - bind(conn_fd2.get(), reinterpret_cast(&conn_bound_addr), - conn_addrlen), - SyscallFailsWithErrno(EADDRINUSE)); - } + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast(&conn_bound_addr), conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); // Sleep for a little over the linger timeout to reduce flakiness in // save/restore tests. @@ -1737,6 +1723,171 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) { ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast(&addr_v4), test_addr_v4.addr_len), SyscallFailsWithErrno(EADDRINUSE)); + + // Verify that binding the v4 any on the same port with a v4 socket + // fails. + TestAddress const& test_addr_v4_any = V4Any(); + sockaddr_storage addr_v4_any = test_addr_v4_any.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port)); + const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_any.family(), param.type, 0)); + ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast(&addr_v4_any), + test_addr_v4_any.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketMultiProtocolInetLoopbackTest, + DualStackV6AnyReuseAddrDoesNotReserveV4Any) { + auto const& param = GetParam(); + + // Bind the v6 any on a dual stack socket. + TestAddress const& test_addr_dual = V6Any(); + sockaddr_storage addr_dual = test_addr_dual.addr; + const FileDescriptor fd_dual = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0)); + ASSERT_THAT(setsockopt(fd_dual.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast(&addr_dual), + test_addr_dual.addr_len), + SyscallSucceeds()); + + // Get the port that we bound. + socklen_t addrlen = test_addr_dual.addr_len; + ASSERT_THAT(getsockname(fd_dual.get(), + reinterpret_cast(&addr_dual), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); + + // Verify that binding the v4 any on the same port with a v4 socket succeeds. + TestAddress const& test_addr_v4_any = V4Any(); + sockaddr_storage addr_v4_any = test_addr_v4_any.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port)); + const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_any.family(), param.type, 0)); + ASSERT_THAT(setsockopt(fd_v4_any.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast(&addr_v4_any), + test_addr_v4_any.addr_len), + SyscallSucceeds()); +} + +TEST_P(SocketMultiProtocolInetLoopbackTest, + DualStackV6AnyReuseAddrListenReservesV4Any) { + auto const& param = GetParam(); + + // Only TCP sockets are supported. + SKIP_IF((param.type & SOCK_STREAM) == 0); + + // Bind the v6 any on a dual stack socket. + TestAddress const& test_addr_dual = V6Any(); + sockaddr_storage addr_dual = test_addr_dual.addr; + const FileDescriptor fd_dual = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0)); + ASSERT_THAT(setsockopt(fd_dual.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast(&addr_dual), + test_addr_dual.addr_len), + SyscallSucceeds()); + + ASSERT_THAT(listen(fd_dual.get(), 5), SyscallSucceeds()); + + // Get the port that we bound. + socklen_t addrlen = test_addr_dual.addr_len; + ASSERT_THAT(getsockname(fd_dual.get(), + reinterpret_cast(&addr_dual), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); + + // Verify that binding the v4 any on the same port with a v4 socket succeeds. + TestAddress const& test_addr_v4_any = V4Any(); + sockaddr_storage addr_v4_any = test_addr_v4_any.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port)); + const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_any.family(), param.type, 0)); + ASSERT_THAT(setsockopt(fd_v4_any.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast(&addr_v4_any), + test_addr_v4_any.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketMultiProtocolInetLoopbackTest, + DualStackV6AnyWithListenReservesEverything) { + auto const& param = GetParam(); + + // Only TCP sockets are supported. + SKIP_IF((param.type & SOCK_STREAM) == 0); + + // Bind the v6 any on a dual stack socket. + TestAddress const& test_addr_dual = V6Any(); + sockaddr_storage addr_dual = test_addr_dual.addr; + const FileDescriptor fd_dual = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0)); + ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast(&addr_dual), + test_addr_dual.addr_len), + SyscallSucceeds()); + + ASSERT_THAT(listen(fd_dual.get(), 5), SyscallSucceeds()); + + // Get the port that we bound. + socklen_t addrlen = test_addr_dual.addr_len; + ASSERT_THAT(getsockname(fd_dual.get(), + reinterpret_cast(&addr_dual), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); + + // Verify that binding the v6 loopback with the same port fails. + TestAddress const& test_addr_v6 = V6Loopback(); + sockaddr_storage addr_v6 = test_addr_v6.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port)); + const FileDescriptor fd_v6 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0)); + ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast(&addr_v6), + test_addr_v6.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); + + // Verify that binding the v4 loopback on the same port with a v6 socket + // fails. + TestAddress const& test_addr_v4_mapped = V4MappedLoopback(); + sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr; + ASSERT_NO_ERRNO( + SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port)); + const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_mapped.family(), param.type, 0)); + ASSERT_THAT( + bind(fd_v4_mapped.get(), reinterpret_cast(&addr_v4_mapped), + test_addr_v4_mapped.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); + + // Verify that binding the v4 loopback on the same port with a v4 socket + // fails. + TestAddress const& test_addr_v4 = V4Loopback(); + sockaddr_storage addr_v4 = test_addr_v4.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port)); + const FileDescriptor fd_v4 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0)); + ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast(&addr_v4), + test_addr_v4.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); + + // Verify that binding the v4 any on the same port with a v4 socket + // fails. + TestAddress const& test_addr_v4_any = V4Any(); + sockaddr_storage addr_v4_any = test_addr_v4_any.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port)); + const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_any.family(), param.type, 0)); + ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast(&addr_v4_any), + test_addr_v4_any.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); } TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { @@ -1798,10 +1949,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { auto const& param = GetParam(); - // FIXME(b/76031995): Support disabling SO_REUSEADDR for TCP sockets and make - // it disabled by default. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); - for (int i = 0; true; i++) { // Bind the v6 loopback on a dual stack socket. TestAddress const& test_addr = V6Loopback(); @@ -1864,17 +2011,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { test_addr_v6.addr_len), SyscallFailsWithErrno(EADDRINUSE)); - // Verify that binding the v4 any with the same port fails. - TestAddress const& test_addr_v4_any = V4Any(); - sockaddr_storage addr_v4_any = test_addr_v4_any.addr; - ASSERT_NO_ERRNO( - SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, ephemeral_port)); - const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_v4_any.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast(&addr_v4_any), - test_addr_v4_any.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - // Verify that we can still bind the v4 loopback on the same port. TestAddress const& test_addr_v4_mapped = V4MappedLoopback(); sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr; @@ -1963,10 +2099,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) { TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { auto const& param = GetParam(); - // FIXME(b/76031995): Support disabling SO_REUSEADDR for TCP sockets and make - // it disabled by default. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); - for (int i = 0; true; i++) { // Bind the v4 loopback on a dual stack socket. TestAddress const& test_addr = V4MappedLoopback(); @@ -2152,10 +2284,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { auto const& param = GetParam(); - // FIXME(b/76031995): Support disabling SO_REUSEADDR for TCP sockets and make - // it disabled by default. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); - for (int i = 0; true; i++) { // Bind the v4 loopback on a v4 socket. TestAddress const& test_addr = V4Loopback(); diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc index 796598eee..1c7b0cf90 100644 --- a/test/syscalls/linux/socket_ip_unbound.cc +++ b/test/syscalls/linux/socket_ip_unbound.cc @@ -425,10 +425,6 @@ TEST_P(IPUnboundSocketTest, InsufficientBufferTOS) { TEST_P(IPUnboundSocketTest, ReuseAddrDefault) { auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - // FIXME(b/76031995): gVisor TCP sockets default to SO_REUSEADDR enabled. - SKIP_IF(IsRunningOnGvisor() && - (GetParam().type & SOCK_STREAM) == SOCK_STREAM); - int get = -1; socklen_t get_sz = sizeof(get); ASSERT_THAT( -- cgit v1.2.3 From b070e218c6fe61c6ef98e0a3af5ad58d7e627632 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 24 Jun 2020 10:21:44 -0700 Subject: Add support for Stack level options. Linux controls socket send/receive buffers using a few sysctl variables - net.core.rmem_default - net.core.rmem_max - net.core.wmem_max - net.core.wmem_default - net.ipv4.tcp_rmem - net.ipv4.tcp_wmem The first 4 control the default socket buffer sizes for all sockets raw/packet/tcp/udp and also the maximum permitted socket buffer that can be specified in setsockopt(SOL_SOCKET, SO_(RCV|SND)BUF,...). The last two control the TCP auto-tuning limits and override the default specified in rmem_default/wmem_default as well as the max limits. Netstack today only implements tcp_rmem/tcp_wmem and incorrectly uses it to limit the maximum size in setsockopt() as well as uses it for raw/udp sockets. This changelist introduces the other 4 and updates the udp/raw sockets to use the newly introduced variables. The values for min/max match the current tcp_rmem/wmem values and the default value buffers for UDP/RAW sockets is updated to match the linux value of 212KiB up from the really low current value of 32 KiB. Updates #3043 Fixes #3043 PiperOrigin-RevId: 318089805 --- benchmarks/tcp/tcp_proxy.go | 2 +- pkg/sentry/socket/netstack/stack.go | 12 +-- pkg/tcpip/stack/BUILD | 1 + pkg/tcpip/stack/stack.go | 18 ++++ pkg/tcpip/stack/stack_options.go | 106 +++++++++++++++++++++ pkg/tcpip/stack/stack_test.go | 80 ++++++++++++++++ pkg/tcpip/tcpip.go | 27 +----- pkg/tcpip/transport/raw/endpoint.go | 20 ++-- pkg/tcpip/transport/tcp/connect.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 16 ++-- pkg/tcpip/transport/tcp/endpoint_state.go | 10 +- pkg/tcpip/transport/tcp/protocol.go | 53 ++++++++--- pkg/tcpip/transport/tcp/tcp_sack_test.go | 4 +- pkg/tcpip/transport/tcp/tcp_test.go | 18 ++-- pkg/tcpip/transport/tcp/testing/context/context.go | 6 +- pkg/tcpip/transport/udp/endpoint.go | 20 ++-- pkg/tcpip/transport/udp/protocol.go | 49 +--------- runsc/boot/loader.go | 2 +- 18 files changed, 306 insertions(+), 140 deletions(-) create mode 100644 pkg/tcpip/stack/stack_options.go (limited to 'pkg/tcpip/stack') diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go index b3a4dbea3..4b7ca7a14 100644 --- a/benchmarks/tcp/tcp_proxy.go +++ b/benchmarks/tcp/tcp_proxy.go @@ -228,7 +228,7 @@ func newNetstackImpl(mode string) (impl, error) { }) // Set protocol options. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(*sack)); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(*sack)); err != nil { return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %s", err) } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index d2fb655ea..548442b96 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -143,7 +143,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { // TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize. func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { - var rs tcpip.StackReceiveBufferSizeOption + var rs tcp.ReceiveBufferSizeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &rs) return inet.TCPBufferSize{ Min: rs.Min, @@ -154,7 +154,7 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { // SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize. func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { - rs := tcpip.StackReceiveBufferSizeOption{ + rs := tcp.ReceiveBufferSizeOption{ Min: size.Min, Default: size.Default, Max: size.Max, @@ -164,7 +164,7 @@ func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { // TCPSendBufferSize implements inet.Stack.TCPSendBufferSize. func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { - var ss tcpip.StackSendBufferSizeOption + var ss tcp.SendBufferSizeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &ss) return inet.TCPBufferSize{ Min: ss.Min, @@ -175,7 +175,7 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { // SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize. func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { - ss := tcpip.StackSendBufferSizeOption{ + ss := tcp.SendBufferSizeOption{ Min: size.Min, Default: size.Default, Max: size.Max, @@ -185,14 +185,14 @@ func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { // TCPSACKEnabled implements inet.Stack.TCPSACKEnabled. func (s *Stack) TCPSACKEnabled() (bool, error) { - var sack tcpip.StackSACKEnabled + var sack tcp.SACKEnabled err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &sack) return bool(sack), syserr.TranslateNetstackError(err).ToError() } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. func (s *Stack) SetTCPSACKEnabled(enabled bool) error { - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(enabled))).ToError() + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError() } // Statistics implements inet.Stack.Statistics. diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 24f52b735..794ddb5c8 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -48,6 +48,7 @@ go_library( "route.go", "stack.go", "stack_global_state.go", + "stack_options.go", "transport_demuxer.go", ], visibility = ["//visibility:public"], diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index e92ec0c24..cdcfb8321 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -471,6 +471,14 @@ type Stack struct { // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. randomGenerator *mathrand.Rand + + // sendBufferSize holds the min/default/max send buffer sizes for + // endpoints other than TCP. + sendBufferSize SendBufferSizeOption + + // receiveBufferSize holds the min/default/max receive buffer sizes for + // endpoints other than TCP. + receiveBufferSize ReceiveBufferSizeOption } // UniqueID is an abstract generator of unique identifiers. @@ -683,6 +691,16 @@ func New(opts Options) *Stack { tempIIDSeed: opts.TempIIDSeed, forwarder: newForwardQueue(), randomGenerator: mathrand.New(randSrc), + sendBufferSize: SendBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultBufferSize, + Max: DefaultMaxBufferSize, + }, + receiveBufferSize: ReceiveBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultBufferSize, + Max: DefaultMaxBufferSize, + }, } // Add specified network protocols. diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go new file mode 100644 index 000000000..0b093e6c5 --- /dev/null +++ b/pkg/tcpip/stack/stack_options.go @@ -0,0 +1,106 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import "gvisor.dev/gvisor/pkg/tcpip" + +const ( + // MinBufferSize is the smallest size of a receive or send buffer. + MinBufferSize = 4 << 10 // 4 KiB + + // DefaultBufferSize is the default size of the send/recv buffer for a + // transport endpoint. + DefaultBufferSize = 212 << 10 // 212 KiB + + // DefaultMaxBufferSize is the default maximum permitted size of a + // send/receive buffer. + DefaultMaxBufferSize = 4 << 20 // 4 MiB +) + +// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max send buffer sizes. +type SendBufferSizeOption struct { + Min int + Default int + Max int +} + +// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max receive buffer sizes. +type ReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + +// SetOption allows setting stack wide options. +func (s *Stack) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case SendBufferSizeOption: + // Make sure we don't allow lowering the buffer below minimum + // required for stack to work. + if v.Min < MinBufferSize { + return tcpip.ErrInvalidOptionValue + } + + if v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + + s.mu.Lock() + s.sendBufferSize = v + s.mu.Unlock() + return nil + + case ReceiveBufferSizeOption: + // Make sure we don't allow lowering the buffer below minimum + // required for stack to work. + if v.Min < MinBufferSize { + return tcpip.ErrInvalidOptionValue + } + + if v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + + s.mu.Lock() + s.receiveBufferSize = v + s.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// Option allows retrieving stack wide options. +func (s *Stack) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *SendBufferSizeOption: + s.mu.RLock() + *v = s.sendBufferSize + s.mu.RUnlock() + return nil + + case *ReceiveBufferSizeOption: + s.mu.RLock() + *v = s.receiveBufferSize + s.mu.RUnlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 5aacbf53e..7657a4101 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -3338,3 +3338,83 @@ func TestDoDADWhenNICEnabled(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) } } + +func TestStackReceiveBufferSizeOption(t *testing.T) { + const sMin = stack.MinBufferSize + testCases := []struct { + name string + rs stack.ReceiveBufferSizeOption + err *tcpip.Error + }{ + // Invalid configurations. + {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + + // Valid Configurations + {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + defer s.Close() + if err := s.SetOption(tc.rs); err != tc.err { + t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err) + } + var rs stack.ReceiveBufferSizeOption + if tc.err == nil { + if err := s.Option(&rs); err != nil { + t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err) + } + if got, want := rs, tc.rs; got != want { + t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) + } + } + }) + } +} + +func TestStackSendBufferSizeOption(t *testing.T) { + const sMin = stack.MinBufferSize + testCases := []struct { + name string + ss stack.SendBufferSizeOption + err *tcpip.Error + }{ + // Invalid configurations. + {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + + // Valid Configurations + {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + defer s.Close() + if err := s.SetOption(tc.ss); err != tc.err { + t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err) + } + var ss stack.SendBufferSizeOption + if tc.err == nil { + if err := s.Option(&ss); err != nil { + t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) + } + if got, want := ss, tc.ss; got != want { + t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) + } + } + }) + } +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 956232a44..4d45dcc42 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -813,33 +813,8 @@ type OutOfBandInlineOption int // a default TTL. type DefaultTTLOption uint8 -// StackSACKEnabled is used by stack.(*Stack).TransportProtocolOption to -// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. -type StackSACKEnabled bool - -// StackDelayEnabled is used by stack.(Stack*).TransportProtocolOption to -// enable/disable Nagle's algorithm in TCP. -type StackDelayEnabled bool - -// StackSendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption -// to get/set the default, min and max send buffer sizes. -type StackSendBufferSizeOption struct { - Min int - Default int - Max int -} - -// StackReceiveBufferSizeOption is used by -// stack.(Stack*).TransportProtocolOption to get/set the default, min and max -// receive buffer sizes. -type StackReceiveBufferSizeOption struct { - Min int - Default int - Max int -} - // -// IPPacketInfo is the message struture for IP_PKTINFO. +// IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable type IPPacketInfo struct { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 6a7977259..dd514d397 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -111,13 +111,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt } // Override with stack defaults. - var ss tcpip.StackSendBufferSizeOption - if err := s.TransportProtocolOption(transProto, &ss); err == nil { + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { e.sndBufSizeMax = ss.Default } - var rs tcpip.StackReceiveBufferSizeOption - if err := s.TransportProtocolOption(transProto, &rs); err == nil { + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { e.rcvBufSizeMax = rs.Default } @@ -541,9 +541,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss tcpip.StackSendBufferSizeOption - if err := e.stack.TransportProtocolOption(e.TransProto, &ss); err != nil { - panic(fmt.Sprintf("s.TransportProtocolOption(%d, %+v) = %s", e.TransProto, ss, err)) + var ss stack.SendBufferSizeOption + if err := e.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) } if v > ss.Max { v = ss.Max @@ -559,9 +559,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs tcpip.StackReceiveBufferSizeOption - if err := e.stack.TransportProtocolOption(e.TransProto, &rs); err != nil { - panic(fmt.Sprintf("s.TransportProtocolOption(%d, %+v) = %s", e.TransProto, rs, err)) + var rs stack.ReceiveBufferSizeOption + if err := e.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) } if v > rs.Max { v = rs.Max diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 377643b82..9d4dce826 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -521,7 +521,7 @@ func (h *handshake) execute() *tcpip.Error { s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) defer s.Done() - var sackEnabled tcpip.StackSACKEnabled + var sackEnabled SACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled // status then just default to switching off SACK negotiation. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 1e4c2f507..99a691815 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -847,12 +847,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue maxSynRetries: DefaultSynRetries, } - var ss tcpip.StackSendBufferSizeOption + var ss SendBufferSizeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } @@ -867,7 +867,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.rcvAutoParams.disabled = !bool(mrb) } - var de tcpip.StackDelayEnabled + var de DelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { e.SetSockOptBool(tcpip.DelayOption, true) } @@ -1584,7 +1584,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min { v = rs.Min @@ -1634,7 +1634,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss tcpip.StackSendBufferSizeOption + var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if v < ss.Min { v = ss.Min @@ -1674,7 +1674,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrInvalidOptionValue } } - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min/2 { v = rs.Min / 2 @@ -2595,7 +2595,7 @@ func (e *endpoint) receiveBufferSize() int { } func (e *endpoint) maxReceiveBufferSize() int { - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { // As a fallback return the hardcoded max buffer size. return MaxBufferSize @@ -2676,7 +2676,7 @@ func timeStampOffset() uint32 { // if the SYN options indicate that the SACK option was negotiated and the TCP // stack is configured to enable TCP SACK option. func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { - var v tcpip.StackSACKEnabled + var v SACKEnabled if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 8258c0ecc..abf1ac5c9 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -182,13 +182,17 @@ func (e *endpoint) Resume(s *stack.Stack) { epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: - var ss tcpip.StackSendBufferSizeOption + var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) } - if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) + } + + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { + panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) } } } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 3cff55afa..f2ae6ce50 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -76,6 +76,31 @@ const ( ccCubic = "cubic" ) +// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to +// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. +type SACKEnabled bool + +// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to +// enable/disable Nagle's algorithm in TCP. +type DelayEnabled bool + +// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption +// to get/set the default, min and max TCP send buffer sizes. +type SendBufferSizeOption struct { + Min int + Default int + Max int +} + +// ReceiveBufferSizeOption is used by +// stack.(Stack*).TransportProtocolOption to get/set the default, min and max +// TCP receive buffer sizes. +type ReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + // syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The // value is protected by a mutex so that we can increment only when it's // guaranteed not to go above a threshold. @@ -137,8 +162,8 @@ type protocol struct { mu sync.RWMutex sackEnabled bool delayEnabled bool - sendBufferSize tcpip.StackSendBufferSizeOption - recvBufferSize tcpip.StackReceiveBufferSizeOption + sendBufferSize SendBufferSizeOption + recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool @@ -249,19 +274,19 @@ func replyWithReset(s *segment, tos, ttl uint8) { // SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { - case tcpip.StackSACKEnabled: + case SACKEnabled: p.mu.Lock() p.sackEnabled = bool(v) p.mu.Unlock() return nil - case tcpip.StackDelayEnabled: + case DelayEnabled: p.mu.Lock() p.delayEnabled = bool(v) p.mu.Unlock() return nil - case tcpip.StackSendBufferSizeOption: + case SendBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } @@ -270,7 +295,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil - case tcpip.StackReceiveBufferSizeOption: + case ReceiveBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } @@ -363,25 +388,25 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { // Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { - case *tcpip.StackSACKEnabled: + case *SACKEnabled: p.mu.RLock() - *v = tcpip.StackSACKEnabled(p.sackEnabled) + *v = SACKEnabled(p.sackEnabled) p.mu.RUnlock() return nil - case *tcpip.StackDelayEnabled: + case *DelayEnabled: p.mu.RLock() - *v = tcpip.StackDelayEnabled(p.delayEnabled) + *v = DelayEnabled(p.delayEnabled) p.mu.RUnlock() return nil - case *tcpip.StackSendBufferSizeOption: + case *SendBufferSizeOption: p.mu.RLock() *v = p.sendBufferSize p.mu.RUnlock() return nil - case *tcpip.StackReceiveBufferSizeOption: + case *ReceiveBufferSizeOption: p.mu.RLock() *v = p.recvBufferSize p.mu.RUnlock() @@ -491,12 +516,12 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ - sendBufferSize: tcpip.StackSendBufferSizeOption{ + sendBufferSize: SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize, }, - recvBufferSize: tcpip.StackReceiveBufferSizeOption{ + recvBufferSize: ReceiveBufferSizeOption{ Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize, diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 812e503bc..99521f0c1 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -46,8 +46,8 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { t.Helper() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, StackSACKEnabled(%t) = %s", enable, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 2632a3c67..169adb16b 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4005,7 +4005,7 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{ + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{ Min: 1, Default: tcp.DefaultSendBufferSize * 2, Max: tcp.DefaultSendBufferSize * 20}); err != nil { @@ -4022,7 +4022,7 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{ + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{ Min: 1, Default: tcp.DefaultReceiveBufferSize * 3, Max: tcp.DefaultReceiveBufferSize * 30}); err != nil { @@ -4053,11 +4053,11 @@ func TestMinMaxBufferSizes(t *testing.T) { defer ep.Close() // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5696,7 +5696,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // the segment queue holding unprocessed packets is limited to 500. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5817,7 +5817,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5959,7 +5959,7 @@ func TestDelayEnabled(t *testing.T) { checkDelayOption(t, c, false, false) // Delay is disabled by default. for _, v := range []struct { - delayEnabled tcpip.StackDelayEnabled + delayEnabled tcp.DelayEnabled wantDelayOption bool }{ {delayEnabled: false, wantDelayOption: false}, @@ -5974,10 +5974,10 @@ func TestDelayEnabled(t *testing.T) { } } -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.StackDelayEnabled, wantDelayOption bool) { +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { t.Helper() - var gotDelayEnabled tcpip.StackDelayEnabled + var gotDelayEnabled tcp.DelayEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 9e262c272..06fde2a79 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -144,11 +144,11 @@ func New(t *testing.T, mtu uint32) *Context { }) // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -1091,7 +1091,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true // for the Stack in the context. func (c *Context) SACKEnabled() bool { - var v tcpip.StackSACKEnabled + var v tcp.SACKEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return false diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 6ea212093..8bdc1ee1f 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -190,13 +190,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } // Override with stack defaults. - var ss tcpip.StackSendBufferSizeOption - if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { e.sndBufSizeMax = ss.Default } - var rs tcpip.StackReceiveBufferSizeOption - if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { e.rcvBufSizeMax = rs.Default } @@ -629,9 +629,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs tcpip.StackReceiveBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %s", ProtocolNumber, rs, err)) + var rs stack.ReceiveBufferSizeOption + if err := e.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err)) } if v < rs.Min { @@ -648,9 +648,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss tcpip.StackSendBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %s", ProtocolNumber, ss, err)) + var ss stack.SendBufferSizeOption + if err := e.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err)) } if v < ss.Min { diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index fc93f93c0..0e7464e3a 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -21,7 +21,6 @@ package udp import ( - "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -50,9 +49,6 @@ const ( ) type protocol struct { - mu sync.RWMutex - sendBufferSize tcpip.StackSendBufferSizeOption - recvBufferSize tcpip.StackReceiveBufferSizeOption } // Number returns the udp protocol number. @@ -203,48 +199,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans // SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { - switch v := option.(type) { - case tcpip.StackSendBufferSizeOption: - if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue - } - p.mu.Lock() - p.sendBufferSize = v - p.mu.Unlock() - return nil - - case tcpip.StackReceiveBufferSizeOption: - if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue - } - p.mu.Lock() - p.recvBufferSize = v - p.mu.Unlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { - switch v := option.(type) { - case *tcpip.StackSendBufferSizeOption: - p.mu.RLock() - *v = p.sendBufferSize - p.mu.RUnlock() - return nil - - case *tcpip.StackReceiveBufferSizeOption: - p.mu.RLock() - *v = p.recvBufferSize - p.mu.RUnlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } // Close implements stack.TransportProtocol.Close. @@ -267,8 +227,5 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { - return &protocol{ - sendBufferSize: tcpip.StackSendBufferSizeOption{Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize}, - recvBufferSize: tcpip.StackReceiveBufferSizeOption{Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize}, - } + return &protocol{} } diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 081db39c1..b5df1deb9 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -1058,7 +1058,7 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in })} // Enable SACK Recovery. - if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(true)); err != nil { + if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil { return nil, fmt.Errorf("failed to enable SACK: %s", err) } -- cgit v1.2.3 From 7fb6cc286fffab2f408a5dbc228e0db706104682 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 25 Jun 2020 21:20:29 -0700 Subject: conntrack refactor, no behavior changes - Split connTrackForPacket into 2 functions instead of switching on flag - Replace hash with struct keys. - Remove prefixes where possible - Remove unused connStatus, timeout - Flatten ConnTrack struct a bit - some intermediate structs had no meaning outside of the context of their parent. - Protect conn.tcb with a mutex - Remove redundant error checking (e.g. when is pkt.NetworkHeader valid) - Clarify that HandlePacket and CreateConnFor are the expected entrypoints for ConnTrack PiperOrigin-RevId: 318407168 --- pkg/sentry/socket/netfilter/targets.go | 2 +- pkg/tcpip/stack/conntrack.go | 405 ++++++++++++--------------------- pkg/tcpip/stack/iptables.go | 7 +- pkg/tcpip/stack/iptables_targets.go | 23 +- pkg/tcpip/stack/iptables_types.go | 4 +- 5 files changed, 168 insertions(+), 273 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 84abe8d29..b91ba3ab3 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -30,6 +30,6 @@ type JumpTarget struct { } // Action implements stack.Target.Action. -func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrackTable, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 05bf62788..af9c325ca 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -15,39 +15,29 @@ package stack import ( - "encoding/binary" "sync" - "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" ) // Connection tracking is used to track and manipulate packets for NAT rules. -// The connection is created for a packet if it does not exist. Every connection -// contains two tuples (original and reply). The tuples are manipulated if there -// is a matching NAT rule. The packet is modified by looking at the tuples in the -// Prerouting and Output hooks. +// The connection is created for a packet if it does not exist. Every +// connection contains two tuples (original and reply). The tuples are +// manipulated if there is a matching NAT rule. The packet is modified by +// looking at the tuples in the Prerouting and Output hooks. +// +// Currently, only TCP tracking is supported. // Direction of the tuple. -type ctDirection int +type direction int const ( - dirOriginal ctDirection = iota + dirOriginal direction = iota dirReply ) -// Status of connection. -// TODO(gvisor.dev/issue/170): Add other states of connection. -type connStatus int - -const ( - connNew connStatus = iota - connEstablished -) - // Manipulation type for the connection. type manipType int @@ -56,250 +46,171 @@ const ( manipDstOutput ) -// connTrackMutable is the manipulatable part of the tuple. -type connTrackMutable struct { - // addr is source address of the tuple. - addr tcpip.Address +// tuple holds a connection's identifying and manipulating data in one +// direction. It is immutable. +type tuple struct { + tupleID - // port is source port of the tuple. - port uint16 + // conn is the connection tracking entry this tuple belongs to. + conn *conn - // protocol is network layer protocol. - protocol tcpip.NetworkProtocolNumber + // direction is the direction of the tuple. + direction direction } -// connTrackImmutable is the non-manipulatable part of the tuple. -type connTrackImmutable struct { - // addr is destination address of the tuple. - addr tcpip.Address - - // direction is direction (original or reply) of the tuple. - direction ctDirection - - // port is destination port of the tuple. - port uint16 - - // protocol is transport layer protocol. - protocol tcpip.TransportProtocolNumber +// tupleID uniquely identifies a connection in one direction. It currently +// contains enough information to distinguish between any TCP or UDP +// connection, and will need to be extended to support other protocols. +type tupleID struct { + srcAddr tcpip.Address + srcPort uint16 + dstAddr tcpip.Address + dstPort uint16 + transProto tcpip.TransportProtocolNumber + netProto tcpip.NetworkProtocolNumber } -// connTrackTuple represents the tuple which is created from the -// packet. -type connTrackTuple struct { - // dst is non-manipulatable part of the tuple. - dst connTrackImmutable - - // src is manipulatable part of the tuple. - src connTrackMutable +// reply creates the reply tupleID. +func (ti tupleID) reply() tupleID { + return tupleID{ + srcAddr: ti.dstAddr, + srcPort: ti.dstPort, + dstAddr: ti.srcAddr, + dstPort: ti.srcPort, + transProto: ti.transProto, + netProto: ti.netProto, + } } -// connTrackTupleHolder is the container of tuple and connection. -type ConnTrackTupleHolder struct { - // conn is pointer to the connection tracking entry. - conn *connTrack - - // tuple is original or reply tuple. - tuple connTrackTuple -} +// conn is a tracked connection. +type conn struct { + // original is the tuple in original direction. It is immutable. + original tuple -// connTrack is the connection. -type connTrack struct { - // originalTupleHolder contains tuple in original direction. - originalTupleHolder ConnTrackTupleHolder + // reply is the tuple in reply direction. It is immutable. + reply tuple - // replyTupleHolder contains tuple in reply direction. - replyTupleHolder ConnTrackTupleHolder + // manip indicates if the packet should be manipulated. It is immutable. + manip manipType - // status indicates connection is new or established. - status connStatus + // tcbHook indicates if the packet is inbound or outbound to + // update the state of tcb. It is immutable. + tcbHook Hook - // timeout indicates the time connection should be active. - timeout time.Duration - - // manip indicates if the packet should be manipulated. - manip manipType + // mu protects tcb. + mu sync.Mutex // tcb is TCB control block. It is used to keep track of states - // of tcp connection. + // of tcp connection and is protected by mu. tcb tcpconntrack.TCB - - // tcbHook indicates if the packet is inbound or outbound to - // update the state of tcb. - tcbHook Hook } -// ConnTrackTable contains a map of all existing connections created for -// NAT rules. -type ConnTrackTable struct { - // connMu protects connTrackTable. - connMu sync.RWMutex - - // connTrackTable maintains a map of tuples needed for connection tracking - // for iptables NAT rules. The key for the map is an integer calculated - // using seed, source address, destination address, source port and - // destination port. - CtMap map[uint32]ConnTrackTupleHolder - - // seed is a one-time random value initialized at stack startup - // and is used in calculation of hash key for connection tracking - // table. - Seed uint32 -} +// ConnTrack tracks all connections created for NAT rules. Most users are +// expected to only call handlePacket and createConnFor. +type ConnTrack struct { + // mu protects conns. + mu sync.RWMutex -// packetToTuple converts packet to a tuple in original direction. -func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { - var tuple connTrackTuple + // conns maintains a map of tuples needed for connection tracking for + // iptables NAT rules. It is protected by mu. + conns map[tupleID]tuple +} - netHeader := header.IPv4(pkt.NetworkHeader) +// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid +// TCP header. +func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { // TODO(gvisor.dev/issue/170): Need to support for other // protocols as well. + netHeader := header.IPv4(pkt.NetworkHeader) if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tuple, tcpip.ErrUnknownProtocol + return tupleID{}, tcpip.ErrUnknownProtocol } tcpHeader := header.TCP(pkt.TransportHeader) if tcpHeader == nil { - return tuple, tcpip.ErrUnknownProtocol + return tupleID{}, tcpip.ErrUnknownProtocol } - tuple.src.addr = netHeader.SourceAddress() - tuple.src.port = tcpHeader.SourcePort() - tuple.src.protocol = header.IPv4ProtocolNumber - - tuple.dst.addr = netHeader.DestinationAddress() - tuple.dst.port = tcpHeader.DestinationPort() - tuple.dst.protocol = netHeader.TransportProtocol() - - return tuple, nil -} - -// getReplyTuple creates reply tuple for the given tuple. -func getReplyTuple(tuple connTrackTuple) connTrackTuple { - var replyTuple connTrackTuple - replyTuple.src.addr = tuple.dst.addr - replyTuple.src.port = tuple.dst.port - replyTuple.src.protocol = tuple.src.protocol - replyTuple.dst.addr = tuple.src.addr - replyTuple.dst.port = tuple.src.port - replyTuple.dst.protocol = tuple.dst.protocol - replyTuple.dst.direction = dirReply - - return replyTuple + return tupleID{ + srcAddr: netHeader.SourceAddress(), + srcPort: tcpHeader.SourcePort(), + dstAddr: netHeader.DestinationAddress(), + dstPort: tcpHeader.DestinationPort(), + transProto: netHeader.TransportProtocol(), + netProto: header.IPv4ProtocolNumber, + }, nil } -// makeNewConn creates new connection. -func makeNewConn(tuple, replyTuple connTrackTuple) connTrack { - var conn connTrack - conn.status = connNew - conn.originalTupleHolder.tuple = tuple - conn.originalTupleHolder.conn = &conn - conn.replyTupleHolder.tuple = replyTuple - conn.replyTupleHolder.conn = &conn - - return conn -} - -// getTupleHash returns hash of the tuple. The fields used for -// generating hash are seed (generated once for stack), source address, -// destination address, source port and destination ports. -func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 { - h := jenkins.Sum32(ct.Seed) - h.Write([]byte(tuple.src.addr)) - h.Write([]byte(tuple.dst.addr)) - portBuf := make([]byte, 2) - binary.LittleEndian.PutUint16(portBuf, tuple.src.port) - h.Write([]byte(portBuf)) - binary.LittleEndian.PutUint16(portBuf, tuple.dst.port) - h.Write([]byte(portBuf)) - - return h.Sum32() +// newConn creates new connection. +func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { + conn := conn{ + manip: manip, + tcbHook: hook, + } + conn.original = tuple{conn: &conn, tupleID: orig} + conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} + return &conn } -// connTrackForPacket returns connTrack for packet. -// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other -// transport protocols. -func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) { - var dir ctDirection - tuple, err := packetToTuple(pkt, hook) +// connFor gets the conn for pkt if it exists, or returns nil +// if it does not. It returns an error when pkt does not contain a valid TCP +// header. +// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support +// other transport protocols. +func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { + tid, err := packetToTupleID(pkt) if err != nil { - return nil, dir + return nil, dirOriginal } - ct.connMu.Lock() - defer ct.connMu.Unlock() - - connTrackTable := ct.CtMap - hash := ct.getTupleHash(tuple) - - var conn *connTrack - switch createConn { - case true: - // If connection does not exist for the hash, create a new - // connection. - replyTuple := getReplyTuple(tuple) - replyHash := ct.getTupleHash(replyTuple) - newConn := makeNewConn(tuple, replyTuple) - conn = &newConn - - // Add tupleHolders to the map. - // TODO(gvisor.dev/issue/170): Need to support collisions using linked list. - ct.CtMap[hash] = conn.originalTupleHolder - ct.CtMap[replyHash] = conn.replyTupleHolder - default: - tupleHolder, ok := connTrackTable[hash] - if !ok { - return nil, dir - } - - // If this is the reply of new connection, set the connection - // status as ESTABLISHED. - conn = tupleHolder.conn - if conn.status == connNew && tupleHolder.tuple.dst.direction == dirReply { - conn.status = connEstablished - } - if tupleHolder.conn == nil { - panic("tupleHolder has null connection tracking entry") - } + ct.mu.Lock() + defer ct.mu.Unlock() - dir = tupleHolder.tuple.dst.direction + tuple, ok := ct.conns[tid] + if !ok { + return nil, dirOriginal } - return conn, dir + return tuple.conn, tuple.direction } -// SetNatInfo will manipulate the tuples according to iptables NAT rules. -func (ct *ConnTrackTable) SetNatInfo(pkt *PacketBuffer, rt RedirectTarget, hook Hook) { - // Get the connection. Connection is always created before this - // function is called. - conn, _ := ct.connTrackForPacket(pkt, hook, false) - if conn == nil { - panic("connection should be created to manipulate tuples.") +// createConnFor creates a new conn for pkt. +func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { + tid, err := packetToTupleID(pkt) + if err != nil { + return nil + } + if hook != Prerouting && hook != Output { + return nil } - replyTuple := conn.replyTupleHolder.tuple - replyHash := ct.getTupleHash(replyTuple) - - // TODO(gvisor.dev/issue/170): Support only redirect of ports. Need to - // support changing of address for Prerouting. - // Change the port as per the iptables rule. This tuple will be used - // to manipulate the packet in HandlePacket. - conn.replyTupleHolder.tuple.src.addr = rt.MinIP - conn.replyTupleHolder.tuple.src.port = rt.MinPort - newHash := ct.getTupleHash(conn.replyTupleHolder.tuple) + // Create a new connection and change the port as per the iptables + // rule. This tuple will be used to manipulate the packet in + // handlePacket. + replyTID := tid.reply() + replyTID.srcAddr = rt.MinIP + replyTID.srcPort = rt.MinPort + var manip manipType + switch hook { + case Prerouting: + manip = manipDstPrerouting + case Output: + manip = manipDstOutput + } + conn := newConn(tid, replyTID, manip, hook) // Add the changed tuple to the map. - ct.connMu.Lock() - defer ct.connMu.Unlock() - ct.CtMap[newHash] = conn.replyTupleHolder - if hook == Output { - conn.replyTupleHolder.conn.manip = manipDstOutput - } + // TODO(gvisor.dev/issue/170): Need to support collisions using linked + // list. + ct.mu.Lock() + defer ct.mu.Unlock() + ct.conns[tid] = conn.original + ct.conns[replyTID] = conn.reply - // Delete the old tuple. - delete(ct.CtMap, replyHash) + return conn } // handlePacketPrerouting manipulates ports for packets in Prerouting hook. -// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.. -func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) { +// TODO(gvisor.dev/issue/170): Change address for Prerouting hook. +func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -308,13 +219,13 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) // modified. switch dir { case dirOriginal: - port := conn.replyTupleHolder.tuple.src.port + port := conn.reply.srcPort tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + netHeader.SetDestinationAddress(conn.reply.srcAddr) case dirReply: - port := conn.originalTupleHolder.tuple.dst.port + port := conn.original.dstPort tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + netHeader.SetSourceAddress(conn.original.dstAddr) } netHeader.SetChecksum(0) @@ -322,7 +233,7 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) } // handlePacketOutput manipulates ports for packets in Output hook. -func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, dir ctDirection) { +func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -331,13 +242,13 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, // modified. For prerouting redirection, we only reach this point // when replying, so packet sources are modified. if conn.manip == manipDstOutput && dir == dirOriginal { - port := conn.replyTupleHolder.tuple.src.port + port := conn.reply.srcPort tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + netHeader.SetDestinationAddress(conn.reply.srcAddr) } else { - port := conn.originalTupleHolder.tuple.dst.port + port := conn.original.dstPort tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + netHeader.SetSourceAddress(conn.original.dstAddr) } // Calculate the TCP checksum and set it. @@ -356,9 +267,9 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, netHeader.SetChecksum(^netHeader.CalculateChecksum()) } -// HandlePacket will manipulate the port and address of the packet if the +// handlePacket will manipulate the port and address of the packet if the // connection exists. -func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { if pkt.NatDone { return } @@ -367,21 +278,9 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r return } - conn, dir := ct.connTrackForPacket(pkt, hook, false) - // Connection or Rule not found for the packet. + conn, dir := ct.connFor(pkt) if conn == nil { - return - } - - netHeader := header.IPv4(pkt.NetworkHeader) - // TODO(gvisor.dev/issue/170): Need to support for other transport - // protocols as well. - if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { - return - } - - tcpHeader := header.TCP(pkt.TransportHeader) - if tcpHeader == nil { + // Connection not found for the packet or the packet is invalid. return } @@ -396,7 +295,10 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r // Update the state of tcb. // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle // other tcp states. + conn.mu.Lock() + defer conn.mu.Unlock() var st tcpconntrack.Result + tcpHeader := header.TCP(pkt.TransportHeader) if conn.tcb.IsEmpty() { conn.tcb.Init(tcpHeader) conn.tcbHook = hook @@ -409,26 +311,21 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r } } - // Delete conntrack if tcp connection is closed. + // Delete conn if tcp connection is closed. if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset { - ct.deleteConnTrack(conn) + ct.deleteConn(conn) } } -// deleteConnTrack deletes the connection. -func (ct *ConnTrackTable) deleteConnTrack(conn *connTrack) { +// deleteConn deletes the connection. +func (ct *ConnTrack) deleteConn(conn *conn) { if conn == nil { return } - tuple := conn.originalTupleHolder.tuple - hash := ct.getTupleHash(tuple) - replyTuple := conn.replyTupleHolder.tuple - replyHash := ct.getTupleHash(replyTuple) - - ct.connMu.Lock() - defer ct.connMu.Unlock() + ct.mu.Lock() + defer ct.mu.Unlock() - delete(ct.CtMap, hash) - delete(ct.CtMap, replyHash) + delete(ct.conns, conn.original.tupleID) + delete(ct.conns, conn.reply.tupleID) } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 62d4eb1b6..974d77c36 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -111,9 +111,8 @@ func DefaultTables() *IPTables { Prerouting: []string{TablenameMangle, TablenameNat}, Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, }, - connections: ConnTrackTable{ - CtMap: make(map[uint32]ConnTrackTupleHolder), - Seed: generateRandUint32(), + connections: ConnTrack{ + conns: make(map[tupleID]tuple), }, } } @@ -213,7 +212,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Packets are manipulated only if connection and matching // NAT rule exists. - it.connections.HandlePacket(pkt, hook, gso, r) + it.connections.handlePacket(pkt, hook, gso, r) // Go through each table containing the hook. for _, tablename := range it.GetPriorities(hook) { diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 92e31643e..d43f60c67 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -24,7 +24,7 @@ import ( type AcceptTarget struct{} // Action implements Target.Action. -func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } @@ -32,7 +32,7 @@ func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, t type DropTarget struct{} // Action implements Target.Action. -func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } @@ -41,7 +41,7 @@ func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcp type ErrorTarget struct{} // Action implements Target.Action. -func (ErrorTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -52,7 +52,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -61,7 +61,7 @@ func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route type ReturnTarget struct{} // Action implements Target.Action. -func (ReturnTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } @@ -92,7 +92,7 @@ type RedirectTarget struct { // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for PREROUTING and calls pkt.Clone(), neither // of which should be the case. -func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -150,12 +150,11 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook return RuleAccept, 0 } - // Set up conection for matching NAT rule. - // Only the first packet of the connection comes here. - // Other packets will be manipulated in connection tracking. - if conn, _ := ct.connTrackForPacket(pkt, hook, true); conn != nil { - ct.SetNatInfo(pkt, rt, hook) - ct.HandlePacket(pkt, hook, gso, r) + // Set up conection for matching NAT rule. Only the first + // packet of the connection comes here. Other packets will be + // manipulated in connection tracking. + if conn := ct.createConnFor(pkt, hook, rt); conn != nil { + ct.handlePacket(pkt, hook, gso, r) } default: return RuleDrop, 0 diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 7026990c4..c528ec381 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -96,7 +96,7 @@ type IPTables struct { // don't utilize iptables. modified bool - connections ConnTrackTable + connections ConnTrack } // A Table defines a set of chains and hooks into the network stack. It is @@ -249,5 +249,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(packet *PacketBuffer, connections *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) + Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) } -- cgit v1.2.3 From 1e5b0a973293d9a6bf7b60349763ee3e71343dea Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Mon, 6 Jul 2020 12:09:03 -0700 Subject: Shard some slow tests. stack_x_test: 2m -> 20s tcp_x_test: 80s -> 25s PiperOrigin-RevId: 319828101 --- pkg/tcpip/stack/BUILD | 1 + pkg/tcpip/transport/tcp/BUILD | 1 + 2 files changed, 2 insertions(+) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 794ddb5c8..e65c731c2 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -79,6 +79,7 @@ go_test( "transport_demuxer_test.go", "transport_test.go", ], + shard_count = 20, deps = [ ":stack", "//pkg/rand", diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 6baeda8e4..18ff89ffc 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -86,6 +86,7 @@ go_test( "tcp_test.go", "tcp_timestamp_test.go", ], + shard_count = 10, deps = [ ":tcp", "//pkg/sync", -- cgit v1.2.3 From 7e4d2d63eef3f643222181b3a227bfc7ebe44db2 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Tue, 7 Jul 2020 15:55:40 -0700 Subject: icmp: When setting TransportHeader, remove from the Data portion. The current convention is when a header is set to pkt.XxxHeader field, it gets removed from pkt.Data. ICMP does not currently follow this convention. PiperOrigin-RevId: 320078606 --- pkg/tcpip/stack/nic.go | 7 +++++-- pkg/tcpip/transport/icmp/endpoint.go | 12 +++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index afb7dfeaf..7b80534e6 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1358,16 +1358,19 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // TransportHeader is nil only when pkt is an ICMP packet or was reassembled // from fragments. if pkt.TransportHeader == nil { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their - // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { + // ICMP packets may be longer, but until icmp.Parse is implemented, here + // we parse it using the minimum size. transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } pkt.TransportHeader = transHeader + pkt.Data.TrimFront(len(pkt.TransportHeader)) } else { // This is either a bad packet or was re-assembled from fragments. transProto.Parse(pkt) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 8ce294002..62d1acad4 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -744,15 +744,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.TransportHeader) + if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) - if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.TransportHeader) + if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return @@ -786,7 +786,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk }, } - packet.data = pkt.Data + // ICMP socket's data includes ICMP header. + packet.data = pkt.TransportHeader.ToVectorisedView() + packet.data.Append(pkt.Data) e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() -- cgit v1.2.3 From 9c32fd3f4d8f6e63d922c1c58b7d1f1f504fa2bc Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Sun, 12 Jul 2020 17:20:50 -0700 Subject: Do not copy sleep.Waker sleep.Waker's fields are modified as values. PiperOrigin-RevId: 320873451 --- pkg/sleep/BUILD | 1 + pkg/sleep/sleep_test.go | 20 +++++++------------- pkg/sleep/sleep_unsafe.go | 7 +++++++ pkg/sync/BUILD | 1 + pkg/sync/nocopy.go | 28 ++++++++++++++++++++++++++++ pkg/tcpip/stack/packet_buffer.go | 14 ++------------ pkg/tcpip/timer.go | 21 ++++----------------- 7 files changed, 50 insertions(+), 42 deletions(-) create mode 100644 pkg/sync/nocopy.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index e131455f7..ae0fe1522 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -12,6 +12,7 @@ go_library( "sleep_unsafe.go", ], visibility = ["//:sandbox"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go index af47e2ba1..1dd11707d 100644 --- a/pkg/sleep/sleep_test.go +++ b/pkg/sleep/sleep_test.go @@ -379,10 +379,7 @@ func TestRace(t *testing.T) { // TestRaceInOrder tests that multiple wakers can continuously send wake requests to // the sleeper and that the wakers are retrieved in the order asserted. func TestRaceInOrder(t *testing.T) { - const wakers = 100 - const wakeRequests = 10000 - - w := make([]Waker, wakers) + w := make([]Waker, 10000) s := Sleeper{} // Associate each waker and start goroutines that will assert them. @@ -390,19 +387,16 @@ func TestRaceInOrder(t *testing.T) { s.AddWaker(&w[i], i) } go func() { - n := 0 - for n < wakeRequests { - wk := w[n%len(w)] - wk.Assert() - n++ + for i := range w { + w[i].Assert() } }() // Wait for all wake up notifications from all wakers. - for i := 0; i < wakeRequests; i++ { - v, _ := s.Fetch(true) - if got, want := v, i%wakers; got != want { - t.Fatalf("got %d want %d", got, want) + for want := range w { + got, _ := s.Fetch(true) + if got != want { + t.Fatalf("got %d want %d", got, want) } } } diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index f68c12620..118805492 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -75,6 +75,8 @@ package sleep import ( "sync/atomic" "unsafe" + + "gvisor.dev/gvisor/pkg/sync" ) const ( @@ -323,7 +325,12 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) { // // This struct is thread-safe, that is, its methods can be called concurrently // by multiple goroutines. +// +// Note, it is not safe to copy a Waker as its fields are modified by value +// (the pointer fields are individually modified with atomic operations). type Waker struct { + _ sync.NoCopy + // s is the sleeper that this waker can wake up. Only one sleeper at a // time is allowed. This field can have three classes of values: // nil -- the waker is not asserted: it either is not associated with diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index d0d77e19c..4d47207f7 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -33,6 +33,7 @@ go_library( "aliases.go", "memmove_unsafe.go", "mutex_unsafe.go", + "nocopy.go", "norace_unsafe.go", "race_unsafe.go", "rwmutex_unsafe.go", diff --git a/pkg/sync/nocopy.go b/pkg/sync/nocopy.go new file mode 100644 index 000000000..722b29501 --- /dev/null +++ b/pkg/sync/nocopy.go @@ -0,0 +1,28 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sync + +// NoCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +type NoCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*NoCopy) Lock() {} + +// Unlock is a no-op used by -copylocks checker from `go vet`. +func (*NoCopy) Unlock() {} diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 1b5da6017..e3556d5d2 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -14,6 +14,7 @@ package stack import ( + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -24,7 +25,7 @@ import ( // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. type PacketBuffer struct { - _ noCopy + _ sync.NoCopy // PacketBufferEntry is used to build an intrusive list of // PacketBuffers. @@ -102,14 +103,3 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { NatDone: pk.NatDone, } } - -// noCopy may be embedded into structs which must not be copied -// after the first use. -// -// See https://golang.org/issues/8005#issuecomment-190753527 -// for details. -type noCopy struct{} - -// Lock is a no-op used by -copylocks checker from `go vet`. -func (*noCopy) Lock() {} -func (*noCopy) Unlock() {} diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go index 59f3b391f..5554c573f 100644 --- a/pkg/tcpip/timer.go +++ b/pkg/tcpip/timer.go @@ -15,8 +15,9 @@ package tcpip import ( - "sync" "time" + + "gvisor.dev/gvisor/pkg/sync" ) // cancellableTimerInstance is a specific instance of CancellableTimer. @@ -92,6 +93,8 @@ func (t *cancellableTimerInstance) stop() { // Note, it is not safe to copy a CancellableTimer as its timer instance creates // a closure over the address of the CancellableTimer. type CancellableTimer struct { + _ sync.NoCopy + // The active instance of a cancellable timer. instance cancellableTimerInstance @@ -157,22 +160,6 @@ func (t *CancellableTimer) Reset(d time.Duration) { } } -// Lock is a no-op used by the copylocks checker from go vet. -// -// See CancellableTimer for details about why it shouldn't be copied. -// -// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more -// details about the copylocks checker. -func (*CancellableTimer) Lock() {} - -// Unlock is a no-op used by the copylocks checker from go vet. -// -// See CancellableTimer for details about why it shouldn't be copied. -// -// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more -// details about the copylocks checker. -func (*CancellableTimer) Unlock() {} - // NewCancellableTimer returns an unscheduled CancellableTimer with the given // locker and fn. // -- cgit v1.2.3 From 43c209f48e0aa9024705583cc6f0fafa7d6380ca Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Mon, 13 Jul 2020 11:59:26 -0700 Subject: garbage collect connections As in Linux, we must periodically clean up unused connections. PiperOrigin-RevId: 321003353 --- pkg/tcpip/stack/BUILD | 14 ++ pkg/tcpip/stack/conntrack.go | 277 ++++++++++++++++++---- pkg/tcpip/stack/iptables.go | 42 +++- pkg/tcpip/stack/iptables_state.go | 40 ++++ pkg/tcpip/stack/iptables_types.go | 11 + pkg/tcpip/stack/stack.go | 1 + pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go | 5 + 7 files changed, 347 insertions(+), 43 deletions(-) create mode 100644 pkg/tcpip/stack/iptables_state.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index e65c731c2..800bf3f08 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -27,6 +27,18 @@ go_template_instance( }, ) +go_template_instance( + name = "tuple_list", + out = "tuple_list.go", + package = "stack", + prefix = "tuple", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*tuple", + "Linker": "*tuple", + }, +) + go_library( name = "stack", srcs = [ @@ -35,6 +47,7 @@ go_library( "forwarder.go", "icmp_rate_limit.go", "iptables.go", + "iptables_state.go", "iptables_targets.go", "iptables_types.go", "linkaddrcache.go", @@ -50,6 +63,7 @@ go_library( "stack_global_state.go", "stack_options.go", "transport_demuxer.go", + "tuple_list.go", ], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index af9c325ca..d39baf620 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -15,9 +15,12 @@ package stack import ( + "encoding/binary" "sync" + "time" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" ) @@ -30,6 +33,10 @@ import ( // // Currently, only TCP tracking is supported. +// Our hash table has 16K buckets. +// TODO(gvisor.dev/issue/170): These should be tunable. +const numBuckets = 1 << 14 + // Direction of the tuple. type direction int @@ -48,7 +55,12 @@ const ( // tuple holds a connection's identifying and manipulating data in one // direction. It is immutable. +// +// +stateify savable type tuple struct { + // tupleEntry is used to build an intrusive list of tuples. + tupleEntry + tupleID // conn is the connection tracking entry this tuple belongs to. @@ -61,6 +73,8 @@ type tuple struct { // tupleID uniquely identifies a connection in one direction. It currently // contains enough information to distinguish between any TCP or UDP // connection, and will need to be extended to support other protocols. +// +// +stateify savable type tupleID struct { srcAddr tcpip.Address srcPort uint16 @@ -83,6 +97,8 @@ func (ti tupleID) reply() tupleID { } // conn is a tracked connection. +// +// +stateify savable type conn struct { // original is the tuple in original direction. It is immutable. original tuple @@ -98,22 +114,67 @@ type conn struct { tcbHook Hook // mu protects tcb. - mu sync.Mutex + mu sync.Mutex `state:"nosave"` // tcb is TCB control block. It is used to keep track of states // of tcp connection and is protected by mu. tcb tcpconntrack.TCB + + // lastUsed is the last time the connection saw a relevant packet, and + // is updated by each packet on the connection. It is protected by mu. + lastUsed time.Time `state:".(unixTime)"` +} + +// timedOut returns whether the connection timed out based on its state. +func (cn *conn) timedOut(now time.Time) bool { + const establishedTimeout = 5 * 24 * time.Hour + const defaultTimeout = 120 * time.Second + cn.mu.Lock() + defer cn.mu.Unlock() + if cn.tcb.State() == tcpconntrack.ResultAlive { + // Use the same default as Linux, which doesn't delete + // established connections for 5(!) days. + return now.Sub(cn.lastUsed) > establishedTimeout + } + // Use the same default as Linux, which lets connections in most states + // other than established remain for <= 120 seconds. + return now.Sub(cn.lastUsed) > defaultTimeout } // ConnTrack tracks all connections created for NAT rules. Most users are // expected to only call handlePacket and createConnFor. +// +// ConnTrack keeps all connections in a slice of buckets, each of which holds a +// linked list of tuples. This gives us some desirable properties: +// - Each bucket has its own lock, lessening lock contention. +// - The slice is large enough that lists stay short (<10 elements on average). +// Thus traversal is fast. +// - During linked list traversal we reap expired connections. This amortizes +// the cost of reaping them and makes reapUnused faster. +// +// Locks are ordered by their location in the buckets slice. That is, a +// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j. +// +// +stateify savable type ConnTrack struct { - // mu protects conns. - mu sync.RWMutex + // seed is a one-time random value initialized at stack startup + // and is used in the calculation of hash keys for the list of buckets. + // It is immutable. + seed uint32 - // conns maintains a map of tuples needed for connection tracking for - // iptables NAT rules. It is protected by mu. - conns map[tupleID]tuple + // mu protects the buckets slice, but not buckets' contents. Only take + // the write lock if you are modifying the slice or saving for S/R. + mu sync.RWMutex `state:"nosave"` + + // buckets is protected by mu. + buckets []bucket +} + +// +stateify savable +type bucket struct { + // mu protects tuples. + mu sync.Mutex `state:"nosave"` + tuples tupleList } // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid @@ -143,8 +204,9 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { // newConn creates new connection. func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { conn := conn{ - manip: manip, - tcbHook: hook, + manip: manip, + tcbHook: hook, + lastUsed: time.Now(), } conn.original = tuple{conn: &conn, tupleID: orig} conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} @@ -162,14 +224,28 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { return nil, dirOriginal } - ct.mu.Lock() - defer ct.mu.Unlock() - - tuple, ok := ct.conns[tid] - if !ok { - return nil, dirOriginal + bucket := ct.bucket(tid) + now := time.Now() + + ct.mu.RLock() + defer ct.mu.RUnlock() + ct.buckets[bucket].mu.Lock() + defer ct.buckets[bucket].mu.Unlock() + + // Iterate over the tuples in a bucket, cleaning up any unused + // connections we find. + for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() { + // Clean up any timed-out connections we happen to find. + if ct.reapTupleLocked(other, bucket, now) { + // The tuple expired. + continue + } + if tid == other.tupleID { + return other.conn, other.direction + } } - return tuple.conn, tuple.direction + + return nil, dirOriginal } // createConnFor creates a new conn for pkt. @@ -197,13 +273,31 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg } conn := newConn(tid, replyTID, manip, hook) - // Add the changed tuple to the map. - // TODO(gvisor.dev/issue/170): Need to support collisions using linked - // list. - ct.mu.Lock() - defer ct.mu.Unlock() - ct.conns[tid] = conn.original - ct.conns[replyTID] = conn.reply + // Lock the buckets in the correct order. + tupleBucket := ct.bucket(tid) + replyBucket := ct.bucket(replyTID) + ct.mu.RLock() + defer ct.mu.RUnlock() + if tupleBucket < replyBucket { + ct.buckets[tupleBucket].mu.Lock() + ct.buckets[replyBucket].mu.Lock() + } else if tupleBucket > replyBucket { + ct.buckets[replyBucket].mu.Lock() + ct.buckets[tupleBucket].mu.Lock() + } else { + // Both tuples are in the same bucket. + ct.buckets[tupleBucket].mu.Lock() + } + + // Add the tuple to the map. + ct.buckets[tupleBucket].tuples.PushFront(&conn.original) + ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + + // Unlocking can happen in any order. + ct.buckets[tupleBucket].mu.Unlock() + if tupleBucket != replyBucket { + ct.buckets[replyBucket].mu.Unlock() + } return conn } @@ -297,35 +391,134 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou // other tcp states. conn.mu.Lock() defer conn.mu.Unlock() - var st tcpconntrack.Result - tcpHeader := header.TCP(pkt.TransportHeader) - if conn.tcb.IsEmpty() { + + // Mark the connection as having been used recently so it isn't reaped. + conn.lastUsed = time.Now() + // Update connection state. + if tcpHeader := header.TCP(pkt.TransportHeader); conn.tcb.IsEmpty() { conn.tcb.Init(tcpHeader) conn.tcbHook = hook + } else if hook == conn.tcbHook { + conn.tcb.UpdateStateOutbound(tcpHeader) } else { - switch hook { - case conn.tcbHook: - st = conn.tcb.UpdateStateOutbound(tcpHeader) - default: - st = conn.tcb.UpdateStateInbound(tcpHeader) - } + conn.tcb.UpdateStateInbound(tcpHeader) } +} + +// bucket gets the conntrack bucket for a tupleID. +func (ct *ConnTrack) bucket(id tupleID) int { + h := jenkins.Sum32(ct.seed) + h.Write([]byte(id.srcAddr)) + h.Write([]byte(id.dstAddr)) + shortBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(shortBuf, id.srcPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, id.dstPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto)) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto)) + h.Write([]byte(shortBuf)) + ct.mu.RLock() + defer ct.mu.RUnlock() + return int(h.Sum32()) % len(ct.buckets) +} - // Delete conn if tcp connection is closed. - if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset { - ct.deleteConn(conn) +// reapUnused deletes timed out entries from the conntrack map. The rules for +// reaping are: +// - Most reaping occurs in connFor, which is called on each packet. connFor +// cleans up the bucket the packet's connection maps to. Thus calls to +// reapUnused should be fast. +// - Each call to reapUnused traverses a fraction of the conntrack table. +// Specifically, it traverses len(ct.buckets)/fractionPerReaping. +// - After reaping, reapUnused decides when it should next run based on the +// ratio of expired connections to examined connections. If the ratio is +// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it +// slightly increases the interval between runs. +// - maxFullTraversal caps the time it takes to traverse the entire table. +// +// reapUnused returns the next bucket that should be checked and the time after +// which it should be called again. +func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { + // TODO(gvisor.dev/issue/170): This can be more finely controlled, as + // it is in Linux via sysctl. + const fractionPerReaping = 128 + const maxExpiredPct = 50 + const maxFullTraversal = 60 * time.Second + const minInterval = 10 * time.Millisecond + const maxInterval = maxFullTraversal / fractionPerReaping + + now := time.Now() + checked := 0 + expired := 0 + var idx int + ct.mu.RLock() + defer ct.mu.RUnlock() + for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { + idx = (i + start) % len(ct.buckets) + ct.buckets[idx].mu.Lock() + for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() { + checked++ + if ct.reapTupleLocked(tuple, idx, now) { + expired++ + } + } + ct.buckets[idx].mu.Unlock() + } + // We already checked buckets[idx]. + idx++ + + // If half or more of the connections are expired, the table has gotten + // stale. Reschedule quickly. + expiredPct := 0 + if checked != 0 { + expiredPct = expired * 100 / checked + } + if expiredPct > maxExpiredPct { + return idx, minInterval + } + if interval := prevInterval + minInterval; interval <= maxInterval { + // Increment the interval between runs. + return idx, interval } + // We've hit the maximum interval. + return idx, maxInterval } -// deleteConn deletes the connection. -func (ct *ConnTrack) deleteConn(conn *conn) { - if conn == nil { - return +// reapTupleLocked tries to remove tuple and its reply from the table. It +// returns whether the tuple's connection has timed out. +// +// Preconditions: ct.mu is locked for reading and bucket is locked. +func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { + if !tuple.conn.timedOut(now) { + return false } - ct.mu.Lock() - defer ct.mu.Unlock() + // To maintain lock order, we can only reap these tuples if the reply + // appears later in the table. + replyBucket := ct.bucket(tuple.reply()) + if bucket > replyBucket { + return true + } + + // Don't re-lock if both tuples are in the same bucket. + differentBuckets := bucket != replyBucket + if differentBuckets { + ct.buckets[replyBucket].mu.Lock() + } + + // We have the buckets locked and can remove both tuples. + if tuple.direction == dirOriginal { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply) + } else { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original) + } + ct.buckets[bucket].tuples.Remove(tuple) + + // Don't re-unlock if both tuples are in the same bucket. + if differentBuckets { + ct.buckets[replyBucket].mu.Unlock() + } - delete(ct.conns, conn.original.tupleID) - delete(ct.conns, conn.reply.tupleID) + return true } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 974d77c36..f846ea2e5 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -41,6 +42,9 @@ const ( // underflow. const HookUnset = -1 +// reaperDelay is how long to wait before starting to reap connections. +const reaperDelay = 5 * time.Second + // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. func DefaultTables() *IPTables { @@ -112,8 +116,9 @@ func DefaultTables() *IPTables { Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, }, connections: ConnTrack{ - conns: make(map[tupleID]tuple), + seed: generateRandUint32(), }, + reaperDone: make(chan struct{}, 1), } } @@ -169,6 +174,12 @@ func (it *IPTables) GetTable(name string) (Table, bool) { func (it *IPTables) ReplaceTable(name string, table Table) { it.mu.Lock() defer it.mu.Unlock() + // If iptables is being enabled, initialize the conntrack table and + // reaper. + if !it.modified { + it.connections.buckets = make([]bucket, numBuckets) + it.startReaper(reaperDelay) + } it.modified = true it.tables[name] = table } @@ -249,6 +260,35 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr return true } +// beforeSave is invoked by stateify. +func (it *IPTables) beforeSave() { + // Ensure the reaper exits cleanly. + it.reaperDone <- struct{}{} + // Prevent others from modifying the connection table. + it.connections.mu.Lock() +} + +// afterLoad is invoked by stateify. +func (it *IPTables) afterLoad() { + it.startReaper(reaperDelay) +} + +// startReaper starts a goroutine that wakes up periodically to reap timed out +// connections. +func (it *IPTables) startReaper(interval time.Duration) { + go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved. + bucket := 0 + for { + select { + case <-it.reaperDone: + return + case <-time.After(interval): + bucket, interval = it.connections.reapUnused(bucket, interval) + } + } + }() +} + // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go new file mode 100644 index 000000000..529e02a07 --- /dev/null +++ b/pkg/tcpip/stack/iptables_state.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "time" +) + +// +stateify savable +type unixTime struct { + second int64 + nano int64 +} + +// saveLastUsed is invoked by stateify. +func (cn *conn) saveLastUsed() unixTime { + return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()} +} + +// loadLastUsed is invoked by stateify. +func (cn *conn) loadLastUsed(unix unixTime) { + cn.lastUsed = time.Unix(unix.second, unix.nano) +} + +// beforeSave is invoked by stateify. +func (ct *ConnTrack) beforeSave() { + ct.mu.Lock() +} diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index c528ec381..eb70e3104 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -78,6 +78,8 @@ const ( ) // IPTables holds all the tables for a netstack. +// +// +stateify savable type IPTables struct { // mu protects tables, priorities, and modified. mu sync.RWMutex @@ -97,10 +99,15 @@ type IPTables struct { modified bool connections ConnTrack + + // reaperDone can be signalled to stop the reaper goroutine. + reaperDone chan struct{} } // A Table defines a set of chains and hooks into the network stack. It is // really just a list of rules. +// +// +stateify savable type Table struct { // Rules holds the rules that make up the table. Rules []Rule @@ -130,6 +137,8 @@ func (table *Table) ValidHooks() uint32 { // contains zero or more matchers, each of which is a specification of which // packets this rule applies to. If there are no matchers in the rule, it // applies to any packet. +// +// +stateify savable type Rule struct { // Filter holds basic IP filtering fields common to every rule. Filter IPHeaderFilter @@ -142,6 +151,8 @@ type Rule struct { } // IPHeaderFilter holds basic IP filtering data common to every rule. +// +// +stateify savable type IPHeaderFilter struct { // Protocol matches the transport protocol. Protocol tcpip.TransportProtocolNumber diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index cdcfb8321..0aa815447 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -425,6 +425,7 @@ type Stack struct { handleLocal bool // tables are the iptables packet filtering and manipulation rules. + // TODO(gvisor.dev/issue/170): S/R this field. tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 12bc1b5b5..558b06df0 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { return st } +// State returns the current state of the TCB. +func (t *TCB) State() Result { + return t.state +} + // IsAlive returns true as long as the connection is established(Alive) // or connecting state. func (t *TCB) IsAlive() bool { -- cgit v1.2.3 From fef90c61c6186c113cfdb0bbcf53f4ca70f9741a Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 15 Jul 2020 14:13:42 -0700 Subject: Fix minor bugs in a couple of interface IOCTLs. gVisor incorrectly returns the wrong ARP type for SIOGIFHWADDR. This breaks tcpdump as it tries to interpret the packets incorrectly. Similarly, SIOCETHTOOL is used by tcpdump to query interface properties which fails with an EINVAL since we don't implement it. For now change it to return EOPNOTSUPP to indicate that we don't support the query rather than return EINVAL. NOTE: ARPHRD types for link endpoints are distinct from NIC capabilities and NIC flags. In Linux all 3 exist eg. ARPHRD types are stored in dev->type field while NIC capabilities are more like the device features which can be queried using SIOCETHTOOL but not modified and NIC Flags are fields that can be modified from user space. eg. NIC status (UP/DOWN/MULTICAST/BROADCAST) etc. Updates #2746 PiperOrigin-RevId: 321436525 --- pkg/abi/linux/ioctl.go | 27 +++++++++-- pkg/abi/linux/netlink_route.go | 2 + pkg/sentry/socket/netstack/netstack.go | 74 ++++++++++++++++++------------ pkg/sentry/socket/netstack/stack.go | 22 ++++++--- pkg/tcpip/header/arp.go | 77 +++++++++++++++++++++----------- pkg/tcpip/link/channel/BUILD | 1 + pkg/tcpip/link/channel/channel.go | 6 +++ pkg/tcpip/link/fdbased/endpoint.go | 8 ++++ pkg/tcpip/link/loopback/loopback.go | 5 +++ pkg/tcpip/link/muxed/BUILD | 1 + pkg/tcpip/link/muxed/injectable.go | 6 +++ pkg/tcpip/link/nested/BUILD | 1 + pkg/tcpip/link/nested/nested.go | 6 +++ pkg/tcpip/link/qdisc/fifo/BUILD | 1 + pkg/tcpip/link/qdisc/fifo/endpoint.go | 6 +++ pkg/tcpip/link/sharedmem/sharedmem.go | 5 +++ pkg/tcpip/link/tun/device.go | 10 +++++ pkg/tcpip/link/waitable/BUILD | 2 + pkg/tcpip/link/waitable/waitable.go | 6 +++ pkg/tcpip/link/waitable/waitable_test.go | 6 +++ pkg/tcpip/network/ip_test.go | 9 +++- pkg/tcpip/stack/forwarder_test.go | 6 +++ pkg/tcpip/stack/nic_test.go | 5 +++ pkg/tcpip/stack/registration.go | 7 +++ pkg/tcpip/stack/stack.go | 6 +++ test/syscalls/linux/socket_netdevice.cc | 23 ++++++++++ 26 files changed, 263 insertions(+), 65 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 2062e6a4b..2c5e56ae5 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -67,10 +67,29 @@ const ( // ioctl(2) requests provided by uapi/linux/sockios.h const ( - SIOCGIFMEM = 0x891f - SIOCGIFPFLAGS = 0x8935 - SIOCGMIIPHY = 0x8947 - SIOCGMIIREG = 0x8948 + SIOCGIFNAME = 0x8910 + SIOCGIFCONF = 0x8912 + SIOCGIFFLAGS = 0x8913 + SIOCGIFADDR = 0x8915 + SIOCGIFDSTADDR = 0x8917 + SIOCGIFBRDADDR = 0x8919 + SIOCGIFNETMASK = 0x891b + SIOCGIFMETRIC = 0x891d + SIOCGIFMTU = 0x8921 + SIOCGIFMEM = 0x891f + SIOCGIFHWADDR = 0x8927 + SIOCGIFINDEX = 0x8933 + SIOCGIFPFLAGS = 0x8935 + SIOCGIFTXQLEN = 0x8942 + SIOCETHTOOL = 0x8946 + SIOCGMIIPHY = 0x8947 + SIOCGMIIREG = 0x8948 + SIOCGIFMAP = 0x8970 +) + +// ioctl(2) requests provided by uapi/asm-generic/sockios.h +const ( + SIOCGSTAMP = 0x8906 ) // ioctl(2) directions. Used to calculate requests number. diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index 40bec566c..ceda0a8d3 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -187,6 +187,8 @@ const ( // Device types, from uapi/linux/if_arp.h. const ( + ARPHRD_NONE = 65534 + ARPHRD_ETHER = 1 ARPHRD_LOOPBACK = 772 ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 78a842973..0b1be1bd2 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -2747,7 +2747,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. switch args[1].Int() { - case syscall.SIOCGSTAMP: + case linux.SIOCGSTAMP: s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { @@ -2788,18 +2788,19 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy // Ioctl performs a socket ioctl. func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch arg := int(args[1].Int()); arg { - case syscall.SIOCGIFFLAGS, - syscall.SIOCGIFADDR, - syscall.SIOCGIFBRDADDR, - syscall.SIOCGIFDSTADDR, - syscall.SIOCGIFHWADDR, - syscall.SIOCGIFINDEX, - syscall.SIOCGIFMAP, - syscall.SIOCGIFMETRIC, - syscall.SIOCGIFMTU, - syscall.SIOCGIFNAME, - syscall.SIOCGIFNETMASK, - syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFFLAGS, + linux.SIOCGIFADDR, + linux.SIOCGIFBRDADDR, + linux.SIOCGIFDSTADDR, + linux.SIOCGIFHWADDR, + linux.SIOCGIFINDEX, + linux.SIOCGIFMAP, + linux.SIOCGIFMETRIC, + linux.SIOCGIFMTU, + linux.SIOCGIFNAME, + linux.SIOCGIFNETMASK, + linux.SIOCGIFTXQLEN, + linux.SIOCETHTOOL: var ifr linux.IFReq if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ @@ -2815,7 +2816,7 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc }) return 0, err - case syscall.SIOCGIFCONF: + case linux.SIOCGIFCONF: // Return a list of interface addresses or the buffer size // necessary to hold the list. var ifc linux.IFConf @@ -2889,7 +2890,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to // identify a device. - if arg == syscall.SIOCGIFNAME { + if arg == linux.SIOCGIFNAME { // Gets the name of the interface given the interface index // stored in ifr_ifindex. index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4])) @@ -2912,21 +2913,28 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } switch arg { - case syscall.SIOCGIFINDEX: + case linux.SIOCGIFINDEX: // Copy out the index to the data. usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index)) - case syscall.SIOCGIFHWADDR: + case linux.SIOCGIFHWADDR: // Copy the hardware address out. - ifr.Data[0] = 6 // IEEE802.2 arp type. - ifr.Data[1] = 0 + // + // Refer: https://linux.die.net/man/7/netdevice + // SIOCGIFHWADDR, SIOCSIFHWADDR + // + // Get or set the hardware address of a device using + // ifr_hwaddr. The hardware address is specified in a struct + // sockaddr. sa_family contains the ARPHRD_* device type, + // sa_data the L2 hardware address starting from byte 0. Setting + // the hardware address is a privileged operation. + usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType) n := copy(ifr.Data[2:], iface.Addr) for i := 2 + n; i < len(ifr.Data); i++ { ifr.Data[i] = 0 // Clear padding. } - usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n)) - case syscall.SIOCGIFFLAGS: + case linux.SIOCGIFFLAGS: f, err := interfaceStatusFlags(stack, iface.Name) if err != nil { return err @@ -2935,7 +2943,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // matches Linux behavior. usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f)) - case syscall.SIOCGIFADDR: + case linux.SIOCGIFADDR: // Copy the IPv4 address out. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2946,32 +2954,32 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } - case syscall.SIOCGIFMETRIC: + case linux.SIOCGIFMETRIC: // Gets the metric of the device. As per netdevice(7), this // always just sets ifr_metric to 0. usermem.ByteOrder.PutUint32(ifr.Data[:4], 0) - case syscall.SIOCGIFMTU: + case linux.SIOCGIFMTU: // Gets the MTU of the device. usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU) - case syscall.SIOCGIFMAP: + case linux.SIOCGIFMAP: // Gets the hardware parameters of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFTXQLEN: // Gets the transmit queue length of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFDSTADDR: + case linux.SIOCGIFDSTADDR: // Gets the destination address of a point-to-point device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFBRDADDR: + case linux.SIOCGIFBRDADDR: // Gets the broadcast address of a device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFNETMASK: + case linux.SIOCGIFNETMASK: // Gets the network mask of a device. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2988,6 +2996,14 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } + case linux.SIOCETHTOOL: + // Stubbed out for now, Ideally we should implement the required + // sub-commands for ETHTOOL + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/net/core/dev_ioctl.c + return syserr.ErrEndpointOperation + default: // Not a valid call. return syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 548442b96..67737ae87 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -15,6 +15,8 @@ package netstack import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -40,19 +42,29 @@ func (s *Stack) SupportsIPv6() bool { return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber) } +// Converts Netstack's ARPHardwareType to equivalent linux constants. +func toLinuxARPHardwareType(t header.ARPHardwareType) uint16 { + switch t { + case header.ARPHardwareNone: + return linux.ARPHRD_NONE + case header.ARPHardwareLoopback: + return linux.ARPHRD_LOOPBACK + case header.ARPHardwareEther: + return linux.ARPHRD_ETHER + default: + panic(fmt.Sprintf("unknown ARPHRD type: %d", t)) + } +} + // Interfaces implements inet.Stack.Interfaces. func (s *Stack) Interfaces() map[int32]inet.Interface { is := make(map[int32]inet.Interface) for id, ni := range s.Stack.NICInfo() { - var devType uint16 - if ni.Flags.Loopback { - devType = linux.ARPHRD_LOOPBACK - } is[int32(id)] = inet.Interface{ Name: ni.Name, Addr: []byte(ni.LinkAddress), Flags: uint32(nicStateFlagsToLinux(ni.Flags)), - DeviceType: devType, + DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType), MTU: ni.MTU, } } diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go index 718a4720a..83189676e 100644 --- a/pkg/tcpip/header/arp.go +++ b/pkg/tcpip/header/arp.go @@ -14,14 +14,33 @@ package header -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/tcpip" +) const ( // ARPProtocolNumber is the ARP network protocol number. ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806 // ARPSize is the size of an IPv4-over-Ethernet ARP packet. - ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4 + ARPSize = 28 +) + +// ARPHardwareType is the hardware type for LinkEndpoint in an ARP header. +type ARPHardwareType uint16 + +// Typical ARP HardwareType values. Some of the constants have to be specific +// values as they are egressed on the wire in the HTYPE field of an ARP header. +const ( + ARPHardwareNone ARPHardwareType = 0 + // ARPHardwareEther specifically is the HTYPE for Ethernet as specified + // in the IANA list here: + // + // https://www.iana.org/assignments/arp-parameters/arp-parameters.xhtml#arp-parameters-2 + ARPHardwareEther ARPHardwareType = 1 + ARPHardwareLoopback ARPHardwareType = 2 ) // ARPOp is an ARP opcode. @@ -36,54 +55,64 @@ const ( // ARP is an ARP packet stored in a byte array as described in RFC 826. type ARP []byte -func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) } -func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) } -func (a ARP) hardwareAddressSize() int { return int(a[4]) } -func (a ARP) protocolAddressSize() int { return int(a[5]) } +const ( + hTypeOffset = 0 + protocolOffset = 2 + haAddressSizeOffset = 4 + protoAddressSizeOffset = 5 + opCodeOffset = 6 + senderHAAddressOffset = 8 + senderProtocolAddressOffset = senderHAAddressOffset + EthernetAddressSize + targetHAAddressOffset = senderProtocolAddressOffset + IPv4AddressSize + targetProtocolAddressOffset = targetHAAddressOffset + EthernetAddressSize +) + +func (a ARP) hardwareAddressType() ARPHardwareType { + return ARPHardwareType(binary.BigEndian.Uint16(a[hTypeOffset:])) +} + +func (a ARP) protocolAddressSpace() uint16 { return binary.BigEndian.Uint16(a[protocolOffset:]) } +func (a ARP) hardwareAddressSize() int { return int(a[haAddressSizeOffset]) } +func (a ARP) protocolAddressSize() int { return int(a[protoAddressSizeOffset]) } // Op is the ARP opcode. -func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) } +func (a ARP) Op() ARPOp { return ARPOp(binary.BigEndian.Uint16(a[opCodeOffset:])) } // SetOp sets the ARP opcode. func (a ARP) SetOp(op ARPOp) { - a[6] = uint8(op >> 8) - a[7] = uint8(op) + binary.BigEndian.PutUint16(a[opCodeOffset:], uint16(op)) } // SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet. func (a ARP) SetIPv4OverEthernet() { - a[0], a[1] = 0, 1 // htypeEthernet - a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber - a[4] = 6 // macSize - a[5] = uint8(IPv4AddressSize) + binary.BigEndian.PutUint16(a[hTypeOffset:], uint16(ARPHardwareEther)) + binary.BigEndian.PutUint16(a[protocolOffset:], uint16(IPv4ProtocolNumber)) + a[haAddressSizeOffset] = EthernetAddressSize + a[protoAddressSizeOffset] = uint8(IPv4AddressSize) } // HardwareAddressSender is the link address of the sender. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) HardwareAddressSender() []byte { - const s = 8 - return a[s : s+6] + return a[senderHAAddressOffset : senderHAAddressOffset+EthernetAddressSize] } // ProtocolAddressSender is the protocol address of the sender. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) ProtocolAddressSender() []byte { - const s = 8 + 6 - return a[s : s+4] + return a[senderProtocolAddressOffset : senderProtocolAddressOffset+IPv4AddressSize] } // HardwareAddressTarget is the link address of the target. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) HardwareAddressTarget() []byte { - const s = 8 + 6 + 4 - return a[s : s+6] + return a[targetHAAddressOffset : targetHAAddressOffset+EthernetAddressSize] } // ProtocolAddressTarget is the protocol address of the target. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) ProtocolAddressTarget() []byte { - const s = 8 + 6 + 4 + 6 - return a[s : s+4] + return a[targetProtocolAddressOffset : targetProtocolAddressOffset+IPv4AddressSize] } // IsValid reports whether this is an ARP packet for IPv4 over Ethernet. @@ -91,10 +120,8 @@ func (a ARP) IsValid() bool { if len(a) < ARPSize { return false } - const htypeEthernet = 1 - const macSize = 6 - return a.hardwareAddressSpace() == htypeEthernet && + return a.hardwareAddressType() == ARPHardwareEther && a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) && - a.hardwareAddressSize() == macSize && + a.hardwareAddressSize() == EthernetAddressSize && a.protocolAddressSize() == IPv4AddressSize } diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index b8b93e78e..39ca774ef 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 20b183da0..a2bb773d4 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -296,3 +297,8 @@ func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle { func (e *Endpoint) RemoveNotify(handle *NotificationHandle) { e.q.RemoveNotify(handle) } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index f34082e1a..32abe2a13 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -626,6 +626,14 @@ func (e *endpoint) GSOMaxSize() uint32 { return e.gsoMaxSize } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + if e.hdrSize > 0 { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} + // InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes // to the FD, but does not read from it. All reads come from injected packets. type InjectableEndpoint struct { diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 568c6874f..3b17d8c28 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -113,3 +113,8 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { return nil } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareLoopback +} diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 82b441b79..e7493e5c5 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index c69d6b7e9..c305d9e86 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -18,6 +18,7 @@ package muxed import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -129,6 +130,11 @@ func (m *InjectableEndpoint) Wait() { } } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("unsupported operation") +} + // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { return &InjectableEndpoint{ diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD index bdd5276ad..2cdb23475 100644 --- a/pkg/tcpip/link/nested/BUILD +++ b/pkg/tcpip/link/nested/BUILD @@ -12,6 +12,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 2998f9c4f..328bd048e 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -129,3 +130,8 @@ func (e *Endpoint) GSOMaxSize() uint32 { } return 0 } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + return e.child.ARPHardwareType() +} diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD index 054c213bc..1d0079bd6 100644 --- a/pkg/tcpip/link/qdisc/fifo/BUILD +++ b/pkg/tcpip/link/qdisc/fifo/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index b5dfb7850..c84fe1bb9 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -207,3 +208,8 @@ func (e *endpoint) Wait() { e.wg.Wait() } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + return e.lower.ARPHardwareType() +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 0374a2441..a36862c67 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -287,3 +287,8 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { e.completed.Done() } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareEther +} diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 6bc9033d0..47446efec 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -139,6 +139,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE stack: s, nicID: id, name: name, + isTap: prefix == "tap", } endpoint.Endpoint.LinkEPCapabilities = linkCaps if endpoint.name == "" { @@ -348,6 +349,7 @@ type tunEndpoint struct { stack *stack.Stack nicID tcpip.NICID name string + isTap bool } // DecRef decrements refcount of e, removes NIC if refcount goes to 0. @@ -356,3 +358,11 @@ func (e *tunEndpoint) DecRef() { e.stack.RemoveNIC(e.nicID) }) } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType { + if e.isTap { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 0956d2c65..ee84c3d96 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -12,6 +12,7 @@ go_library( "//pkg/gate", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) @@ -25,6 +26,7 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 949b3f2b2..24a8dc2eb 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/gate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -147,3 +148,8 @@ func (e *Endpoint) WaitDispatch() { // Wait implements stack.LinkEndpoint.Wait. func (e *Endpoint) Wait() {} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + return e.lower.ARPHardwareType() +} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 63bf40562..ffb2354be 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -81,6 +82,11 @@ func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { return nil } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("unimplemented") +} + // Wait implements stack.LinkEndpoint.Wait. func (*countedEndpoint) Wait() {} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 7c8fb3e0a..a5b780ca2 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -172,14 +172,19 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } -func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { +func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { return tcpip.ErrNotSupported } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testObject) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index a6546cef0..eefb4b07f 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) const ( @@ -301,6 +302,11 @@ func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Er // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 31f865260..3bc9fd831 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -84,6 +84,11 @@ func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { return tcpip.ErrNotSupported } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) // An IPv6 NetworkEndpoint that throws away outgoing packets. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 5cbc946b6..f260eeb7f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/waiter" ) @@ -436,6 +437,12 @@ type LinkEndpoint interface { // Wait will not block if the endpoint hasn't started any goroutines // yet, even if it might later. Wait() + + // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint. + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30 + ARPHardwareType() header.ARPHardwareType } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0aa815447..2b7ece851 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1095,6 +1095,11 @@ type NICInfo struct { // Context is user-supplied data optionally supplied in CreateNICWithOptions. // See type NICOptions for more details. Context NICContext + + // ARPHardwareType holds the ARP Hardware type of the NIC. This is the + // value sent in haType field of an ARP Request sent by this NIC and the + // value expected in the haType field of an ARP response. + ARPHardwareType header.ARPHardwareType } // HasNIC returns true if the NICID is defined in the stack. @@ -1126,6 +1131,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { MTU: nic.linkEP.MTU(), Stats: nic.stats, Context: nic.context, + ARPHardwareType: nic.linkEP.ARPHardwareType(), } } return nics diff --git a/test/syscalls/linux/socket_netdevice.cc b/test/syscalls/linux/socket_netdevice.cc index 15d4b85a7..5f8d7f981 100644 --- a/test/syscalls/linux/socket_netdevice.cc +++ b/test/syscalls/linux/socket_netdevice.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -49,6 +50,7 @@ TEST(NetdeviceTest, Loopback) { // Check that the loopback is zero hardware address. ASSERT_THAT(ioctl(sock.get(), SIOCGIFHWADDR, &ifr), SyscallSucceeds()); + EXPECT_EQ(ifr.ifr_hwaddr.sa_family, ARPHRD_LOOPBACK); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[0], 0); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[1], 0); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[2], 0); @@ -178,6 +180,27 @@ TEST(NetdeviceTest, InterfaceMTU) { EXPECT_GT(ifr.ifr_mtu, 0); } +TEST(NetdeviceTest, EthtoolGetTSInfo) { + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ethtool_ts_info tsi = {}; + tsi.cmd = ETHTOOL_GET_TS_INFO; // Get NIC's Timestamping capabilities. + + // Prepare the request. + struct ifreq ifr = {}; + snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); + ifr.ifr_data = (void*)&tsi; + + // Check that SIOCGIFMTU returns a nonzero MTU. + if (IsRunningOnGvisor()) { + ASSERT_THAT(ioctl(sock.get(), SIOCETHTOOL, &ifr), + SyscallFailsWithErrno(EOPNOTSUPP)); + return; + } + ASSERT_THAT(ioctl(sock.get(), SIOCETHTOOL, &ifr), SyscallSucceeds()); +} + } // namespace } // namespace testing -- cgit v1.2.3 From e92f38ff0cd2e490637df2081fc8f75ddaf32937 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 15 Jul 2020 16:33:46 -0700 Subject: iptables: remove check for NetworkHeader This is no longer necessary, as we always set NetworkHeader before calling iptables.Check. PiperOrigin-RevId: 321461978 --- pkg/tcpip/stack/iptables.go | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index f846ea2e5..bbf3b60e8 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -292,10 +292,9 @@ func (it *IPTables) startReaper(interval time.Duration) { // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. @@ -319,9 +318,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * return drop, natPkts } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -366,23 +365,12 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId return chainDrop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] - // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data. - if pkt.NetworkHeader == nil { - var ok bool - pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // Precondition has been violated. - panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) - } - } - // Check whether the packet matches the IP header filter. if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { // Continue on to the next rule. -- cgit v1.2.3 From 71bf90c55bd888f9b9c493533ca5e4b2b4b3d21d Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 22 Jul 2020 15:12:56 -0700 Subject: Support for receiving outbound packets in AF_PACKET. Updates #173 PiperOrigin-RevId: 322665518 --- pkg/abi/linux/socket.go | 9 +++ pkg/sentry/socket/netstack/netstack.go | 19 +++++ pkg/tcpip/link/channel/channel.go | 4 + pkg/tcpip/link/fdbased/endpoint.go | 36 ++++----- pkg/tcpip/link/fdbased/endpoint_test.go | 8 ++ pkg/tcpip/link/loopback/loopback.go | 3 + pkg/tcpip/link/muxed/injectable.go | 4 + pkg/tcpip/link/nested/nested.go | 15 ++++ pkg/tcpip/link/nested/nested_test.go | 4 + pkg/tcpip/link/packetsocket/BUILD | 14 ++++ pkg/tcpip/link/packetsocket/endpoint.go | 50 +++++++++++++ pkg/tcpip/link/qdisc/fifo/endpoint.go | 12 +++ pkg/tcpip/link/sharedmem/sharedmem.go | 21 ++++-- pkg/tcpip/link/sharedmem/sharedmem_test.go | 4 + pkg/tcpip/link/sniffer/sniffer.go | 5 ++ pkg/tcpip/link/tun/device.go | 43 +++++++---- pkg/tcpip/link/waitable/waitable.go | 14 ++++ pkg/tcpip/link/waitable/waitable_test.go | 9 +++ pkg/tcpip/network/ip_test.go | 5 ++ pkg/tcpip/stack/forwarder_test.go | 5 ++ pkg/tcpip/stack/nic.go | 30 ++++++-- pkg/tcpip/stack/nic_test.go | 5 ++ pkg/tcpip/stack/packet_buffer.go | 4 + pkg/tcpip/stack/registration.go | 16 +++- pkg/tcpip/tcpip.go | 25 +++++++ pkg/tcpip/transport/packet/endpoint.go | 52 +++++++++---- runsc/boot/BUILD | 1 + runsc/boot/network.go | 4 + test/syscalls/linux/packet_socket.cc | 116 +++++++++++++++++++++++++++++ 29 files changed, 471 insertions(+), 66 deletions(-) create mode 100644 pkg/tcpip/link/packetsocket/BUILD create mode 100644 pkg/tcpip/link/packetsocket/endpoint.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 4a14ef691..95337c168 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -134,6 +134,15 @@ const ( SHUT_RDWR = 2 ) +// Packet types from +const ( + PACKET_HOST = 0 // To us + PACKET_BROADCAST = 1 // To all + PACKET_MULTICAST = 2 // To group + PACKET_OTHERHOST = 3 // To someone else + PACKET_OUTGOING = 4 // Outgoing of any type +) + // Socket options from socket.h. const ( SO_DEBUG = 1 diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 49a04e613..964ec8414 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -26,6 +26,7 @@ package netstack import ( "bytes" + "fmt" "io" "math" "reflect" @@ -2468,6 +2469,23 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) } +func toLinuxPacketType(pktType tcpip.PacketType) uint8 { + switch pktType { + case tcpip.PacketHost: + return linux.PACKET_HOST + case tcpip.PacketOtherHost: + return linux.PACKET_OTHERHOST + case tcpip.PacketOutgoing: + return linux.PACKET_OUTGOING + case tcpip.PacketBroadcast: + return linux.PACKET_BROADCAST + case tcpip.PacketMulticast: + return linux.PACKET_MULTICAST + default: + panic(fmt.Sprintf("unknown packet type: %d", pktType)) + } +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -2526,6 +2544,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq switch v := addr.(type) { case *linux.SockAddrLink: v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) } } diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index a2bb773d4..e12a5929b 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -302,3 +302,7 @@ func (e *Endpoint) RemoveNotify(handle *NotificationHandle) { func (*Endpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone } + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 6aa1badc7..c18bb91fb 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -386,26 +386,33 @@ const ( _VIRTIO_NET_HDR_GSO_TCPV6 = 4 ) -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if e.hdrSize > 0 { // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, + DstAddr: remote, Type: protocol, } // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + if local != "" { + ethHdr.SrcAddr = local } else { ethHdr.SrcAddr = e.addr } eth.Encode(ethHdr) } +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if e.hdrSize > 0 { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + } var builder iovec.Builder @@ -448,22 +455,8 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc // Send a batch of packets through batchFD. mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) for _, pkt := range batch { - var eth header.Ethernet if e.hdrSize > 0 { - // Add ethernet header if needed. - eth = make(header.Ethernet, header.EthernetMinimumSize) - ethHdr := &header.EthernetFields{ - DstAddr: pkt.EgressRoute.RemoteLinkAddress, - Type: pkt.NetworkProtocolNumber, - } - - // Preserve the src address if it's set in the route. - if pkt.EgressRoute.LocalLinkAddress != "" { - ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress - } else { - ethHdr.SrcAddr = e.addr - } - eth.Encode(ethHdr) + e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) } var vnetHdrBuf []byte @@ -493,7 +486,6 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc var builder iovec.Builder builder.Add(vnetHdrBuf) - builder.Add(eth) builder.Add(pkt.Header.View()) for _, v := range pkt.Data.Views() { builder.Add(v) diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 4bad930c7..7b995b85a 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -107,6 +107,10 @@ func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.Lin c.ch <- packetInfo{remote, protocol, pkt} } +func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestNoEthernetProperties(t *testing.T) { c := newContext(t, &Options{MTU: mtu}) defer c.cleanup() @@ -510,6 +514,10 @@ func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAdd d.pkts = append(d.pkts, pkt) } +func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestDispatchPacketFormat(t *testing.T) { for _, test := range []struct { name string diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 3b17d8c28..781cdd317 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -118,3 +118,6 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { func (*endpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareLoopback } + +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index c305d9e86..56a611825 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -135,6 +135,10 @@ func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType { panic("unsupported operation") } +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} + // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { return &InjectableEndpoint{ diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index 328bd048e..d40de54df 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -61,6 +61,16 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco } } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.mu.RLock() + d := e.dispatcher + e.mu.RUnlock() + if d != nil { + d.DeliverOutboundPacket(remote, local, protocol, pkt) + } +} + // Attach implements stack.LinkEndpoint. func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.mu.Lock() @@ -135,3 +145,8 @@ func (e *Endpoint) GSOMaxSize() uint32 { func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { return e.child.ARPHardwareType() } + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.child.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go index c1a219f02..7d9249c1c 100644 --- a/pkg/tcpip/link/nested/nested_test.go +++ b/pkg/tcpip/link/nested/nested_test.go @@ -55,6 +55,10 @@ func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAd d.count++ } +func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { + panic("unimplemented") +} + func TestNestedLinkEndpoint(t *testing.T) { const emptyAddress = tcpip.LinkAddress("") diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD new file mode 100644 index 000000000..6fff160ce --- /dev/null +++ b/pkg/tcpip/link/packetsocket/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "packetsocket", + srcs = ["endpoint.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go new file mode 100644 index 000000000..3922c2a04 --- /dev/null +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -0,0 +1,50 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package packetsocket provides a link layer endpoint that provides the ability +// to loop outbound packets to any AF_PACKET sockets that may be interested in +// the outgoing packet. +package packetsocket + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type endpoint struct { + nested.Endpoint +} + +// New creates a new packetsocket LinkEndpoint. +func New(lower stack.LinkEndpoint) stack.LinkEndpoint { + e := &endpoint{} + e.Endpoint.Init(lower, e) + return e +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) + return e.Endpoint.WritePacket(r, gso, protocol, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) + } + + return e.Endpoint.WritePackets(r, gso, pkts, proto) +} diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index c84fe1bb9..467083239 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -107,6 +107,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) +} + // Attach implements stack.LinkEndpoint.Attach. func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher @@ -194,6 +199,8 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + // TODO(gvisor.dev/issue/3267/): Queue these packets as well once + // WriteRawPacket takes PacketBuffer instead of VectorisedView. return e.lower.WriteRawPacket(vv) } @@ -213,3 +220,8 @@ func (e *endpoint) Wait() { func (e *endpoint) ARPHardwareType() header.ARPHardwareType { return e.lower.ARPHardwareType() } + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.lower.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index a36862c67..507c76b76 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -183,22 +183,29 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.addr } -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - // Add the ethernet header here. +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, + DstAddr: remote, Type: protocol, } - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + + // Preserve the src address if it's set in the route. + if local != "" { + ethHdr.SrcAddr = local } else { ethHdr.SrcAddr = e.addr } eth.Encode(ethHdr) +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) v := pkt.Data.ToView() // Transmit the packet. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 28a2e88ba..8f3cd9449 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -143,6 +143,10 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L c.packetCh <- struct{}{} } +func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func (c *testContext) cleanup() { c.ep.Close() closeFDs(&c.txCfg) diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index d9cd4e83a..509076643 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -123,6 +123,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt) +} + func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 47446efec..04ae58e59 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -272,21 +272,9 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. if info.Pkt.LinkHeader == nil { - hdr := &header.EthernetFields{ - SrcAddr: info.Route.LocalLinkAddress, - DstAddr: info.Route.RemoteLinkAddress, - Type: info.Proto, - } - if hdr.SrcAddr == "" { - hdr.SrcAddr = d.endpoint.LinkAddress() - } - - eth := make(header.Ethernet, header.EthernetMinimumSize) - eth.Encode(hdr) - vv.AppendView(buffer.View(eth)) - } else { - vv.AppendView(info.Pkt.LinkHeader) + d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) } + vv.AppendView(info.Pkt.LinkHeader) } // Append upper headers. @@ -366,3 +354,30 @@ func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType { } return header.ARPHardwareNone } + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if !e.isTap { + return + } + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) + hdr := &header.EthernetFields{ + SrcAddr: local, + DstAddr: remote, + Type: protocol, + } + if hdr.SrcAddr == "" { + hdr.SrcAddr = e.LinkAddress() + } + + eth.Encode(hdr) +} + +// MaxHeaderLength returns the maximum size of the link layer header. +func (e *tunEndpoint) MaxHeaderLength() uint16 { + if e.isTap { + return header.EthernetMinimumSize + } + return 0 +} diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 24a8dc2eb..b152a0f26 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -60,6 +60,15 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.dispatchGate.Leave() } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if !e.dispatchGate.Enter() { + return + } + e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) + e.dispatchGate.Leave() +} + // Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and // registers with the lower endpoint as its dispatcher so that "e" is called // for inbound packets. @@ -153,3 +162,8 @@ func (e *Endpoint) Wait() {} func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { return e.lower.ARPHardwareType() } + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.lower.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index ffb2354be..c448a888f 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -40,6 +40,10 @@ func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, e.dispatchCount++ } +func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.attachCount++ e.dispatcher = dispatcher @@ -90,6 +94,11 @@ func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { // Wait implements stack.LinkEndpoint.Wait. func (*countedEndpoint) Wait() {} +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestWaitWrite(t *testing.T) { ep := &countedEndpoint{} wep := New(ep) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index a5b780ca2..615bae648 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -185,6 +185,11 @@ func (*testObject) ARPHardwareType() header.ARPHardwareType { panic("not implemented") } +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("not implemented") +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index eefb4b07f..bca1d940b 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -307,6 +307,11 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { panic("not implemented") } +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 7b80534e6..fea0ce7e8 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1200,15 +1200,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // Are any packet sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] - // Check whether there are packet sockets listening for every protocol. - // If we received a packet with protocol EthernetProtocolAll, then the - // previous for loop will have handled it. - if protocol != header.EthernetProtocolAll { - packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) - } + // Add any other packet sockets that maybe listening for all protocols. + packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) n.mu.RUnlock() for _, ep := range packetEPs { - ep.HandlePacket(n.id, local, protocol, pkt.Clone()) + p := pkt.Clone() + p.PktType = tcpip.PacketHost + ep.HandlePacket(n.id, local, protocol, p) } if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { @@ -1311,6 +1309,24 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } +// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. +func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + n.mu.RLock() + // We do not deliver to protocol specific packet endpoints as on Linux + // only ETH_P_ALL endpoints get outbound packets. + // Add any other packet sockets that maybe listening for all protocols. + packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + n.mu.RUnlock() + for _, ep := range packetEPs { + p := pkt.Clone() + p.PktType = tcpip.PacketOutgoing + // Add the link layer header as outgoing packets are intercepted + // before the link layer header is created. + n.linkEP.AddHeader(local, remote, protocol, p) + ep.HandlePacket(n.id, local, protocol, p) + } +} + func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 3bc9fd831..c477e31d8 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -89,6 +89,11 @@ func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType { panic("not implemented") } +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) // An IPv6 NetworkEndpoint that throws away outgoing packets. diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index e3556d5d2..5d6865e35 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -79,6 +79,10 @@ type PacketBuffer struct { // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. NatDone bool + + // PktType indicates the SockAddrLink.PacketType of the packet as defined in + // https://www.man7.org/linux/man-pages/man7/packet.7.html. + PktType tcpip.PacketType } // Clone makes a copy of pk. It clones the Data field, which creates a new diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index f260eeb7f..cd4b7a449 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -330,8 +330,7 @@ type NetworkProtocol interface { } // NetworkDispatcher contains the methods used by the network stack to deliver -// packets to the appropriate network endpoint after it has been handled by -// the data link layer. +// inbound/outbound packets to the appropriate network/packet(if any) endpoints. type NetworkDispatcher interface { // DeliverNetworkPacket finds the appropriate network protocol endpoint // and hands the packet over for further processing. @@ -342,6 +341,16 @@ type NetworkDispatcher interface { // // DeliverNetworkPacket takes ownership of pkt. DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) + + // DeliverOutboundPacket is called by link layer when a packet is being + // sent out. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverOutboundPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. + // + // DeliverOutboundPacket takes ownership of pkt. + DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -443,6 +452,9 @@ type LinkEndpoint interface { // See: // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30 ARPHardwareType() header.ARPHardwareType + + // AddHeader adds a link layer header to pkt if required. + AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 48ad56d4d..ff14a3b3c 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -316,6 +316,28 @@ const ( ShutdownWrite ) +// PacketType is used to indicate the destination of the packet. +type PacketType uint8 + +const ( + // PacketHost indicates a packet addressed to the local host. + PacketHost PacketType = iota + + // PacketOtherHost indicates an outgoing packet addressed to + // another host caught by a NIC in promiscuous mode. + PacketOtherHost + + // PacketOutgoing for a packet originating from the local host + // that is looped back to a packet socket. + PacketOutgoing + + // PacketBroadcast indicates a link layer broadcast packet. + PacketBroadcast + + // PacketMulticast indicates a link layer multicast packet. + PacketMulticast +) + // FullAddress represents a full transport node address, as required by the // Connect() and Bind() methods. // @@ -555,6 +577,9 @@ type Endpoint interface { type LinkPacketInfo struct { // Protocol is the NetworkProtocolNumber for the packet. Protocol NetworkProtocolNumber + + // PktType is used to indicate the destination of the packet. + PktType PacketType } // PacketEndpoint are additional methods that are only implemented by Packet diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 7b2083a09..8f167391f 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -441,6 +441,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, Addr: tcpip.Address(hdr.SourceAddress()), } packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } else { // Guess the would-be ethernet header. packet.senderAddr = tcpip.FullAddress{ @@ -448,30 +449,53 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, Addr: tcpip.Address(localAddr), } packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } if ep.cooked { // Cooked packets can simply be queued. - packet.data = pkt.Data + switch pkt.PktType { + case tcpip.PacketHost: + packet.data = pkt.Data + case tcpip.PacketOutgoing: + // Strip Link Header from the Header. + pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):]) + combinedVV := pkt.Header.View().ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + default: + panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt)) + } + } else { // Raw packets need their ethernet headers prepended before // queueing. var linkHeader buffer.View - if len(pkt.LinkHeader) == 0 { - // We weren't provided with an actual ethernet header, - // so fake one. - ethFields := header.EthernetFields{ - SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), - DstAddr: localAddr, - Type: netProto, + var combinedVV buffer.VectorisedView + if pkt.PktType != tcpip.PacketOutgoing { + if len(pkt.LinkHeader) == 0 { + // We weren't provided with an actual ethernet header, + // so fake one. + ethFields := header.EthernetFields{ + SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + DstAddr: localAddr, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader...) } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - linkHeader = buffer.View(fakeHeader) - } else { - linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + combinedVV = linkHeader.ToVectorisedView() + } + if pkt.PktType == tcpip.PacketOutgoing { + // For outgoing packets the Link, Network and Transport + // headers are in the pkt.Header fields normally unless + // a Raw socket is in use in which case pkt.Header could + // be nil. + combinedVV.AppendView(pkt.Header.View()) } - combinedVV := linkHeader.ToVectorisedView() combinedVV.Append(pkt.Data) packet.data = combinedVV } diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 55d45aaa6..9f52438c2 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -90,6 +90,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/link/fdbased", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/packetsocket", "//pkg/tcpip/link/qdisc/fifo", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/arp", diff --git a/runsc/boot/network.go b/runsc/boot/network.go index 14d2f56a5..4e1fa7665 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/link/packetsocket" "gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" @@ -252,6 +253,9 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct linkEP = fifo.New(linkEP, runtime.GOMAXPROCS(0), 1000) } + // Enable support for AF_PACKET sockets to receive outgoing packets. + linkEP = packetsocket.New(linkEP) + log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels) if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil { return err diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index e94ddcb77..40aa9326d 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -417,6 +417,122 @@ TEST_P(CookedPacketTest, BindDrop) { EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); } +// Verify that we receive outbound packets. This test requires at least one +// non loopback interface so that we can actually capture an outgoing packet. +TEST_P(CookedPacketTest, ReceiveOutbound) { + // Only ETH_P_ALL sockets can receive outbound packets on linux. + SKIP_IF(GetParam() != ETH_P_ALL); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ifaddrs* if_addr_list = nullptr; + auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); }); + + ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds()); + + // Get interface other than loopback. + struct ifreq ifr = {}; + for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) { + if (strcmp(i->ifa_name, "lo") != 0) { + strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name)); + break; + } + } + + // Skip if no interface is available other than loopback. + if (strlen(ifr.ifr_name) == 0) { + GTEST_SKIP(); + } + + // Get interface index and name. + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + int ifindex = ifr.ifr_ifindex; + + constexpr int kMACSize = 6; + char hwaddr[kMACSize]; + // Get interface address. + ASSERT_THAT(ioctl(socket_, SIOCGIFHWADDR, &ifr), SyscallSucceeds()); + ASSERT_THAT(ifr.ifr_hwaddr.sa_family, + AnyOf(Eq(ARPHRD_NONE), Eq(ARPHRD_ETHER))); + memcpy(hwaddr, ifr.ifr_hwaddr.sa_data, kMACSize); + + // Just send it to the google dns server 8.8.8.8. It's UDP we don't care + // if it actually gets to the DNS Server we just want to see that we receive + // it on our AF_PACKET socket. + // + // NOTE: We just want to pick an IP that is non-local to avoid having to + // handle ARP as this should cause the UDP packet to be sent to the default + // gateway configured for the system under test. Otherwise the only packet we + // will see is the ARP query unless we picked an IP which will actually + // resolve. The test is a bit brittle but this was the best compromise for + // now. + struct sockaddr_in dest = {}; + ASSERT_EQ(inet_pton(AF_INET, "8.8.8.8", &dest.sin_addr.s_addr), 1); + dest.sin_family = AF_INET; + dest.sin_port = kPort; + EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0, + reinterpret_cast(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); + + // Wait and make sure the socket receives the data. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(1)); + + // Now read and check that the packet is the one we just sent. + // Read and verify the data. + constexpr size_t packet_size = + sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage); + char buf[64]; + struct sockaddr_ll src = {}; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + reinterpret_cast(&src), &src_len), + SyscallSucceedsWithValue(packet_size)); + + // sockaddr_ll ends with an 8 byte physical address field, but ethernet + // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 + // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns + // sizeof(sockaddr_ll). + ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); + + // Verify the source address. + EXPECT_EQ(src.sll_family, AF_PACKET); + EXPECT_EQ(src.sll_ifindex, ifindex); + EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); + EXPECT_EQ(src.sll_pkttype, PACKET_OUTGOING); + // Verify the link address of the interface matches that of the non + // non loopback interface address we stored above. + for (int i = 0; i < src.sll_halen; i++) { + EXPECT_EQ(src.sll_addr[i], hwaddr[i]); + } + + // Verify the IP header. + struct iphdr ip = {}; + memcpy(&ip, buf, sizeof(ip)); + EXPECT_EQ(ip.ihl, 5); + EXPECT_EQ(ip.version, 4); + EXPECT_EQ(ip.tot_len, htons(packet_size)); + EXPECT_EQ(ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ip.daddr, dest.sin_addr.s_addr); + EXPECT_NE(ip.saddr, htonl(INADDR_LOOPBACK)); + + // Verify the UDP header. + struct udphdr udp = {}; + memcpy(&udp, buf + sizeof(iphdr), sizeof(udp)); + EXPECT_EQ(udp.dest, kPort); + EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage))); + + // Verify the payload. + char* payload = reinterpret_cast(buf + sizeof(iphdr) + sizeof(udphdr)); + EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); +} + // Bind with invalid address. TEST_P(CookedPacketTest, BindFail) { // Null address. -- cgit v1.2.3 From bd98f820141208d9f19b0e12dee93f6f6de3ac97 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 22 Jul 2020 16:22:06 -0700 Subject: iptables: replace maps with arrays For iptables users, Check() is a hot path called for every packet one or more times. Let's avoid a bunch of map lookups. PiperOrigin-RevId: 322678699 --- pkg/sentry/socket/netfilter/netfilter.go | 21 ++-- pkg/tcpip/stack/iptables.go | 187 ++++++++++++++++++------------- pkg/tcpip/stack/iptables_types.go | 22 ++-- 3 files changed, 127 insertions(+), 103 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index f7abe77d3..1243143ea 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -342,10 +342,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table switch replace.Name.String() { - case stack.TablenameFilter: + case stack.FilterTable: table = stack.EmptyFilterTable() - case stack.TablenameNat: - table = stack.EmptyNatTable() + case stack.NATTable: + table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) return syserr.ErrInvalidArgument @@ -431,6 +431,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { for hook, _ := range replace.HookEntry { if table.ValidHooks()&(1< Date: Fri, 5 Jun 2020 11:22:44 -0700 Subject: iptables: don't NAT existing connections Fixes a NAT bug that manifested as: - A SYN was sent from gVisor to another host, unaffected by iptables. - The corresponding SYN/ACK was NATted by a PREROUTING REDIRECT rule despite being part of the existing connection. - The socket that sent the SYN never received the SYN/ACK and thus a connection could not be established. We handle this (as Linux does) by tracking all connections, inserting a no-op conntrack rule for new connections with no rules of their own. Needed for istio support (#170). --- pkg/tcpip/stack/conntrack.go | 136 +++++++++++++++++++++++++++++------- pkg/tcpip/stack/iptables.go | 21 +++++- pkg/tcpip/stack/iptables_targets.go | 2 +- test/iptables/iptables_test.go | 7 ++ test/iptables/nat.go | 52 ++++++++++++++ 5 files changed, 189 insertions(+), 29 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index d39baf620..559a1c4dd 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -49,7 +49,8 @@ const ( type manipType int const ( - manipDstPrerouting manipType = iota + manipNone manipType = iota + manipDstPrerouting manipDstOutput ) @@ -113,13 +114,11 @@ type conn struct { // update the state of tcb. It is immutable. tcbHook Hook - // mu protects tcb. + // mu protects all mutable state. mu sync.Mutex `state:"nosave"` - // tcb is TCB control block. It is used to keep track of states // of tcp connection and is protected by mu. tcb tcpconntrack.TCB - // lastUsed is the last time the connection saw a relevant packet, and // is updated by each packet on the connection. It is protected by mu. lastUsed time.Time `state:".(unixTime)"` @@ -141,8 +140,26 @@ func (cn *conn) timedOut(now time.Time) bool { return now.Sub(cn.lastUsed) > defaultTimeout } +// update the connection tracking state. +// +// Precondition: ct.mu must be held. +func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) { + // Update the state of tcb. tcb assumes it's always initialized on the + // client. However, we only need to know whether the connection is + // established or not, so the client/server distinction isn't important. + // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle + // other tcp states. + if ct.tcb.IsEmpty() { + ct.tcb.Init(tcpHeader) + } else if hook == ct.tcbHook { + ct.tcb.UpdateStateOutbound(tcpHeader) + } else { + ct.tcb.UpdateStateInbound(tcpHeader) + } +} + // ConnTrack tracks all connections created for NAT rules. Most users are -// expected to only call handlePacket and createConnFor. +// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop. // // ConnTrack keeps all connections in a slice of buckets, each of which holds a // linked list of tuples. This gives us some desirable properties: @@ -248,8 +265,7 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { return nil, dirOriginal } -// createConnFor creates a new conn for pkt. -func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -272,10 +288,15 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg manip = manipDstOutput } conn := newConn(tid, replyTID, manip, hook) + ct.insertConn(conn) + return conn +} +// insertConn inserts conn into the appropriate table bucket. +func (ct *ConnTrack) insertConn(conn *conn) { // Lock the buckets in the correct order. - tupleBucket := ct.bucket(tid) - replyBucket := ct.bucket(replyTID) + tupleBucket := ct.bucket(conn.original.tupleID) + replyBucket := ct.bucket(conn.reply.tupleID) ct.mu.RLock() defer ct.mu.RUnlock() if tupleBucket < replyBucket { @@ -289,22 +310,37 @@ func (ct *ConnTrack) createConnFor(pkt *PacketBuffer, hook Hook, rt RedirectTarg ct.buckets[tupleBucket].mu.Lock() } - // Add the tuple to the map. - ct.buckets[tupleBucket].tuples.PushFront(&conn.original) - ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + // Now that we hold the locks, ensure the tuple hasn't been inserted by + // another thread. + alreadyInserted := false + for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { + if other.tupleID == conn.original.tupleID { + alreadyInserted = true + break + } + } + + if !alreadyInserted { + // Add the tuple to the map. + ct.buckets[tupleBucket].tuples.PushFront(&conn.original) + ct.buckets[replyBucket].tuples.PushFront(&conn.reply) + } // Unlocking can happen in any order. ct.buckets[tupleBucket].mu.Unlock() if tupleBucket != replyBucket { ct.buckets[replyBucket].mu.Unlock() } - - return conn } // handlePacketPrerouting manipulates ports for packets in Prerouting hook. // TODO(gvisor.dev/issue/170): Change address for Prerouting hook. func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -322,12 +358,22 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { netHeader.SetSourceAddress(conn.original.dstAddr) } + // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated + // on inbound packets, so we don't recalculate them. However, we should + // support cases when they are validated, e.g. when we can't offload + // receive checksumming. + netHeader.SetChecksum(0) netHeader.SetChecksum(^netHeader.CalculateChecksum()) } // handlePacketOutput manipulates ports for packets in Output hook. func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -362,20 +408,31 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d } // handlePacket will manipulate the port and address of the packet if the -// connection exists. -func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { +// connection exists. Returns whether, after the packet traverses the tables, +// it should create a new entry in the table. +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool { if pkt.NatDone { - return + return false } if hook != Prerouting && hook != Output { - return + return false + } + + // TODO(gvisor.dev/issue/170): Support other transport protocols. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return false } conn, dir := ct.connFor(pkt) + // Connection or Rule not found for the packet. if conn == nil { - // Connection not found for the packet or the packet is invalid. - return + return true + } + + tcpHeader := header.TCP(pkt.TransportHeader) + if tcpHeader == nil { + return false } switch hook { @@ -395,14 +452,39 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou // Mark the connection as having been used recently so it isn't reaped. conn.lastUsed = time.Now() // Update connection state. - if tcpHeader := header.TCP(pkt.TransportHeader); conn.tcb.IsEmpty() { - conn.tcb.Init(tcpHeader) - conn.tcbHook = hook - } else if hook == conn.tcbHook { - conn.tcb.UpdateStateOutbound(tcpHeader) - } else { - conn.tcb.UpdateStateInbound(tcpHeader) + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + + return false +} + +// maybeInsertNoop tries to insert a no-op connection entry to keep connections +// from getting clobbered when replies arrive. It only inserts if there isn't +// already a connection for pkt. +// +// This should be called after traversing iptables rules only, to ensure that +// pkt.NatDone is set correctly. +func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { + // If there were a rule applying to this packet, it would be marked + // with NatDone. + if pkt.NatDone { + return + } + + // We only track TCP connections. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return + } + + // This is the first packet we're seeing for the TCP connection. Insert + // the noop entry (an identity mapping) so that the response doesn't + // get NATed, breaking the connection. + tid, err := packetToTupleID(pkt) + if err != nil { + return } + conn := newConn(tid, tid.reply(), manipNone, hook) + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + ct.insertConn(conn) } // bucket gets the conntrack bucket for a tupleID. diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 5f647c5fe..ca1dda695 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -245,13 +245,18 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Packets are manipulated only if connection and matching // NAT rule exists. - it.connections.handlePacket(pkt, hook, gso, r) + shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) // Go through each table containing the hook. it.mu.RLock() defer it.mu.RUnlock() priorities := it.priorities[hook] for _, tableID := range priorities { + // If handlePacket already NATed the packet, we don't need to + // check the NAT table. + if tableID == natID && pkt.NatDone { + continue + } table := it.tables[tableID] ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { @@ -281,6 +286,20 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr } } + // If this connection should be tracked, try to add an entry for it. If + // traversing the nat table didn't end in adding an entry, + // maybeInsertNoop will add a no-op entry for the connection. This is + // needeed when establishing connections so that the SYN/ACK reply to an + // outgoing SYN is delivered to the correct endpoint rather than being + // redirected by a prerouting rule. + // + // From the iptables documentation: "If there is no rule, a `null' + // binding is created: this usually does not map the packet, but exists + // to ensure we don't map another stream over an existing one." + if shouldTrack { + it.connections.maybeInsertNoop(pkt, hook) + } + // Every table returned Accept. return true } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index d43f60c67..dc88033c7 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -153,7 +153,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // Set up conection for matching NAT rule. Only the first // packet of the connection comes here. Other packets will be // manipulated in connection tracking. - if conn := ct.createConnFor(pkt, hook, rt); conn != nil { + if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { ct.handlePacket(pkt, hook, gso, r) } default: diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index f5ac79370..f303030aa 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -263,6 +263,13 @@ func TestNATPreRedirectTCPPort(t *testing.T) { singleTest(t, NATPreRedirectTCPPort{}) } +func TestNATPreRedirectTCPOutgoing(t *testing.T) { + singleTest(t, NATPreRedirectTCPOutgoing{}) +} + +func TestNATOutRedirectTCPIncoming(t *testing.T) { + singleTest(t, NATOutRedirectTCPIncoming{}) +} func TestNATOutRedirectUDPPort(t *testing.T) { singleTest(t, NATOutRedirectUDPPort{}) } diff --git a/test/iptables/nat.go b/test/iptables/nat.go index 8562b0820..149dec2bb 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -28,6 +28,8 @@ const ( func init() { RegisterTestCase(NATPreRedirectUDPPort{}) RegisterTestCase(NATPreRedirectTCPPort{}) + RegisterTestCase(NATPreRedirectTCPOutgoing{}) + RegisterTestCase(NATOutRedirectTCPIncoming{}) RegisterTestCase(NATOutRedirectUDPPort{}) RegisterTestCase(NATOutRedirectTCPPort{}) RegisterTestCase(NATDropUDP{}) @@ -91,6 +93,56 @@ func (NATPreRedirectTCPPort) LocalAction(ip net.IP) error { return connectTCP(ip, dropPort, sendloopDuration) } +// NATPreRedirectTCPOutgoing verifies that outgoing TCP connections aren't +// affected by PREROUTING connection tracking. +type NATPreRedirectTCPOutgoing struct{} + +// Name implements TestCase.Name. +func (NATPreRedirectTCPOutgoing) Name() string { + return "NATPreRedirectTCPOutgoing" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreRedirectTCPOutgoing) ContainerAction(ip net.IP) error { + // Redirect all incoming TCP traffic to a closed port. + if err := natTable("-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + + // Establish a connection to the host process. + return connectTCP(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRedirectTCPOutgoing) LocalAction(ip net.IP) error { + return listenTCP(acceptPort, sendloopDuration) +} + +// NATOutRedirectTCPIncoming verifies that incoming TCP connections aren't +// affected by OUTPUT connection tracking. +type NATOutRedirectTCPIncoming struct{} + +// Name implements TestCase.Name. +func (NATOutRedirectTCPIncoming) Name() string { + return "NATOutRedirectTCPIncoming" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectTCPIncoming) ContainerAction(ip net.IP) error { + // Redirect all outgoing TCP traffic to a closed port. + if err := natTable("-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + + // Establish a connection to the host process. + return listenTCP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectTCPIncoming) LocalAction(ip net.IP) error { + return connectTCP(ip, acceptPort, sendloopDuration) +} + // NATOutRedirectUDPPort tests that packets are redirected to different port. type NATOutRedirectUDPPort struct{} -- cgit v1.2.3 From fb8be7e6273f5a646cdf48e38743a2507a4bf64f Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 9 Jul 2020 22:37:11 -0700 Subject: make connect(2) fail when dest is unreachable Previously, ICMP destination unreachable datagrams were ignored by TCP endpoints. This caused connect to hang when an intermediate router couldn't find a route to the host. This manifested as a Kokoro error when Docker IPv6 was enabled. The Ruby image test would try to install the sinatra gem and hang indefinitely attempting to use an IPv6 address. Fixes #3079. --- pkg/tcpip/header/icmpv4.go | 1 + pkg/tcpip/header/icmpv6.go | 11 +- pkg/tcpip/network/ipv4/icmp.go | 3 + pkg/tcpip/network/ipv6/icmp.go | 2 + pkg/tcpip/stack/registration.go | 5 +- pkg/tcpip/transport/tcp/connect.go | 6 + pkg/tcpip/transport/tcp/endpoint.go | 26 +++- pkg/test/dockerutil/exec.go | 1 - test/packetimpact/runner/packetimpact_test.go | 8 +- test/packetimpact/testbench/connections.go | 58 ++++++++- test/packetimpact/testbench/layers.go | 6 +- test/packetimpact/tests/BUILD | 10 ++ .../tests/tcp_network_unreachable_test.go | 139 +++++++++++++++++++++ 13 files changed, 262 insertions(+), 14 deletions(-) create mode 100644 test/packetimpact/tests/tcp_network_unreachable_test.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 7908c5744..1a631b31a 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -72,6 +72,7 @@ const ( // Values for ICMP code as defined in RFC 792. const ( ICMPv4TTLExceeded = 0 + ICMPv4HostUnreachable = 1 ICMPv4PortUnreachable = 3 ICMPv4FragmentationNeeded = 4 ) diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index c7ee2de57..a13b4b809 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -110,9 +110,16 @@ const ( ICMPv6RedirectMsg ICMPv6Type = 137 ) -// Values for ICMP code as defined in RFC 4443. +// Values for ICMP destination unreachable code as defined in RFC 4443 section +// 3.1. const ( - ICMPv6PortUnreachable = 4 + ICMPv6NetworkUnreachable = 0 + ICMPv6Prohibited = 1 + ICMPv6BeyondScope = 2 + ICMPv6AddressUnreachable = 3 + ICMPv6PortUnreachable = 4 + ICMPv6Policy = 5 + ICMPv6RejectRoute = 6 ) // Type is the ICMP type field. diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 1b67aa066..83e71cb8c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -129,6 +129,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { + case header.ICMPv4HostUnreachable: + e.handleControl(stack.ControlNoRoute, 0, pkt) + case header.ICMPv4PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 3b79749b5..ff1cb53dd 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -128,6 +128,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch header.ICMPv6(hdr).Code() { + case header.ICMPv6NetworkUnreachable: + e.handleControl(stack.ControlNetworkUnreachable, 0, pkt) case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index cd4b7a449..9e1b2d25f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -52,8 +52,11 @@ type TransportEndpointID struct { type ControlType int // The following are the allowed values for ControlType values. +// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages. const ( - ControlPacketTooBig ControlType = iota + ControlNetworkUnreachable ControlType = iota + ControlNoRoute + ControlPacketTooBig ControlPortUnreachable ControlUnknown ) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 81b740115..1798510bc 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + if n¬ifyError != 0 { + return h.ep.takeLastError() + } } // Wait for notification. @@ -616,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + if n¬ifyError != 0 { + return h.ep.takeLastError() + } case wakerForNewSegment: if err := h.processSegments(); err != nil { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 83dc10ed0..0f7487963 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1209,6 +1209,14 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } +func (e *endpoint) takeLastError() *tcpip.Error { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + err := e.lastError + e.lastError = nil + return err +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -1956,11 +1964,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: - e.lastErrorMu.Lock() - err := e.lastError - e.lastError = nil - e.lastErrorMu.Unlock() - return err + return e.takeLastError() case *tcpip.BindToDeviceOption: e.LockUser() @@ -2546,6 +2550,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.sndBufMu.Unlock() e.notifyProtocolGoroutine(notifyMTUChanged) + + case stack.ControlNoRoute: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNoRoute + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) + + case stack.ControlNetworkUnreachable: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNetworkUnreachable + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) } } diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go index 921d1da9e..4c739c9e9 100644 --- a/pkg/test/dockerutil/exec.go +++ b/pkg/test/dockerutil/exec.go @@ -87,7 +87,6 @@ func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Proc execid: resp.ID, conn: hijack, }, nil - } func (c *Container) execConfig(r ExecOpts, cmd []string) types.ExecConfig { diff --git a/test/packetimpact/runner/packetimpact_test.go b/test/packetimpact/runner/packetimpact_test.go index ff5f5c7f1..1a0221893 100644 --- a/test/packetimpact/runner/packetimpact_test.go +++ b/test/packetimpact/runner/packetimpact_test.go @@ -280,11 +280,13 @@ func TestOne(t *testing.T) { } // Because the Linux kernel receives the SYN-ACK but didn't send the SYN it - // will issue a RST. To prevent this IPtables can be used to filter out all + // will issue an RST. To prevent this IPtables can be used to filter out all // incoming packets. The raw socket that packetimpact tests use will still see // everything. - if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, "iptables", "-A", "INPUT", "-i", testNetDev, "-j", "DROP"); err != nil { - t.Fatalf("unable to Exec iptables on container %s: %s, logs from testbench:\n%s", testbench.Name, err, logs) + for _, bin := range []string{"iptables", "ip6tables"} { + if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", testNetDev, "-p", "tcp", "-j", "DROP"); err != nil { + t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbench.Name, err, logs) + } } // FIXME(b/156449515): Some piece of the system has a race. The old diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 5d9cec73e..87ce58c24 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -41,7 +41,8 @@ func portFromSockaddr(sa unix.Sockaddr) (uint16, error) { return 0, fmt.Errorf("sockaddr type %T does not contain port", sa) } -// pickPort makes a new socket and returns the socket FD and port. The domain should be AF_INET or AF_INET6. The caller must close the FD when done with +// pickPort makes a new socket and returns the socket FD and port. The domain +// should be AF_INET or AF_INET6. The caller must close the FD when done with // the port if there is no error. func pickPort(domain, typ int) (fd int, port uint16, err error) { fd, err = unix.Socket(domain, typ, 0) @@ -1061,3 +1062,58 @@ func (conn *UDPIPv6) Close() { func (conn *UDPIPv6) Drain() { conn.sniffer.Drain() } + +// TCPIPv6 maintains the state for all the layers in a TCP/IPv6 connection. +type TCPIPv6 Connection + +// NewTCPIPv6 creates a new TCPIPv6 connection with reasonable defaults. +func NewTCPIPv6(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv6 { + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv6State, err := newIPv6State(IPv6{}, IPv6{}) + if err != nil { + t.Fatalf("can't make ipv6State: %s", err) + } + tcpState, err := newTCPState(unix.AF_INET6, outgoingTCP, incomingTCP) + if err != nil { + t.Fatalf("can't make tcpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return TCPIPv6{ + layerStates: []layerState{etherState, ipv6State, tcpState}, + injector: injector, + sniffer: sniffer, + t: t, + } +} + +func (conn *TCPIPv6) SrcPort() uint16 { + state := conn.layerStates[2].(*tcpState) + return *state.out.SrcPort +} + +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *TCPIPv6) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = tcp + if payload != nil { + expected = append(expected, payload) + } + return (*Connection)(conn).ExpectFrame(expected, timeout) +} + +// Close frees associated resources held by the TCPIPv6 connection. +func (conn *TCPIPv6) Close() { + (*Connection)(conn).Close() +} diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index 645f6c1a9..24aa46cce 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -805,7 +805,11 @@ func (l *ICMPv6) ToBytes() ([]byte, error) { // We need to search forward to find the IPv6 header. for prev := l.Prev(); prev != nil; prev = prev.Prev() { if ipv6, ok := prev.(*IPv6); ok { - h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, buffer.VectorisedView{})) + payload, err := payload(l) + if err != nil { + return nil, err + } + h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, payload)) break } } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 6a07889be..27905dcff 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -219,6 +219,16 @@ packetimpact_go_test( ], ) +packetimpact_go_test( + name = "tcp_network_unreachable", + srcs = ["tcp_network_unreachable_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + packetimpact_go_test( name = "tcp_cork_mss", srcs = ["tcp_cork_mss_test.go"], diff --git a/test/packetimpact/tests/tcp_network_unreachable_test.go b/test/packetimpact/tests/tcp_network_unreachable_test.go new file mode 100644 index 000000000..868a08da8 --- /dev/null +++ b/test/packetimpact/tests/tcp_network_unreachable_test.go @@ -0,0 +1,139 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_synsent_reset_test + +import ( + "context" + "flag" + "net" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTCPSynSentUnreachable verifies that TCP connections fail immediately when +// an ICMP destination unreachable message is sent in response to the inital +// SYN. +func TestTCPSynSentUnreachable(t *testing.T) { + // Create the DUT and connection. + dut := testbench.NewDUT(t) + defer dut.TearDown() + clientFD, clientPort := dut.CreateBoundSocket(unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) + port := uint16(9001) + conn := testbench.NewTCPIPv4(t, testbench.TCP{SrcPort: &port, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &port}) + defer conn.Close() + + // Bring the DUT to SYN-SENT state with a non-blocking connect. + ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout) + defer cancel() + sa := unix.SockaddrInet4{Port: int(port)} + copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv4)).To4()) + if _, err := dut.ConnectWithErrno(ctx, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) { + t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err) + } + + // Get the SYN. + tcpLayers, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second) + if err != nil { + t.Fatalf("expected SYN: %s", err) + } + + // Send a host unreachable message. + rawConn := (*testbench.Connection)(&conn) + layers := rawConn.CreateFrame(nil) + layers = layers[:len(layers)-1] + const ipLayer = 1 + const tcpLayer = ipLayer + 1 + ip, ok := tcpLayers[ipLayer].(*testbench.IPv4) + if !ok { + t.Fatalf("expected %s to be IPv4", tcpLayers[ipLayer]) + } + tcp, ok := tcpLayers[tcpLayer].(*testbench.TCP) + if !ok { + t.Fatalf("expected %s to be TCP", tcpLayers[tcpLayer]) + } + var icmpv4 testbench.ICMPv4 = testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), Code: testbench.Uint8(header.ICMPv4HostUnreachable)} + layers = append(layers, &icmpv4, ip, tcp) + rawConn.SendFrameStateless(layers) + + if _, err = dut.ConnectWithErrno(ctx, clientFD, &sa); err != syscall.Errno(unix.EHOSTUNREACH) { + t.Errorf("expected connect to fail with EHOSTUNREACH, but got %v", err) + } +} + +// TestTCPSynSentUnreachable6 verifies that TCP connections fail immediately when +// an ICMP destination unreachable message is sent in response to the inital +// SYN. +func TestTCPSynSentUnreachable6(t *testing.T) { + // Create the DUT and connection. + dut := testbench.NewDUT(t) + defer dut.TearDown() + clientFD, clientPort := dut.CreateBoundSocket(unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv6)) + conn := testbench.NewTCPIPv6(t, testbench.TCP{DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort}) + defer conn.Close() + + // Bring the DUT to SYN-SENT state with a non-blocking connect. + ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout) + defer cancel() + sa := unix.SockaddrInet6{ + Port: int(conn.SrcPort()), + ZoneId: uint32(testbench.RemoteInterfaceID), + } + copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv6)).To16()) + if _, err := dut.ConnectWithErrno(ctx, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) { + t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err) + } + + // Get the SYN. + tcpLayers, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second) + if err != nil { + t.Fatalf("expected SYN: %s", err) + } + + // Send a host unreachable message. + rawConn := (*testbench.Connection)(&conn) + layers := rawConn.CreateFrame(nil) + layers = layers[:len(layers)-1] + const ipLayer = 1 + const tcpLayer = ipLayer + 1 + ip, ok := tcpLayers[ipLayer].(*testbench.IPv6) + if !ok { + t.Fatalf("expected %s to be IPv6", tcpLayers[ipLayer]) + } + tcp, ok := tcpLayers[tcpLayer].(*testbench.TCP) + if !ok { + t.Fatalf("expected %s to be TCP", tcpLayers[tcpLayer]) + } + var icmpv6 testbench.ICMPv6 = testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6DstUnreachable), + Code: testbench.Uint8(header.ICMPv6NetworkUnreachable), + // Per RFC 4443 3.1, the payload contains 4 zeroed bytes. + Payload: []byte{0, 0, 0, 0}, + } + layers = append(layers, &icmpv6, ip, tcp) + rawConn.SendFrameStateless(layers) + + if _, err = dut.ConnectWithErrno(ctx, clientFD, &sa); err != syscall.Errno(unix.ENETUNREACH) { + t.Errorf("expected connect to fail with ENETUNREACH, but got %v", err) + } +} -- cgit v1.2.3 From dd530eeeff09128d4c2428e1d6f24205a29e661e Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 23 Jul 2020 15:44:09 -0700 Subject: iptables: use keyed array literals PiperOrigin-RevId: 322882426 --- pkg/tcpip/stack/iptables.go | 93 ++++++++++++++++++--------------------------- 1 file changed, 37 insertions(+), 56 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index ca1dda695..cbbae4224 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -58,8 +58,7 @@ const reaperDelay = 5 * time.Second func DefaultTables() *IPTables { return &IPTables{ tables: [numTables]Table{ - // NAT table. - Table{ + natID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, @@ -68,22 +67,21 @@ func DefaultTables() *IPTables { Rule{Target: ErrorTarget{}}, }, BuiltinChains: [NumHooks]int{ - 0, // Prerouting. - 1, // Input. - HookUnset, // Forward. - 2, // Output. - 3, // Postrouting. + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, }, Underflows: [NumHooks]int{ - 0, // Prerouting. - 1, // Input. - HookUnset, // Forward. - 2, // Output. - 3, // Postrouting. + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, }, }, - // Mangle table. - Table{ + mangleID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, @@ -94,15 +92,14 @@ func DefaultTables() *IPTables { Output: 1, }, Underflows: [NumHooks]int{ - 0, // Prerouting. - HookUnset, // Input. - HookUnset, // Forward. - 1, // Output. - HookUnset, // Postrouting. + Prerouting: 0, + Input: HookUnset, + Forward: HookUnset, + Output: 1, + Postrouting: HookUnset, }, }, - // Filter table. - Table{ + filterID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, @@ -110,27 +107,25 @@ func DefaultTables() *IPTables { Rule{Target: ErrorTarget{}}, }, BuiltinChains: [NumHooks]int{ - HookUnset, // Prerouting. - Input: 0, // Input. - Forward: 1, // Forward. - Output: 2, // Output. - HookUnset, // Postrouting. + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, }, Underflows: [NumHooks]int{ - HookUnset, // Prerouting. - 0, // Input. - 1, // Forward. - 2, // Output. - HookUnset, // Postrouting. + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, }, }, }, priorities: [NumHooks][]tableID{ - []tableID{mangleID, natID}, // Prerouting. - []tableID{natID, filterID}, // Input. - []tableID{}, // Forward. - []tableID{mangleID, natID, filterID}, // Output. - []tableID{}, // Postrouting. + Prerouting: []tableID{mangleID, natID}, + Input: []tableID{natID, filterID}, + Output: []tableID{mangleID, natID, filterID}, }, connections: ConnTrack{ seed: generateRandUint32(), @@ -145,18 +140,12 @@ func EmptyFilterTable() Table { return Table{ Rules: []Rule{}, BuiltinChains: [NumHooks]int{ - HookUnset, - 0, - 0, - 0, - HookUnset, + Prerouting: HookUnset, + Postrouting: HookUnset, }, Underflows: [NumHooks]int{ - HookUnset, - 0, - 0, - 0, - HookUnset, + Prerouting: HookUnset, + Postrouting: HookUnset, }, } } @@ -167,18 +156,10 @@ func EmptyNATTable() Table { return Table{ Rules: []Rule{}, BuiltinChains: [NumHooks]int{ - 0, - 0, - HookUnset, - 0, - 0, + Forward: HookUnset, }, Underflows: [NumHooks]int{ - 0, - 0, - HookUnset, - 0, - 0, + Forward: HookUnset, }, } } -- cgit v1.2.3 From 82a5cada5944390e738a8b7235fb861965ca40f7 Mon Sep 17 00:00:00 2001 From: Sam Balana Date: Thu, 23 Jul 2020 17:59:12 -0700 Subject: Add AfterFunc to tcpip.Clock Changes the API of tcpip.Clock to also provide a method for scheduling and rescheduling work after a specified duration. This change also implements the AfterFunc method for existing implementations of tcpip.Clock. This is the groundwork required to mock time within tests. All references to CancellableTimer has been replaced with the tcpip.Job interface, allowing for custom implementations of scheduling work. This is a BREAKING CHANGE for clients that implement their own tcpip.Clock or use tcpip.CancellableTimer. Migration plan: 1. Add AfterFunc(d, f) to tcpip.Clock 2. Replace references of tcpip.CancellableTimer with tcpip.Job 3. Replace calls to tcpip.CancellableTimer#StopLocked with tcpip.Job#Cancel 4. Replace calls to tcpip.CancellableTimer#Reset with tcpip.Job#Schedule 5. Replace calls to tcpip.NewCancellableTimer with tcpip.NewJob. PiperOrigin-RevId: 322906897 --- pkg/sentry/kernel/kernel.go | 5 ++ pkg/sentry/kernel/time/BUILD | 1 + pkg/sentry/kernel/time/tcpip.go | 131 +++++++++++++++++++++++++++++ pkg/tcpip/stack/ndp.go | 140 +++++++++++++++---------------- pkg/tcpip/stack/ndp_test.go | 28 +++---- pkg/tcpip/stack/stack.go | 12 ++- pkg/tcpip/tcpip.go | 27 +++++- pkg/tcpip/time_unsafe.go | 30 ++++++- pkg/tcpip/timer.go | 147 ++++++++++++++++++++------------- pkg/tcpip/timer_test.go | 91 ++++++++++---------- pkg/tcpip/transport/icmp/endpoint.go | 2 +- pkg/tcpip/transport/packet/endpoint.go | 2 +- pkg/tcpip/transport/raw/endpoint.go | 2 +- pkg/tcpip/transport/udp/endpoint.go | 2 +- 14 files changed, 429 insertions(+), 191 deletions(-) create mode 100644 pkg/sentry/kernel/time/tcpip.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 240cd6fe0..15dae0f5b 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -1469,6 +1469,11 @@ func (k *Kernel) NowMonotonic() int64 { return now } +// AfterFunc implements tcpip.Clock.AfterFunc. +func (k *Kernel) AfterFunc(d time.Duration, f func()) tcpip.Timer { + return ktime.TcpipAfterFunc(k.realtimeClock, d, f) +} + // SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or // LoadFrom. func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) { diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 7ba7dc50c..2817aa3ba 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -6,6 +6,7 @@ go_library( name = "time", srcs = [ "context.go", + "tcpip.go", "time.go", ], visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/kernel/time/tcpip.go b/pkg/sentry/kernel/time/tcpip.go new file mode 100644 index 000000000..c4474c0cf --- /dev/null +++ b/pkg/sentry/kernel/time/tcpip.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package time + +import ( + "sync" + "time" +) + +// TcpipAfterFunc waits for duration to elapse according to clock then runs fn. +// The timer is started immediately and will fire exactly once. +func TcpipAfterFunc(clock Clock, duration time.Duration, fn func()) *TcpipTimer { + timer := &TcpipTimer{ + clock: clock, + } + timer.notifier = functionNotifier{ + fn: func() { + // tcpip.Timer.Stop() explicitly states that the function is called in a + // separate goroutine that Stop() does not synchronize with. + // Timer.Destroy() synchronizes with calls to TimerListener.Notify(). + // This is semantically meaningful because, in the former case, it's + // legal to call tcpip.Timer.Stop() while holding locks that may also be + // taken by the function, but this isn't so in the latter case. Most + // immediately, Timer calls TimerListener.Notify() while holding + // Timer.mu. A deadlock occurs without spawning a goroutine: + // T1: (Timer expires) + // => Timer.Tick() <- Timer.mu.Lock() called + // => TimerListener.Notify() + // => Timer.Stop() + // => Timer.Destroy() <- Timer.mu.Lock() called, deadlock! + // + // Spawning a goroutine avoids the deadlock: + // T1: (Timer expires) + // => Timer.Tick() <- Timer.mu.Lock() called + // => TimerListener.Notify() <- Launches T2 + // T2: + // => Timer.Stop() + // => Timer.Destroy() <- Timer.mu.Lock() called, blocks + // T1: + // => (returns) <- Timer.mu.Unlock() called + // T2: + // => (continues) <- No deadlock! + go func() { + timer.Stop() + fn() + }() + }, + } + timer.Reset(duration) + return timer +} + +// TcpipTimer is a resettable timer with variable duration expirations. +// Implements tcpip.Timer, which does not define a Destroy method; instead, all +// resources are released after timer expiration and calls to Timer.Stop. +// +// Must be created by AfterFunc. +type TcpipTimer struct { + // clock is the time source. clock is immutable. + clock Clock + + // notifier is called when the Timer expires. notifier is immutable. + notifier functionNotifier + + // mu protects t. + mu sync.Mutex + + // t stores the latest running Timer. This is replaced whenever Reset is + // called since Timer cannot be restarted once it has been Destroyed by Stop. + // + // This field is nil iff Stop has been called. + t *Timer +} + +// Stop implements tcpip.Timer.Stop. +func (r *TcpipTimer) Stop() bool { + r.mu.Lock() + defer r.mu.Unlock() + + if r.t == nil { + return false + } + _, lastSetting := r.t.Swap(Setting{}) + r.t.Destroy() + r.t = nil + return lastSetting.Enabled +} + +// Reset implements tcpip.Timer.Reset. +func (r *TcpipTimer) Reset(d time.Duration) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.t == nil { + r.t = NewTimer(r.clock, &r.notifier) + } + + r.t.Swap(Setting{ + Enabled: true, + Period: 0, + Next: r.clock.Now().Add(d), + }) +} + +// functionNotifier is a TimerListener that runs a function. +// +// functionNotifier cannot be saved or loaded. +type functionNotifier struct { + fn func() +} + +// Notify implements ktime.TimerListener.Notify. +func (f *functionNotifier) Notify(uint64, Setting) (Setting, bool) { + f.fn() + return Setting{}, false +} + +// Destroy implements ktime.TimerListener.Destroy. +func (f *functionNotifier) Destroy() {} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index e28c23d66..9dce11a97 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -469,7 +469,7 @@ type ndpState struct { rtrSolicit struct { // The timer used to send the next router solicitation message. - timer *time.Timer + timer tcpip.Timer // Used to let the Router Solicitation timer know that it has been stopped. // @@ -503,7 +503,7 @@ type ndpState struct { // to the DAD goroutine that DAD should stop. type dadState struct { // The DAD timer to send the next NS message, or resolve the address. - timer *time.Timer + timer tcpip.Timer // Used to let the DAD timer know that it has been stopped. // @@ -515,38 +515,38 @@ type dadState struct { // defaultRouterState holds data associated with a default router discovered by // a Router Advertisement (RA). type defaultRouterState struct { - // Timer to invalidate the default router. + // Job to invalidate the default router. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job } // onLinkPrefixState holds data associated with an on-link prefix discovered by // a Router Advertisement's Prefix Information option (PI) when the NDP // configurations was configured to do so. type onLinkPrefixState struct { - // Timer to invalidate the on-link prefix. + // Job to invalidate the on-link prefix. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job } // tempSLAACAddrState holds state associated with a temporary SLAAC address. type tempSLAACAddrState struct { - // Timer to deprecate the temporary SLAAC address. + // Job to deprecate the temporary SLAAC address. // // Must not be nil. - deprecationTimer *tcpip.CancellableTimer + deprecationJob *tcpip.Job - // Timer to invalidate the temporary SLAAC address. + // Job to invalidate the temporary SLAAC address. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job - // Timer to regenerate the temporary SLAAC address. + // Job to regenerate the temporary SLAAC address. // // Must not be nil. - regenTimer *tcpip.CancellableTimer + regenJob *tcpip.Job createdAt time.Time @@ -561,15 +561,15 @@ type tempSLAACAddrState struct { // slaacPrefixState holds state associated with a SLAAC prefix. type slaacPrefixState struct { - // Timer to deprecate the prefix. + // Job to deprecate the prefix. // // Must not be nil. - deprecationTimer *tcpip.CancellableTimer + deprecationJob *tcpip.Job - // Timer to invalidate the prefix. + // Job to invalidate the prefix. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job // Nonzero only when the address is not valid forever. validUntil time.Time @@ -651,12 +651,12 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref } var done bool - var timer *time.Timer + var timer tcpip.Timer // We initially start a timer to fire immediately because some of the DAD work // cannot be done while holding the NIC's lock. This is effectively the same // as starting a goroutine but we use a timer that fires immediately so we can // reset it for the next DAD iteration. - timer = time.AfterFunc(0, func() { + timer = ndp.nic.stack.Clock().AfterFunc(0, func() { ndp.nic.mu.Lock() defer ndp.nic.mu.Unlock() @@ -871,9 +871,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { case ok && rl != 0: // This is an already discovered default router. Update - // the invalidation timer. - rtr.invalidationTimer.StopLocked() - rtr.invalidationTimer.Reset(rl) + // the invalidation job. + rtr.invalidationJob.Cancel() + rtr.invalidationJob.Schedule(rl) ndp.defaultRouters[ip] = rtr case ok && rl == 0: @@ -950,7 +950,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { return } - rtr.invalidationTimer.StopLocked() + rtr.invalidationJob.Cancel() delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. @@ -979,12 +979,12 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { } state := defaultRouterState{ - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { ndp.invalidateDefaultRouter(ip) }), } - state.invalidationTimer.Reset(rl) + state.invalidationJob.Schedule(rl) ndp.defaultRouters[ip] = state } @@ -1009,13 +1009,13 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) } state := onLinkPrefixState{ - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { ndp.invalidateOnLinkPrefix(prefix) }), } if l < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(l) + state.invalidationJob.Schedule(l) } ndp.onLinkPrefixes[prefix] = state @@ -1033,7 +1033,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { return } - s.invalidationTimer.StopLocked() + s.invalidationJob.Cancel() delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. @@ -1082,14 +1082,14 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // This is an already discovered on-link prefix with a // new non-zero valid lifetime. // - // Update the invalidation timer. + // Update the invalidation job. - prefixState.invalidationTimer.StopLocked() + prefixState.invalidationJob.Cancel() if vl < header.NDPInfiniteLifetime { - // Prefix is valid for a finite lifetime, reset the timer to expire after + // Prefix is valid for a finite lifetime, schedule the job to execute after // the new valid lifetime. - prefixState.invalidationTimer.Reset(vl) + prefixState.invalidationJob.Schedule(vl) } ndp.onLinkPrefixes[prefix] = prefixState @@ -1154,7 +1154,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { } state := slaacPrefixState{ - deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) @@ -1162,7 +1162,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { ndp.deprecateSLAACAddress(state.stableAddr.ref) }), - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) @@ -1184,19 +1184,19 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { if !ndp.generateSLAACAddr(prefix, &state) { // We were unable to generate an address for the prefix, we do not nothing - // further as there is no reason to maintain state or timers for a prefix we + // further as there is no reason to maintain state or jobs for a prefix we // do not have an address for. return } - // Setup the initial timers to deprecate and invalidate prefix. + // Setup the initial jobs to deprecate and invalidate prefix. if pl < header.NDPInfiniteLifetime && pl != 0 { - state.deprecationTimer.Reset(pl) + state.deprecationJob.Schedule(pl) } if vl < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(vl) + state.invalidationJob.Schedule(vl) state.validUntil = now.Add(vl) } @@ -1428,7 +1428,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla } state := tempSLAACAddrState{ - deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) @@ -1441,7 +1441,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.deprecateSLAACAddress(tempAddrState.ref) }), - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) @@ -1454,7 +1454,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) }), - regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) @@ -1481,9 +1481,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ref: ref, } - state.deprecationTimer.Reset(pl) - state.invalidationTimer.Reset(vl) - state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration) + state.deprecationJob.Schedule(pl) + state.invalidationJob.Schedule(vl) + state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration) prefixState.generationAttempts++ prefixState.tempAddrs[generatedAddr.Address] = state @@ -1518,16 +1518,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat prefixState.stableAddr.ref.deprecated = false } - // If prefix was preferred for some finite lifetime before, stop the - // deprecation timer so it can be reset. - prefixState.deprecationTimer.StopLocked() + // If prefix was preferred for some finite lifetime before, cancel the + // deprecation job so it can be reset. + prefixState.deprecationJob.Cancel() now := time.Now() - // Reset the deprecation timer if prefix has a finite preferred lifetime. + // Schedule the deprecation job if prefix has a finite preferred lifetime. if pl < header.NDPInfiniteLifetime { if !deprecated { - prefixState.deprecationTimer.Reset(pl) + prefixState.deprecationJob.Schedule(pl) } prefixState.preferredUntil = now.Add(pl) } else { @@ -1546,9 +1546,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. if vl >= header.NDPInfiniteLifetime { - // Handle the infinite valid lifetime separately as we do not keep a timer - // in this case. - prefixState.invalidationTimer.StopLocked() + // Handle the infinite valid lifetime separately as we do not schedule a + // job in this case. + prefixState.invalidationJob.Cancel() prefixState.validUntil = time.Time{} } else { var effectiveVl time.Duration @@ -1569,8 +1569,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } if effectiveVl != 0 { - prefixState.invalidationTimer.StopLocked() - prefixState.invalidationTimer.Reset(effectiveVl) + prefixState.invalidationJob.Cancel() + prefixState.invalidationJob.Schedule(effectiveVl) prefixState.validUntil = now.Add(effectiveVl) } } @@ -1582,7 +1582,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // Note, we do not need to update the entries in the temporary address map - // after updating the timers because the timers are held as pointers. + // after updating the jobs because the jobs are held as pointers. var regenForAddr tcpip.Address allAddressesRegenerated := true for tempAddr, tempAddrState := range prefixState.tempAddrs { @@ -1596,14 +1596,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // If the address is no longer valid, invalidate it immediately. Otherwise, - // reset the invalidation timer. + // reset the invalidation job. newValidLifetime := validUntil.Sub(now) if newValidLifetime <= 0 { ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState) continue } - tempAddrState.invalidationTimer.StopLocked() - tempAddrState.invalidationTimer.Reset(newValidLifetime) + tempAddrState.invalidationJob.Cancel() + tempAddrState.invalidationJob.Schedule(newValidLifetime) // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary // address is the lower of the preferred lifetime of the stable address or @@ -1616,17 +1616,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // If the address is no longer preferred, deprecate it immediately. - // Otherwise, reset the deprecation timer. + // Otherwise, schedule the deprecation job again. newPreferredLifetime := preferredUntil.Sub(now) - tempAddrState.deprecationTimer.StopLocked() + tempAddrState.deprecationJob.Cancel() if newPreferredLifetime <= 0 { ndp.deprecateSLAACAddress(tempAddrState.ref) } else { tempAddrState.ref.deprecated = false - tempAddrState.deprecationTimer.Reset(newPreferredLifetime) + tempAddrState.deprecationJob.Schedule(newPreferredLifetime) } - tempAddrState.regenTimer.StopLocked() + tempAddrState.regenJob.Cancel() if tempAddrState.regenerated { } else { allAddressesRegenerated = false @@ -1637,7 +1637,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // immediately after we finish iterating over the temporary addresses. regenForAddr = tempAddr } else { - tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) + tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) } } } @@ -1717,7 +1717,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr ndp.cleanupSLAACPrefixResources(prefix, state) } -// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry. +// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry. // // Panics if the SLAAC prefix is not known. // @@ -1729,8 +1729,8 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa } state.stableAddr.ref = nil - state.deprecationTimer.StopLocked() - state.invalidationTimer.StopLocked() + state.deprecationJob.Cancel() + state.invalidationJob.Cancel() delete(ndp.slaacPrefixes, prefix) } @@ -1775,13 +1775,13 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi } // cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's -// timers and entry. +// jobs and entry. // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { - tempAddrState.deprecationTimer.StopLocked() - tempAddrState.invalidationTimer.StopLocked() - tempAddrState.regenTimer.StopLocked() + tempAddrState.deprecationJob.Cancel() + tempAddrState.invalidationJob.Cancel() + tempAddrState.regenJob.Cancel() delete(tempAddrs, tempAddr) } @@ -1860,7 +1860,7 @@ func (ndp *ndpState) startSolicitingRouters() { var done bool ndp.rtrSolicit.done = &done - ndp.rtrSolicit.timer = time.AfterFunc(delay, func() { + ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() { ndp.nic.mu.Lock() if done { // If we reach this point, it means that the RS timer fired after another diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 6f86abc98..644ba7c33 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1254,7 +1254,7 @@ func TestRouterDiscovery(t *testing.T) { default: } - // Wait for lladdr2's router invalidation timer to fire. The lifetime + // Wait for lladdr2's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. // @@ -1271,7 +1271,7 @@ func TestRouterDiscovery(t *testing.T) { e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) expectRouterEvent(llAddr2, false) - // Wait for lladdr3's router invalidation timer to fire. The lifetime + // Wait for lladdr3's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. // @@ -1502,7 +1502,7 @@ func TestPrefixDiscovery(t *testing.T) { default: } - // Wait for prefix2's most recent invalidation timer plus some buffer to + // Wait for prefix2's most recent invalidation job plus some buffer to // expire. select { case e := <-ndpDisp.prefixC: @@ -2395,7 +2395,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { for _, addr := range tempAddrs { // Wait for a deprecation then invalidation event, or just an invalidation // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation timers could fire in any + // cases because the deprecation and invalidation jobs could execute in any // order. select { case e := <-ndpDisp.autoGenAddrC: @@ -2432,9 +2432,9 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } } -// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's -// regeneration timer gets updated when refreshing the address's lifetimes. -func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { +// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's +// regeneration job gets updated when refreshing the address's lifetimes. +func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { const ( nicID = 1 regenAfter = 2 * time.Second @@ -2533,7 +2533,7 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { // // A new temporary address should immediately be generated since the // regeneration time has already passed since the last address was generated - // - this regeneration does not depend on a timer. + // - this regeneration does not depend on a job. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEvent(tempAddr2, newAddr) @@ -2559,11 +2559,11 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { } // Set the maximum lifetimes for temporary addresses such that on the next - // RA, the regeneration timer gets reset. + // RA, the regeneration job gets scheduled again. // // The maximum lifetime is the sum of the minimum lifetimes for temporary // addresses + the time that has already passed since the last address was - // generated so that the regeneration timer is needed to generate the next + // generated so that the regeneration job is needed to generate the next // address. newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout ndpConfigs.MaxTempAddrValidLifetime = newLifetimes @@ -2993,9 +2993,9 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { expectPrimaryAddr(addr2) } -// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated +// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated // when its preferred lifetime expires. -func TestAutoGenAddrTimerDeprecation(t *testing.T) { +func TestAutoGenAddrJobDeprecation(t *testing.T) { const nicID = 1 const newMinVL = 2 newMinVLDuration := newMinVL * time.Second @@ -3513,8 +3513,8 @@ func TestAutoGenAddrRemoval(t *testing.T) { } expectAutoGenAddrEvent(addr, invalidatedAddr) - // Wait for the original valid lifetime to make sure the original timer - // got stopped/cleaned up. + // Wait for the original valid lifetime to make sure the original job got + // cancelled/cleaned up. select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 2b7ece851..a6faa22c2 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -728,6 +728,11 @@ func New(opts Options) *Stack { return s } +// newJob returns a tcpip.Job using the Stack clock. +func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job { + return tcpip.NewJob(s.clock, l, f) +} + // UniqueID returns a unique identifier. func (s *Stack) UniqueID() uint64 { return s.uniqueIDGenerator.UniqueID() @@ -801,9 +806,10 @@ func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h f } } -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (s *Stack) NowNanoseconds() int64 { - return s.clock.NowNanoseconds() +// Clock returns the Stack's clock for retrieving the current time and +// scheduling work. +func (s *Stack) Clock() tcpip.Clock { + return s.clock } // Stats returns a mutable copy of the current stats. diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index ff14a3b3c..21aafb0a2 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -192,7 +192,7 @@ func (e ErrSaveRejection) Error() string { return "save rejected due to unsupported networking state: " + e.Err.Error() } -// A Clock provides the current time. +// A Clock provides the current time and schedules work for execution. // // Times returned by a Clock should always be used for application-visible // time. Only monotonic times should be used for netstack internal timekeeping. @@ -203,6 +203,31 @@ type Clock interface { // NowMonotonic returns a monotonic time value. NowMonotonic() int64 + + // AfterFunc waits for the duration to elapse and then calls f in its own + // goroutine. It returns a Timer that can be used to cancel the call using + // its Stop method. + AfterFunc(d time.Duration, f func()) Timer +} + +// Timer represents a single event. A Timer must be created with +// Clock.AfterFunc. +type Timer interface { + // Stop prevents the Timer from firing. It returns true if the call stops the + // timer, false if the timer has already expired or been stopped. + // + // If Stop returns false, then the timer has already expired and the function + // f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop + // does not wait for f to complete before returning. If the caller needs to + // know whether f is completed, it must coordinate with f explicitly. + Stop() bool + + // Reset changes the timer to expire after duration d. + // + // Reset should be invoked only on stopped or expired timers. If the timer is + // known to have expired, Reset can be used directly. Otherwise, the caller + // must coordinate with the function f of Clock.AfterFunc(d, f). + Reset(d time.Duration) } // Address is a byte slice cast as a string that represents the address of a diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index 7f172f978..f32d58091 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -20,7 +20,7 @@ package tcpip import ( - _ "time" // Used with go:linkname. + "time" // Used with go:linkname. _ "unsafe" // Required for go:linkname. ) @@ -45,3 +45,31 @@ func (*StdClock) NowMonotonic() int64 { _, _, mono := now() return mono } + +// AfterFunc implements Clock.AfterFunc. +func (*StdClock) AfterFunc(d time.Duration, f func()) Timer { + return &stdTimer{ + t: time.AfterFunc(d, f), + } +} + +type stdTimer struct { + t *time.Timer +} + +var _ Timer = (*stdTimer)(nil) + +// Stop implements Timer.Stop. +func (st *stdTimer) Stop() bool { + return st.t.Stop() +} + +// Reset implements Timer.Reset. +func (st *stdTimer) Reset(d time.Duration) { + st.t.Reset(d) +} + +// NewStdTimer returns a Timer implemented with the time package. +func NewStdTimer(t *time.Timer) Timer { + return &stdTimer{t: t} +} diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go index 5554c573f..f1dd7c310 100644 --- a/pkg/tcpip/timer.go +++ b/pkg/tcpip/timer.go @@ -20,50 +20,49 @@ import ( "gvisor.dev/gvisor/pkg/sync" ) -// cancellableTimerInstance is a specific instance of CancellableTimer. +// jobInstance is a specific instance of Job. // -// Different instances are created each time CancellableTimer is Reset so each -// timer has its own earlyReturn signal. This is to address a bug when a -// CancellableTimer is stopped and reset in quick succession resulting in a -// timer instance's earlyReturn signal being affected or seen by another timer -// instance. +// Different instances are created each time Job is scheduled so each timer has +// its own earlyReturn signal. This is to address a bug when a Job is stopped +// and reset in quick succession resulting in a timer instance's earlyReturn +// signal being affected or seen by another timer instance. // // Consider the following sceneario where timer instances share a common // earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a // lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second // (B), third (C), and fourth (D) instance of the timer firing, respectively): // T1: Obtain L -// T1: Create a new CancellableTimer w/ lock L (create instance A) +// T1: Create a new Job w/ lock L (create instance A) // T2: instance A fires, blocked trying to obtain L. // T1: Attempt to stop instance A (set earlyReturn = true) -// T1: Reset timer (create instance B) +// T1: Schedule timer (create instance B) // T3: instance B fires, blocked trying to obtain L. // T1: Attempt to stop instance B (set earlyReturn = true) -// T1: Reset timer (create instance C) +// T1: Schedule timer (create instance C) // T4: instance C fires, blocked trying to obtain L. // T1: Attempt to stop instance C (set earlyReturn = true) -// T1: Reset timer (create instance D) +// T1: Schedule timer (create instance D) // T5: instance D fires, blocked trying to obtain L. // T1: Release L // -// Now that T1 has released L, any of the 4 timer instances can take L and check -// earlyReturn. If the timers simply check earlyReturn and then do nothing -// further, then instance D will never early return even though it was not -// requested to stop. If the timers reset earlyReturn before early returning, -// then all but one of the timers will do work when only one was expected to. -// If CancellableTimer resets earlyReturn when resetting, then all the timers +// Now that T1 has released L, any of the 4 timer instances can take L and +// check earlyReturn. If the timers simply check earlyReturn and then do +// nothing further, then instance D will never early return even though it was +// not requested to stop. If the timers reset earlyReturn before early +// returning, then all but one of the timers will do work when only one was +// expected to. If Job resets earlyReturn when resetting, then all the timers // will fire (again, when only one was expected to). // // To address the above concerns the simplest solution was to give each timer // its own earlyReturn signal. -type cancellableTimerInstance struct { - timer *time.Timer +type jobInstance struct { + timer Timer // Used to inform the timer to early return when it gets stopped while the // lock the timer tries to obtain when fired is held (T1 is a goroutine that // tries to cancel the timer and T2 is the goroutine that handles the timer // firing): - // T1: Obtain the lock, then call StopLocked() + // T1: Obtain the lock, then call Cancel() // T2: timer fires, and gets blocked on obtaining the lock // T1: Releases lock // T2: Obtains lock does unintended work @@ -74,29 +73,33 @@ type cancellableTimerInstance struct { earlyReturn *bool } -// stop stops the timer instance t from firing if it hasn't fired already. If it +// stop stops the job instance j from firing if it hasn't fired already. If it // has fired and is blocked at obtaining the lock, earlyReturn will be set to // true so that it will early return when it obtains the lock. -func (t *cancellableTimerInstance) stop() { - if t.timer != nil { - t.timer.Stop() - *t.earlyReturn = true +func (j *jobInstance) stop() { + if j.timer != nil { + j.timer.Stop() + *j.earlyReturn = true } } -// CancellableTimer is a timer that does some work and can be safely cancelled -// when it fires at the same time some "related work" is being done. +// Job represents some work that can be scheduled for execution. The work can +// be safely cancelled when it fires at the same time some "related work" is +// being done. // // The term "related work" is defined as some work that needs to be done while // holding some lock that the timer must also hold while doing some work. // -// Note, it is not safe to copy a CancellableTimer as its timer instance creates -// a closure over the address of the CancellableTimer. -type CancellableTimer struct { +// Note, it is not safe to copy a Job as its timer instance creates +// a closure over the address of the Job. +type Job struct { _ sync.NoCopy + // The clock used to schedule the backing timer + clock Clock + // The active instance of a cancellable timer. - instance cancellableTimerInstance + instance jobInstance // locker is the lock taken by the timer immediately after it fires and must // be held when attempting to stop the timer. @@ -113,59 +116,91 @@ type CancellableTimer struct { fn func() } -// StopLocked prevents the Timer from firing if it has not fired already. +// Cancel prevents the Job from executing if it has not executed already. // -// If the timer is blocked on obtaining the t.locker lock when StopLocked is -// called, it will early return instead of calling t.fn. +// Cancel requires appropriate locking to be in place for any resources managed +// by the Job. If the Job is blocked on obtaining the lock when Cancel is +// called, it will early return. // // Note, t will be modified. // -// t.locker MUST be locked. -func (t *CancellableTimer) StopLocked() { - t.instance.stop() +// j.locker MUST be locked. +func (j *Job) Cancel() { + j.instance.stop() // Nothing to do with the stopped instance anymore. - t.instance = cancellableTimerInstance{} + j.instance = jobInstance{} } -// Reset changes the timer to expire after duration d. +// Schedule schedules the Job for execution after duration d. This can be +// called on cancelled or completed Jobs to schedule them again. // -// Note, t will be modified. +// Schedule should be invoked only on unscheduled, cancelled, or completed +// Jobs. To be safe, callers should always call Cancel before calling Schedule. // -// Reset should only be called on stopped or expired timers. To be safe, callers -// should always call StopLocked before calling Reset. -func (t *CancellableTimer) Reset(d time.Duration) { +// Note, j will be modified. +func (j *Job) Schedule(d time.Duration) { // Create a new instance. earlyReturn := false // Capture the locker so that updating the timer does not cause a data race // when a timer fires and tries to obtain the lock (read the timer's locker). - locker := t.locker - t.instance = cancellableTimerInstance{ - timer: time.AfterFunc(d, func() { + locker := j.locker + j.instance = jobInstance{ + timer: j.clock.AfterFunc(d, func() { locker.Lock() defer locker.Unlock() if earlyReturn { // If we reach this point, it means that the timer fired while another - // goroutine called StopLocked while it had the lock. Simply return - // here and do nothing further. + // goroutine called Cancel while it had the lock. Simply return here + // and do nothing further. earlyReturn = false return } - t.fn() + j.fn() }), earlyReturn: &earlyReturn, } } -// NewCancellableTimer returns an unscheduled CancellableTimer with the given -// locker and fn. -// -// fn MUST NOT attempt to lock locker. -// -// Callers must call Reset to schedule the timer to fire. -func NewCancellableTimer(locker sync.Locker, fn func()) *CancellableTimer { - return &CancellableTimer{locker: locker, fn: fn} +// NewJob returns a new Job that can be used to schedule f to run in its own +// gorountine. l will be locked before calling f then unlocked after f returns. +// +// var clock tcpip.StdClock +// var mu sync.Mutex +// message := "foo" +// job := tcpip.NewJob(&clock, &mu, func() { +// fmt.Println(message) +// }) +// job.Schedule(time.Second) +// +// mu.Lock() +// message = "bar" +// mu.Unlock() +// +// // Output: bar +// +// f MUST NOT attempt to lock l. +// +// l MUST be locked prior to calling the returned job's Cancel(). +// +// var clock tcpip.StdClock +// var mu sync.Mutex +// message := "foo" +// job := tcpip.NewJob(&clock, &mu, func() { +// fmt.Println(message) +// }) +// job.Schedule(time.Second) +// +// mu.Lock() +// job.Cancel() +// mu.Unlock() +func NewJob(c Clock, l sync.Locker, f func()) *Job { + return &Job{ + clock: c, + locker: l, + fn: f, + } } diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index b4940e397..a82384c49 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -28,8 +28,8 @@ const ( longDuration = 1 * time.Second ) -func TestCancellableTimerReassignment(t *testing.T) { - var timer tcpip.CancellableTimer +func TestJobReschedule(t *testing.T) { + var clock tcpip.StdClock var wg sync.WaitGroup var lock sync.Mutex @@ -43,26 +43,27 @@ func TestCancellableTimerReassignment(t *testing.T) { // that has an active timer (even if it has been stopped as a stopped // timer may be blocked on a lock before it can check if it has been // stopped while another goroutine holds the same lock). - timer = *tcpip.NewCancellableTimer(&lock, func() { + job := tcpip.NewJob(&clock, &lock, func() { wg.Done() }) - timer.Reset(shortDuration) + job.Schedule(shortDuration) lock.Unlock() }() } wg.Wait() } -func TestCancellableTimerFire(t *testing.T) { +func TestJobExecution(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) - timer := tcpip.NewCancellableTimer(&lock, func() { + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -82,17 +83,18 @@ func TestCancellableTimerFire(t *testing.T) { func TestCancellableTimerResetFromLongDuration(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(middleDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(middleDuration) lock.Lock() - timer.StopLocked() + job.Cancel() lock.Unlock() - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -109,16 +111,17 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { } } -func TestCancellableTimerResetFromShortDuration(t *testing.T) { +func TestJobRescheduleFromShortDuration(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() // Wait for timer to fire if it wasn't correctly stopped. @@ -128,7 +131,7 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) { case <-time.After(middleDuration): } - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -145,17 +148,18 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) { } } -func TestCancellableTimerImmediatelyStop(t *testing.T) { +func TestJobImmediatelyCancel(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) for i := 0; i < 1000; i++ { lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() } @@ -167,25 +171,26 @@ func TestCancellableTimerImmediatelyStop(t *testing.T) { } } -func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { +func TestJobCancelledRescheduleWithoutLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() for i := 0; i < 10; i++ { - timer.Reset(middleDuration) + job.Schedule(middleDuration) lock.Lock() // Sleep until the timer fires and gets blocked trying to take the lock. time.Sleep(middleDuration * 2) - timer.StopLocked() + job.Cancel() lock.Unlock() } @@ -201,17 +206,18 @@ func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) for i := 0; i < 10; i++ { // Sleep until the timer fires and gets blocked trying to take the lock. time.Sleep(middleDuration) - timer.StopLocked() - timer.Reset(shortDuration) + job.Cancel() + job.Schedule(shortDuration) } lock.Unlock() @@ -230,18 +236,19 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { } } -func TestManyCancellableTimerResetUnderLock(t *testing.T) { +func TestManyJobReschedulesUnderLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) for i := 0; i < 10; i++ { - timer.StopLocked() - timer.Reset(shortDuration) + job.Cancel() + job.Schedule(shortDuration) } lock.Unlock() diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 678f4e016..4612be4e7 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -797,7 +797,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() - packet.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 8f167391f..0e46e6355 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -499,7 +499,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, combinedVV.Append(pkt.Data) packet.data = combinedVV } - packet.timestampNS = ep.stack.NowNanoseconds() + packet.timestampNS = ep.stack.Clock().NowNanoseconds() ep.rcvList.PushBack(&packet) ep.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index aefe0e2b2..f85a68554 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -700,7 +700,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { } combinedVV.Append(pkt.Data) packet.data = combinedVV - packet.timestampNS = e.stack.NowNanoseconds() + packet.timestampNS = e.stack.Clock().NowNanoseconds() e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index a14643ae8..6e692da07 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1451,7 +1451,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() } - packet.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() -- cgit v1.2.3 From 8dbf428a1236f5962077e2506bef365362b953d0 Mon Sep 17 00:00:00 2001 From: Sam Balana Date: Mon, 27 Jul 2020 15:16:16 -0700 Subject: Add ability to send unicast ARP requests and Neighbor Solicitations The previous implementation of LinkAddressRequest only supported sending broadcast ARP requests and multicast Neighbor Solicitations. The ability to send these packets as unicast is required for Neighbor Unreachability Detection. Tests: pkg/tcpip/network/arp:arp_test - TestLinkAddressRequest pkg/tcpip/network/ipv6:ipv6_test - TestLinkAddressRequest Updates #1889 Updates #1894 Updates #1895 Updates #1947 Updates #1948 Updates #1949 Updates #1950 PiperOrigin-RevId: 323451569 --- pkg/tcpip/network/arp/arp.go | 7 +++-- pkg/tcpip/network/arp/arp_test.go | 58 ++++++++++++++++++++++++++++++----- pkg/tcpip/network/ipv6/icmp.go | 8 +++-- pkg/tcpip/network/ipv6/icmp_test.go | 52 +++++++++++++++++++++++++++++-- pkg/tcpip/stack/forwarder_test.go | 18 ++++++----- pkg/tcpip/stack/linkaddrcache.go | 2 +- pkg/tcpip/stack/linkaddrcache_test.go | 2 +- pkg/tcpip/stack/nic_test.go | 2 +- pkg/tcpip/stack/registration.go | 7 +++-- 9 files changed, 128 insertions(+), 28 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index b0f57040c..31a242482 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -160,9 +160,12 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ - RemoteLinkAddress: header.EthernetBroadcastAddress, + RemoteLinkAddress: remoteLinkAddr, + } + if len(r.RemoteLinkAddress) == 0 { + r.RemoteLinkAddress = header.EthernetBroadcastAddress } hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 66e67429c..a35a64a0f 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -32,10 +32,14 @@ import ( ) const ( - stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") - stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") - stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") + stackLinkAddr1 = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") + stackLinkAddr2 = tcpip.LinkAddress("\x0b\x0b\x0c\x0c\x0d\x0d") + stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") + stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") + stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") + + defaultChannelSize = 1 + defaultMTU = 65536 ) type testContext struct { @@ -50,8 +54,7 @@ func newTestContext(t *testing.T) *testContext { TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, }) - const defaultMTU = 65536 - ep := channel.New(256, defaultMTU, stackLinkAddr) + ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1) wep := stack.LinkEndpoint(ep) if testing.Verbose() { @@ -119,7 +122,7 @@ func TestDirectRequest(t *testing.T) { if !rep.IsValid() { t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) } - if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { + if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) } if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { @@ -144,3 +147,44 @@ func TestDirectRequest(t *testing.T) { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) } } + +func TestLinkAddressRequest(t *testing.T) { + tests := []struct { + name string + remoteLinkAddr tcpip.LinkAddress + expectLinkAddr tcpip.LinkAddress + }{ + { + name: "Unicast", + remoteLinkAddr: stackLinkAddr2, + expectLinkAddr: stackLinkAddr2, + }, + { + name: "Multicast", + remoteLinkAddr: "", + expectLinkAddr: header.EthernetBroadcastAddress, + }, + } + + for _, test := range tests { + p := arp.NewProtocol() + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") + } + + linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1) + if err := linkRes.LinkAddressRequest(stackAddr1, stackAddr2, test.remoteLinkAddr, linkEP); err != nil { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr1, stackAddr2, test.remoteLinkAddr, err) + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) + } + } +} diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index ff1cb53dd..24600d877 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -504,7 +504,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { snaddr := header.SolicitedNodeAddr(addr) // TODO(b/148672031): Use stack.FindRoute instead of manually creating the @@ -513,8 +513,12 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. r := &stack.Route{ LocalAddress: localAddr, RemoteAddress: snaddr, - RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr), + RemoteLinkAddress: remoteLinkAddr, } + if len(r.RemoteLinkAddress) == 0 { + r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr) + } + hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) pkt.SetType(header.ICMPv6NeighborSolicit) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 52a01b44e..f86aaed1d 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -34,6 +34,9 @@ const ( linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") + + defaultChannelSize = 1 + defaultMTU = 65536 ) var ( @@ -257,8 +260,7 @@ func newTestContext(t *testing.T) *testContext { }), } - const defaultMTU = 65536 - c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) + c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0) wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) if testing.Verbose() { @@ -271,7 +273,7 @@ func newTestContext(t *testing.T) *testContext { t.Fatalf("AddAddress lladdr0: %v", err) } - c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) + c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1) wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -951,3 +953,47 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { }) } } + +func TestLinkAddressRequest(t *testing.T) { + snaddr := header.SolicitedNodeAddr(lladdr0) + mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) + + tests := []struct { + name string + remoteLinkAddr tcpip.LinkAddress + expectLinkAddr tcpip.LinkAddress + }{ + { + name: "Unicast", + remoteLinkAddr: linkAddr1, + expectLinkAddr: linkAddr1, + }, + { + name: "Multicast", + remoteLinkAddr: "", + expectLinkAddr: mcaddr, + }, + } + + for _, test := range tests { + p := NewProtocol() + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") + } + + linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) + if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err) + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) + } + } +} diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index bca1d940b..c962693f5 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -121,10 +121,12 @@ func (*fwdTestNetworkEndpoint) Close() {} type fwdTestNetworkProtocol struct { addrCache *linkAddrCache addrResolveDelay time.Duration - onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address) + onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) } +var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) + func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { return fwdTestNetNumber } @@ -174,10 +176,10 @@ func (f *fwdTestNetworkProtocol) Close() {} func (f *fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error { +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { if f.addrCache != nil && f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { - f.onLinkAddressResolved(f.addrCache, addr) + f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr) }) } return nil @@ -405,7 +407,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any address will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, @@ -463,7 +465,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Only packets to address 3 will be resolved to the // link address "c". if addr == "\x03" { @@ -515,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, @@ -559,7 +561,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, @@ -616,7 +618,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 403557fd7..6f73a0ce4 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -244,7 +244,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP) + linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP) select { case now := <-time.After(c.resolutionTimeout): diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 1baa498d0..b15b8d1cb 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -48,7 +48,7 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { +func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) if f := r.onLinkAddressRequest; f != nil { f() diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index c477e31d8..a70792b50 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -243,7 +243,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { return nil } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 9e1b2d25f..8604c4259 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -478,12 +478,13 @@ type InjectableLinkEndpoint interface { // A LinkAddressResolver is an extension to a NetworkProtocol that // can resolve link addresses. type LinkAddressResolver interface { - // LinkAddressRequest sends a request for the LinkAddress of addr. - // The request is sent on linkEP with localAddr as the source. + // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts + // the request on the local network if remoteLinkAddr is the zero value. The + // request is sent on linkEP with localAddr as the source. // // A valid response will cause the discovery protocol's network // endpoint to call AddLinkAddress. - LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error + LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the -- cgit v1.2.3 From f82dd8ddb477b8923d1db12654a62d55d613019c Mon Sep 17 00:00:00 2001 From: Fabricio Voznika Date: Tue, 28 Jul 2020 21:22:52 -0700 Subject: Redirect TODO to GitHub issues PiperOrigin-RevId: 323715260 --- pkg/sentry/socket/netstack/netstack.go | 4 ++-- pkg/tcpip/stack/nic.go | 2 +- pkg/tcpip/stack/stack_test.go | 6 +++--- pkg/tcpip/transport/packet/endpoint.go | 4 ++-- test/syscalls/linux/packet_socket.cc | 4 ++-- test/syscalls/linux/packet_socket_raw.cc | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 44b3fff46..f86e6cd7a 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -423,7 +423,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - // TODO(b/129292371): Return protocol too. + // TODO(gvisor.dev/issue/173): Return protocol too. return tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), @@ -2418,7 +2418,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) return &out, uint32(sockAddrInet6Size) case linux.AF_PACKET: - // TODO(b/129292371): Return protocol too. + // TODO(gvisor.dev/issue/173): Return protocol too. var out linux.SockAddrLink out.Family = linux.AF_PACKET out.InterfaceIndex = int32(addr.NIC) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index fea0ce7e8..9256d4d43 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -181,7 +181,7 @@ func (n *NIC) disableLocked() *tcpip.Error { return nil } - // TODO(b/147015577): Should Routes that are currently bound to n be + // TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be // invalidated? Currently, Routes will continue to work when a NIC is enabled // again, and applications may not know that the underlying NIC was ever // disabled. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 7657a4101..101ca2206 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -867,9 +867,9 @@ func TestRouteWithDownNIC(t *testing.T) { // Writes with Routes that use NIC1 after being brought up should // succeed. // - // TODO(b/147015577): Should we instead completely invalidate all - // Routes that were bound to a NIC that was brought down at some - // point? + // TODO(gvisor.dev/issue/1491): Should we instead completely + // invalidate all Routes that were bound to a NIC that was brought + // down at some point? if err := upFn(s, nicID1); err != nil { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 0e46e6355..df478115d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -193,7 +193,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes } func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - // TODO(b/129292371): Implement. + // TODO(gvisor.dev/issue/173): Implement. return 0, nil, tcpip.ErrInvalidOptionValue } @@ -432,7 +432,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet - // TODO(b/129292371): Return network protocol. + // TODO(gvisor.dev/issue/173): Return network protocol. if len(pkt.LinkHeader) > 0 { // Get info directly from the ethernet header. hdr := header.Ethernet(pkt.LinkHeader) diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index 40aa9326d..861617ff7 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -188,7 +188,7 @@ void ReceiveMessage(int sock, int ifindex) { // sizeof(sockaddr_ll). ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); - // TODO(b/129292371): Verify protocol once we return it. + // TODO(gvisor.dev/issue/173): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); EXPECT_EQ(src.sll_ifindex, ifindex); @@ -234,7 +234,7 @@ TEST_P(CookedPacketTest, Receive) { // Send via a packet socket. TEST_P(CookedPacketTest, Send) { - // TODO(b/129292371): Remove once we support packet socket writing. + // TODO(gvisor.dev/issue/173): Remove once we support packet socket writing. SKIP_IF(IsRunningOnGvisor()); // Let's send a UDP packet and receive it using a regular UDP socket. diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index 2fca9fe4d..a11a03415 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -195,7 +195,7 @@ TEST_P(RawPacketTest, Receive) { // sizeof(sockaddr_ll). ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); - // TODO(b/129292371): Verify protocol once we return it. + // TODO(gvisor.dev/issue/173): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); @@ -240,7 +240,7 @@ TEST_P(RawPacketTest, Receive) { // Send via a packet socket. TEST_P(RawPacketTest, Send) { - // TODO(b/129292371): Remove once we support packet socket writing. + // TODO(gvisor.dev/issue/173): Remove once we support packet socket writing. SKIP_IF(IsRunningOnGvisor()); // Let's send a UDP packet and receive it using a regular UDP socket. -- cgit v1.2.3 From b00858d075fbc39c6fb2db9f903ac14f197566bd Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 30 Jul 2020 12:48:18 -0700 Subject: Use brodcast MAC for broadcast IPv4 packets When sending packets to a known network's broadcast address, use the broadcast MAC address. Test: - stack_test.TestOutgoingSubnetBroadcast - udp_test.TestOutgoingSubnetBroadcast PiperOrigin-RevId: 324062407 --- pkg/tcpip/stack/BUILD | 1 + pkg/tcpip/stack/route.go | 10 ++ pkg/tcpip/stack/stack.go | 15 ++- pkg/tcpip/stack/stack_test.go | 223 ++++++++++++++++++++++++++++++++++++ pkg/tcpip/tcpip.go | 26 +++++ pkg/tcpip/transport/udp/endpoint.go | 8 +- pkg/tcpip/transport/udp/udp_test.go | 189 ++++++++++++++++++++++++++++++ 7 files changed, 464 insertions(+), 8 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 6b9a6b316..2b4455318 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -110,6 +110,7 @@ go_test( "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index d65f8049e..91e0110f1 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -48,6 +48,10 @@ type Route struct { // Loop controls where WritePacket should send packets. Loop PacketLooping + + // directedBroadcast indicates whether this route is sending a directed + // broadcast packet. + directedBroadcast bool } // makeRoute initializes a new route. It takes ownership of the provided @@ -275,6 +279,12 @@ func (r *Route) Stack() *Stack { return r.ref.stack() } +// IsBroadcast returns true if the route is to send a broadcast packet. +func (r *Route) IsBroadcast() bool { + // Only IPv4 has a notion of broadcast. + return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast +} + // ReverseRoute returns new route with given source and destination address. func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { return Route{ diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index a6faa22c2..36626e051 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1284,9 +1284,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n s.mu.RLock() defer s.mu.RUnlock() - isBroadcast := remoteAddr == header.IPv4Broadcast + isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) - needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) + needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok && nic.enabled() { if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { @@ -1307,9 +1307,16 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) - if needRoute { - r.NextHop = route.Gateway + r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr) + + if len(route.Gateway) > 0 { + if needRoute { + r.NextHop = route.Gateway + } + } else if r.directedBroadcast { + r.RemoteLinkAddress = header.EthernetBroadcastAddress } + return r, nil } } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 101ca2206..f22062889 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -27,6 +27,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -3418,3 +3419,225 @@ func TestStackSendBufferSizeOption(t *testing.T) { }) } } + +func TestOutgoingSubnetBroadcast(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID1 = 1 + ) + + defaultAddr := tcpip.AddressWithPrefix{ + Address: header.IPv4Any, + PrefixLen: 0, + } + defaultSubnet := defaultAddr.Subnet() + ipv4Addr := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 24, + } + ipv4Subnet := ipv4Addr.Subnet() + ipv4SubnetBcast := ipv4Subnet.Broadcast() + ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 31, + } + ipv4Subnet31 := ipv4AddrPrefix31.Subnet() + ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() + ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 32, + } + ipv4Subnet32 := ipv4AddrPrefix32.Subnet() + ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() + ipv6Addr := tcpip.AddressWithPrefix{ + Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + PrefixLen: 64, + } + ipv6Subnet := ipv6Addr.Subnet() + ipv6SubnetBcast := ipv6Subnet.Broadcast() + remNetAddr := tcpip.AddressWithPrefix{ + Address: "\x64\x0a\x7b\x18", + PrefixLen: 24, + } + remNetSubnet := remNetAddr.Subnet() + remNetSubnetBcast := remNetSubnet.Broadcast() + + tests := []struct { + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + expectedRoute stack.Route + }{ + // Broadcast to a locally attached subnet populates the broadcast MAC. + { + name: "IPv4 Broadcast to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv4SubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: ipv4SubnetBcast, + RemoteLinkAddress: header.EthernetBroadcastAddress, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a locally attached /31 subnet does not populate the + // broadcast MAC. + { + name: "IPv4 Broadcast to local /31 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix31, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet31, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet31Bcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4AddrPrefix31.Address, + RemoteAddress: ipv4Subnet31Bcast, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a locally attached /32 subnet does not populate the + // broadcast MAC. + { + name: "IPv4 Broadcast to local /32 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix32, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet32, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet32Bcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4AddrPrefix32.Address, + RemoteAddress: ipv4Subnet32Bcast, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // IPv6 has no notion of a broadcast. + { + name: "IPv6 'Broadcast' to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv6Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv6SubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv6Addr.Address, + RemoteAddress: ipv6SubnetBcast, + NetProto: header.IPv6ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a remote subnet in the route table is send to the next-hop + // gateway. + { + name: "IPv4 Broadcast to remote subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: remNetSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: remNetSubnetBcast, + NextHop: ipv4Gateway, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to an unknown subnet follows the default route. Note that this + // is essentially just routing an unknown destination IP, because w/o any + // subnet prefix information a subnet broadcast address is just a normal IP. + { + name: "IPv4 Broadcast to unknown subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: defaultSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: remNetSubnetBcast, + NextHop: ipv4Gateway, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + }) + ep := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + } + + s.SetRouteTable(test.routes) + + var netProto tcpip.NetworkProtocolNumber + switch l := len(test.remoteAddr); l { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + netProto = header.IPv6ProtocolNumber + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err) + } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" { + t.Errorf("route mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 21aafb0a2..a634b9b60 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -43,6 +43,9 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// Using header.IPv4AddressSize would cause an import cycle. +const ipv4AddressSize = 4 + // Error represents an error in the netstack error space. Using a special type // ensures that errors outside of this space are not accidentally introduced. // @@ -320,6 +323,29 @@ func (s *Subnet) Broadcast() Address { return Address(addr) } +// IsBroadcast returns true if the address is considered a broadcast address. +func (s *Subnet) IsBroadcast(address Address) bool { + // Only IPv4 supports the notion of a broadcast address. + if len(address) != ipv4AddressSize { + return false + } + + // Normally, we would just compare address with the subnet's broadcast + // address but there is an exception where a simple comparison is not + // correct. This exception is for /31 and /32 IPv4 subnets where all + // addresses are considered valid host addresses. + // + // For /31 subnets, the case is easy. RFC 3021 Section 2.1 states that + // both addresses in a /31 subnet "MUST be interpreted as host addresses." + // + // For /32, the case is a bit more vague. RFC 3021 makes no mention of /32 + // subnets. However, the same reasoning applies - if an exception is not + // made, then there do not exist any host addresses in a /32 subnet. RFC + // 4632 Section 3.1 also vaguely implies this interpretation by referring + // to addresses in /32 subnets as "host routes." + return s.Prefix() <= 30 && s.Broadcast() == address +} + // Equal returns true if s equals o. // // Needed to use cmp.Equal on Subnet as its fields are unexported. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 6e692da07..b7d735889 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -483,10 +483,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID = e.BindNICID } - if to.Addr == header.IPv4Broadcast && !e.broadcast { - return 0, nil, tcpip.ErrBroadcastDisabled - } - dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err @@ -503,6 +499,10 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c resolve = route.Resolve } + if !e.broadcast && route.IsBroadcast() { + return 0, nil, tcpip.ErrBroadcastDisabled + } + if route.IsResolutionRequired() { if ch, err := resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 90781cf49..66e8911c8 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -2142,3 +2142,192 @@ func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEn c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) } } + +func TestOutgoingSubnetBroadcast(t *testing.T) { + const nicID1 = 1 + + ipv4Addr := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 24, + } + ipv4Subnet := ipv4Addr.Subnet() + ipv4SubnetBcast := ipv4Subnet.Broadcast() + ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 31, + } + ipv4Subnet31 := ipv4AddrPrefix31.Subnet() + ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() + ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 32, + } + ipv4Subnet32 := ipv4AddrPrefix32.Subnet() + ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() + ipv6Addr := tcpip.AddressWithPrefix{ + Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + PrefixLen: 64, + } + ipv6Subnet := ipv6Addr.Subnet() + ipv6SubnetBcast := ipv6Subnet.Broadcast() + remNetAddr := tcpip.AddressWithPrefix{ + Address: "\x64\x0a\x7b\x18", + PrefixLen: 24, + } + remNetSubnet := remNetAddr.Subnet() + remNetSubnetBcast := remNetSubnet.Broadcast() + + tests := []struct { + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + requiresBroadcastOpt bool + }{ + { + name: "IPv4 Broadcast to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv4SubnetBcast, + requiresBroadcastOpt: true, + }, + { + name: "IPv4 Broadcast to local /31 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix31, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet31, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet31Bcast, + requiresBroadcastOpt: false, + }, + { + name: "IPv4 Broadcast to local /32 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix32, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet32, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet32Bcast, + requiresBroadcastOpt: false, + }, + // IPv6 has no notion of a broadcast. + { + name: "IPv6 'Broadcast' to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv6Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv6SubnetBcast, + requiresBroadcastOpt: false, + }, + { + name: "IPv4 Broadcast to remote subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: remNetSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + requiresBroadcastOpt: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + } + + s.SetRouteTable(test.routes) + + var netProto tcpip.NetworkProtocolNumber + switch l := len(test.remoteAddr); l { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + netProto = header.IPv6ProtocolNumber + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + wq := waiter.Queue{} + ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err) + } + defer ep.Close() + + data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + to := tcpip.FullAddress{ + Addr: test.remoteAddr, + Port: 80, + } + opts := tcpip.WriteOptions{To: &to} + expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled + if !test.requiresBroadcastOpt { + expectedErrWithoutBcastOpt = nil + } + + if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err) + } + + if n, _, err := ep.Write(data, opts); err != nil { + t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil { + t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err) + } + + if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) + } + }) + } +} -- cgit v1.2.3 From ab4bb3845579151314003e29aa4218f2290422ca Mon Sep 17 00:00:00 2001 From: Sam Balana Date: Thu, 30 Jul 2020 13:28:21 -0700 Subject: Implement neighbor unreachability detection for ARP and NDP. This change implements the Neighbor Unreachability Detection (NUD) state machine, as per RFC 4861 [1]. The state machine operates on a single neighbor in the local network. This requires the state machine to be implemented on each entry of the neighbor table. This change also adds, but does not expose, several APIs. The first API is for performing basic operations on the neighbor table: - Create a static entry - List all entries - Delete all entries - Remove an entry by address The second API is used for changing the NUD protocol constants on a per-NIC basis to allow Neighbor Discovery to operate over links with widely varying performance characteristics. See [RFC 4861 Section 10][2] for the list of constants. Finally, the last API is for allowing users to subscribe to NUD state changes. See [RFC 4861 Appendix C][3] for the list of edges. [1]: https://tools.ietf.org/html/rfc4861 [2]: https://tools.ietf.org/html/rfc4861#section-10 [3]: https://tools.ietf.org/html/rfc4861#appendix-C Tests: pkg/tcpip/stack:stack_test - TestNeighborCacheAddStaticEntryThenOverflow - TestNeighborCacheClear - TestNeighborCacheClearThenOverflow - TestNeighborCacheConcurrent - TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress - TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress - TestNeighborCacheEntry - TestNeighborCacheEntryNoLinkAddress - TestNeighborCacheGetConfig - TestNeighborCacheKeepFrequentlyUsed - TestNeighborCacheNotifiesWaker - TestNeighborCacheOverflow - TestNeighborCacheOverwriteWithStaticEntryThenOverflow - TestNeighborCacheRemoveEntry - TestNeighborCacheRemoveEntryThenOverflow - TestNeighborCacheRemoveStaticEntry - TestNeighborCacheRemoveStaticEntryThenOverflow - TestNeighborCacheRemoveWaker - TestNeighborCacheReplace - TestNeighborCacheResolutionFailed - TestNeighborCacheResolutionTimeout - TestNeighborCacheSetConfig - TestNeighborCacheStaticResolution - TestEntryAddsAndClearsWakers - TestEntryDelayToProbe - TestEntryDelayToReachableWhenSolicitedOverrideConfirmation - TestEntryDelayToReachableWhenUpperLevelConfirmation - TestEntryDelayToStaleWhenConfirmationWithDifferentAddress - TestEntryDelayToStaleWhenProbeWithDifferentAddress - TestEntryFailedGetsDeleted - TestEntryIncompleteToFailed - TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt - TestEntryIncompleteToReachable - TestEntryIncompleteToReachableWithRouterFlag - TestEntryIncompleteToStale - TestEntryInitiallyUnknown - TestEntryProbeToFailed - TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress - TestEntryProbeToReachableWhenSolicitedOverrideConfirmation - TestEntryProbeToStaleWhenConfirmationWithDifferentAddress - TestEntryProbeToStaleWhenProbeWithDifferentAddress - TestEntryReachableToStaleWhenConfirmationWithDifferentAddress - TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride - TestEntryReachableToStaleWhenProbeWithDifferentAddress - TestEntryReachableToStaleWhenTimeout - TestEntryStaleToDelay - TestEntryStaleToReachableWhenSolicitedOverrideConfirmation - TestEntryStaleToStaleWhenOverrideConfirmation - TestEntryStaleToStaleWhenProbeUpdateAddress - TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress - TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress - TestEntryStaysReachableWhenConfirmationWithRouterFlag - TestEntryStaysReachableWhenProbeWithSameAddress - TestEntryStaysStaleWhenProbeWithSameAddress - TestEntryUnknownToIncomplete - TestEntryUnknownToStale - TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress pkg/tcpip/stack:stack_x_test - TestDefaultNUDConfigurations - TestNUDConfigurationFailsForNotSupported - TestNUDConfigurationsBaseReachableTime - TestNUDConfigurationsDelayFirstProbeTime - TestNUDConfigurationsMaxMulticastProbes - TestNUDConfigurationsMaxRandomFactor - TestNUDConfigurationsMaxUnicastProbes - TestNUDConfigurationsMinRandomFactor - TestNUDConfigurationsRetransmitTimer - TestNUDConfigurationsUnreachableTime - TestNUDStateReachableTime - TestNUDStateRecomputeReachableTime - TestSetNUDConfigurationFailsForBadNICID - TestSetNUDConfigurationFailsForNotSupported [1]: https://tools.ietf.org/html/rfc4861 [2]: https://tools.ietf.org/html/rfc4861#section-10 [3]: https://tools.ietf.org/html/rfc4861#appendix-C Updates #1889 Updates #1894 Updates #1895 Updates #1947 Updates #1948 Updates #1949 Updates #1950 PiperOrigin-RevId: 324070795 --- WORKSPACE | 7 + go.mod | 2 +- go.sum | 3 + pkg/tcpip/stack/BUILD | 24 + pkg/tcpip/stack/fake_time_test.go | 209 +++ pkg/tcpip/stack/ndp.go | 16 - pkg/tcpip/stack/neighbor_cache.go | 335 ++++ pkg/tcpip/stack/neighbor_cache_test.go | 1752 +++++++++++++++++++ pkg/tcpip/stack/neighbor_entry.go | 482 ++++++ pkg/tcpip/stack/neighbor_entry_test.go | 2770 +++++++++++++++++++++++++++++++ pkg/tcpip/stack/neighborstate_string.go | 44 + pkg/tcpip/stack/nic.go | 33 + pkg/tcpip/stack/nud.go | 466 ++++++ pkg/tcpip/stack/nud_test.go | 795 +++++++++ pkg/tcpip/stack/stack.go | 48 +- 15 files changed, 6968 insertions(+), 18 deletions(-) create mode 100644 pkg/tcpip/stack/fake_time_test.go create mode 100644 pkg/tcpip/stack/neighbor_cache.go create mode 100644 pkg/tcpip/stack/neighbor_cache_test.go create mode 100644 pkg/tcpip/stack/neighbor_entry.go create mode 100644 pkg/tcpip/stack/neighbor_entry_test.go create mode 100644 pkg/tcpip/stack/neighborstate_string.go create mode 100644 pkg/tcpip/stack/nud.go create mode 100644 pkg/tcpip/stack/nud_test.go (limited to 'pkg/tcpip/stack') diff --git a/WORKSPACE b/WORKSPACE index 49f231755..058d1b306 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1089,3 +1089,10 @@ go_repository( sum = "h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o=", version = "v1.0.0", ) + +go_repository( + name = "com_github_dpjacques_clockwork", + importpath = "github.com/dpjacques/clockwork", + sum = "h1:7krODee+eIlZYoLiEDmP1kLFNCvd0bQ0eEXOympdN6U=", + version = "v0.1.1-0.20190114191937-d864eecc357b", +) diff --git a/go.mod b/go.mod index 0875b4ba0..2fcba5cc9 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/docker/go-connections v0.3.0 // indirect github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c // indirect github.com/docker/go-units v0.4.0 // indirect + github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b // indirect github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e // indirect github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect github.com/gogo/googleapis v1.4.0 // indirect @@ -43,7 +44,6 @@ require ( github.com/vishvananda/netns v0.0.0-20200520041808-52d707b772fe // indirect go.uber.org/atomic v1.6.0 // indirect go.uber.org/multierr v1.2.0 // indirect - golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 // indirect golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect golang.org/x/tools v0.0.0-20200707200213-416e8f4faf8a // indirect google.golang.org/grpc v1.29.0 // indirect diff --git a/go.sum b/go.sum index a90bca394..f98132971 100644 --- a/go.sum +++ b/go.sum @@ -74,6 +74,8 @@ github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c h1:+pKlWGMw7gf6bQ github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c/go.mod h1:Uw6UezgYA44ePAFQYUehOuCzmy5zmg/+nl2ZfMWGkpA= github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b h1:7krODee+eIlZYoLiEDmP1kLFNCvd0bQ0eEXOympdN6U= +github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b/go.mod h1:D8mP2A8vVT2GkXqPorSBmhnshhkFBYgzhA90KmJt25Y= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -288,6 +290,7 @@ golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200120151820-655fe14d7479/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200523222454-059865788121 h1:rITEj+UZHYC927n8GT97eC3zrpzXdb/voyeOuVKS46o= golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 2b4455318..1c58bed2d 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -15,6 +15,18 @@ go_template_instance( }, ) +go_template_instance( + name = "neighbor_entry_list", + out = "neighbor_entry_list.go", + package = "stack", + prefix = "neighborEntry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*neighborEntry", + "Linker": "*neighborEntry", + }, +) + go_template_instance( name = "packet_buffer_list", out = "packet_buffer_list.go", @@ -53,7 +65,12 @@ go_library( "linkaddrcache.go", "linkaddrentry_list.go", "ndp.go", + "neighbor_cache.go", + "neighbor_entry.go", + "neighbor_entry_list.go", + "neighborstate_string.go", "nic.go", + "nud.go", "packet_buffer.go", "packet_buffer_list.go", "rand.go", @@ -89,6 +106,7 @@ go_test( size = "medium", srcs = [ "ndp_test.go", + "nud_test.go", "stack_test.go", "transport_demuxer_test.go", "transport_test.go", @@ -118,8 +136,11 @@ go_test( name = "stack_test", size = "small", srcs = [ + "fake_time_test.go", "forwarder_test.go", "linkaddrcache_test.go", + "neighbor_cache_test.go", + "neighbor_entry_test.go", "nic_test.go", ], library = ":stack", @@ -129,5 +150,8 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "@com_github_dpjacques_clockwork//:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/stack/fake_time_test.go b/pkg/tcpip/stack/fake_time_test.go new file mode 100644 index 000000000..92c8cb534 --- /dev/null +++ b/pkg/tcpip/stack/fake_time_test.go @@ -0,0 +1,209 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "container/heap" + "sync" + "time" + + "github.com/dpjacques/clockwork" + "gvisor.dev/gvisor/pkg/tcpip" +) + +type fakeClock struct { + clock clockwork.FakeClock + + // mu protects the fields below. + mu sync.RWMutex + + // times is min-heap of times. A heap is used for quick retrieval of the next + // upcoming time of scheduled work. + times *timeHeap + + // waitGroups stores one WaitGroup for all work scheduled to execute at the + // same time via AfterFunc. This allows parallel execution of all functions + // passed to AfterFunc scheduled for the same time. + waitGroups map[time.Time]*sync.WaitGroup +} + +func newFakeClock() *fakeClock { + return &fakeClock{ + clock: clockwork.NewFakeClock(), + times: &timeHeap{}, + waitGroups: make(map[time.Time]*sync.WaitGroup), + } +} + +var _ tcpip.Clock = (*fakeClock)(nil) + +// NowNanoseconds implements tcpip.Clock.NowNanoseconds. +func (fc *fakeClock) NowNanoseconds() int64 { + return fc.clock.Now().UnixNano() +} + +// NowMonotonic implements tcpip.Clock.NowMonotonic. +func (fc *fakeClock) NowMonotonic() int64 { + return fc.NowNanoseconds() +} + +// AfterFunc implements tcpip.Clock.AfterFunc. +func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { + until := fc.clock.Now().Add(d) + wg := fc.addWait(until) + return &fakeTimer{ + clock: fc, + until: until, + timer: fc.clock.AfterFunc(d, func() { + defer wg.Done() + f() + }), + } +} + +// addWait adds an additional wait to the WaitGroup for parallel execution of +// all work scheduled for t. Returns a reference to the WaitGroup modified. +func (fc *fakeClock) addWait(t time.Time) *sync.WaitGroup { + fc.mu.RLock() + wg, ok := fc.waitGroups[t] + fc.mu.RUnlock() + + if ok { + wg.Add(1) + return wg + } + + fc.mu.Lock() + heap.Push(fc.times, t) + fc.mu.Unlock() + + wg = &sync.WaitGroup{} + wg.Add(1) + + fc.mu.Lock() + fc.waitGroups[t] = wg + fc.mu.Unlock() + + return wg +} + +// removeWait removes a wait from the WaitGroup for parallel execution of all +// work scheduled for t. +func (fc *fakeClock) removeWait(t time.Time) { + fc.mu.RLock() + defer fc.mu.RUnlock() + + wg := fc.waitGroups[t] + wg.Done() +} + +// advance executes all work that have been scheduled to execute within d from +// the current fake time. Blocks until all work has completed execution. +func (fc *fakeClock) advance(d time.Duration) { + // Block until all the work is done + until := fc.clock.Now().Add(d) + for { + fc.mu.Lock() + if fc.times.Len() == 0 { + fc.mu.Unlock() + return + } + + t := heap.Pop(fc.times).(time.Time) + if t.After(until) { + // No work to do + heap.Push(fc.times, t) + fc.mu.Unlock() + return + } + fc.mu.Unlock() + + diff := t.Sub(fc.clock.Now()) + fc.clock.Advance(diff) + + fc.mu.RLock() + wg := fc.waitGroups[t] + fc.mu.RUnlock() + + wg.Wait() + + fc.mu.Lock() + delete(fc.waitGroups, t) + fc.mu.Unlock() + } +} + +type fakeTimer struct { + clock *fakeClock + timer clockwork.Timer + + mu sync.RWMutex + until time.Time +} + +var _ tcpip.Timer = (*fakeTimer)(nil) + +// Reset implements tcpip.Timer.Reset. +func (ft *fakeTimer) Reset(d time.Duration) { + if !ft.timer.Reset(d) { + return + } + + ft.mu.Lock() + defer ft.mu.Unlock() + + ft.clock.removeWait(ft.until) + ft.until = ft.clock.clock.Now().Add(d) + ft.clock.addWait(ft.until) +} + +// Stop implements tcpip.Timer.Stop. +func (ft *fakeTimer) Stop() bool { + if !ft.timer.Stop() { + return false + } + + ft.mu.RLock() + defer ft.mu.RUnlock() + + ft.clock.removeWait(ft.until) + return true +} + +type timeHeap []time.Time + +var _ heap.Interface = (*timeHeap)(nil) + +func (h timeHeap) Len() int { + return len(h) +} + +func (h timeHeap) Less(i, j int) bool { + return h[i].Before(h[j]) +} + +func (h timeHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *timeHeap) Push(x interface{}) { + *h = append(*h, x.(time.Time)) +} + +func (h *timeHeap) Pop() interface{} { + last := (*h)[len(*h)-1] + *h = (*h)[:len(*h)-1] + return last +} diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 9dce11a97..5174e639c 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -33,12 +33,6 @@ const ( // Default = 1 (from RFC 4862 section 5.1) defaultDupAddrDetectTransmits = 1 - // defaultRetransmitTimer is the default amount of time to wait between - // sending NDP Neighbor solicitation messages. - // - // Default = 1s (from RFC 4861 section 10). - defaultRetransmitTimer = time.Second - // defaultMaxRtrSolicitations is the default number of Router // Solicitation messages to send when a NIC becomes enabled. // @@ -79,16 +73,6 @@ const ( // Default = true. defaultAutoGenGlobalAddresses = true - // minimumRetransmitTimer is the minimum amount of time to wait between - // sending NDP Neighbor solicitation messages. Note, RFC 4861 does - // not impose a minimum Retransmit Timer, but we do here to make sure - // the messages are not sent all at once. We also come to this value - // because in the RetransmitTimer field of a Router Advertisement, a - // value of 0 means unspecified, so the smallest valid value is 1. - // Note, the unit of the RetransmitTimer field in the Router - // Advertisement is milliseconds. - minimumRetransmitTimer = time.Millisecond - // minimumRtrSolicitationInterval is the minimum amount of time to wait // between sending Router Solicitation messages. This limit is imposed // to make sure that Router Solicitation messages are not sent all at diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go new file mode 100644 index 000000000..1d37716c2 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -0,0 +1,335 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const neighborCacheSize = 512 // max entries per interface + +// neighborCache maps IP addresses to link addresses. It uses the Least +// Recently Used (LRU) eviction strategy to implement a bounded cache for +// dynmically acquired entries. It contains the state machine and configuration +// for running Neighbor Unreachability Detection (NUD). +// +// There are two types of entries in the neighbor cache: +// 1. Dynamic entries are discovered automatically by neighbor discovery +// protocols (e.g. ARP, NDP). These protocols will attempt to reconfirm +// reachability with the device once the entry's state becomes Stale. +// 2. Static entries are explicitly added by a user and have no expiration. +// Their state is always Static. The amount of static entries stored in the +// cache is unbounded. +// +// neighborCache implements NUDHandler. +type neighborCache struct { + nic *NIC + state *NUDState + + // mu protects the fields below. + mu sync.RWMutex + + cache map[tcpip.Address]*neighborEntry + dynamic struct { + lru neighborEntryList + + // count tracks the amount of dynamic entries in the cache. This is + // needed since static entries do not count towards the LRU cache + // eviction strategy. + count uint16 + } +} + +var _ NUDHandler = (*neighborCache)(nil) + +// getOrCreateEntry retrieves a cache entry associated with addr. The +// returned entry is always refreshed in the cache (it is reachable via the +// map, and its place is bumped in LRU). +// +// If a matching entry exists in the cache, it is returned. If no matching +// entry exists and the cache is full, an existing entry is evicted via LRU, +// reset to state incomplete, and returned. If no matching entry exists and the +// cache is not full, a new entry with state incomplete is allocated and +// returned. +func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { + n.mu.Lock() + defer n.mu.Unlock() + + if entry, ok := n.cache[remoteAddr]; ok { + entry.mu.RLock() + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.lru.PushFront(entry) + } + entry.mu.RUnlock() + return entry + } + + // The entry that needs to be created must be dynamic since all static + // entries are directly added to the cache via addStaticEntry. + entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes) + if n.dynamic.count == neighborCacheSize { + e := n.dynamic.lru.Back() + e.mu.Lock() + + delete(n.cache, e.neigh.Addr) + n.dynamic.lru.Remove(e) + n.dynamic.count-- + + e.dispatchRemoveEventLocked() + e.setStateLocked(Unknown) + e.notifyWakersLocked() + e.mu.Unlock() + } + n.cache[remoteAddr] = entry + n.dynamic.lru.PushFront(entry) + n.dynamic.count++ + return entry +} + +// entry looks up the neighbor cache for translating address to link address +// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there +// is a LinkAddressResolver registered with the network protocol, the cache +// attempts to resolve the address and returns ErrWouldBlock. If a Waker is +// provided, it will be notified when address resolution is complete (success +// or not). +// +// If address resolution is required, ErrNoLinkAddress and a notification +// channel is returned for the top level caller to block. Channel is closed +// once address resolution is complete (success or not). +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { + if linkRes != nil { + if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { + e := NeighborEntry{ + Addr: remoteAddr, + LocalAddr: localAddr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), + } + return e, nil, nil + } + } + + entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) + entry.mu.Lock() + defer entry.mu.Unlock() + + switch s := entry.neigh.State; s { + case Reachable, Static: + return entry.neigh, nil, nil + + case Unknown, Incomplete, Stale, Delay, Probe: + entry.addWakerLocked(w) + + if entry.done == nil { + // Address resolution needs to be initiated. + if linkRes == nil { + return entry.neigh, nil, tcpip.ErrNoLinkAddress + } + entry.done = make(chan struct{}) + } + + entry.handlePacketQueuedLocked() + return entry.neigh, entry.done, tcpip.ErrWouldBlock + + case Failed: + return entry.neigh, nil, tcpip.ErrNoLinkAddress + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", s)) + } +} + +// removeWaker removes a waker that has been added when link resolution for +// addr was requested. +func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) { + n.mu.Lock() + if entry, ok := n.cache[addr]; ok { + delete(entry.wakers, waker) + } + n.mu.Unlock() +} + +// entries returns all entries in the neighbor cache. +func (n *neighborCache) entries() []NeighborEntry { + entries := make([]NeighborEntry, 0, len(n.cache)) + n.mu.RLock() + for _, entry := range n.cache { + entry.mu.RLock() + entries = append(entries, entry.neigh) + entry.mu.RUnlock() + } + n.mu.RUnlock() + return entries +} + +// addStaticEntry adds a static entry to the neighbor cache, mapping an IP +// address to a link address. If a dynamic entry exists in the neighbor cache +// with the same address, it will be replaced with this static entry. If a +// static entry exists with the same address but different link address, it +// will be updated with the new link address. If a static entry exists with the +// same address and link address, nothing will happen. +func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) { + n.mu.Lock() + defer n.mu.Unlock() + + if entry, ok := n.cache[addr]; ok { + entry.mu.Lock() + if entry.neigh.State != Static { + // Dynamic entry found with the same address. + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } else if entry.neigh.LinkAddr == linkAddr { + // Static entry found with the same address and link address. + entry.mu.Unlock() + return + } else { + // Static entry found with the same address but different link address. + entry.neigh.LinkAddr = linkAddr + entry.dispatchChangeEventLocked(entry.neigh.State) + entry.mu.Unlock() + return + } + + // Notify that resolution has been interrupted, just in case the entry was + // in the Incomplete or Probe state. + entry.dispatchRemoveEventLocked() + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + entry.mu.Unlock() + } + + entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) + n.cache[addr] = entry +} + +// removeEntryLocked removes the specified entry from the neighbor cache. +func (n *neighborCache) removeEntryLocked(entry *neighborEntry) { + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } + if entry.neigh.State != Failed { + entry.dispatchRemoveEventLocked() + } + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + + delete(n.cache, entry.neigh.Addr) +} + +// removeEntry removes a dynamic or static entry by address from the neighbor +// cache. Returns true if the entry was found and deleted. +func (n *neighborCache) removeEntry(addr tcpip.Address) bool { + n.mu.Lock() + defer n.mu.Unlock() + + entry, ok := n.cache[addr] + if !ok { + return false + } + + entry.mu.Lock() + defer entry.mu.Unlock() + + n.removeEntryLocked(entry) + return true +} + +// clear removes all dynamic and static entries from the neighbor cache. +func (n *neighborCache) clear() { + n.mu.Lock() + defer n.mu.Unlock() + + for _, entry := range n.cache { + entry.mu.Lock() + entry.dispatchRemoveEventLocked() + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + entry.mu.Unlock() + } + + n.dynamic.lru = neighborEntryList{} + n.cache = make(map[tcpip.Address]*neighborEntry) + n.dynamic.count = 0 +} + +// config returns the NUD configuration. +func (n *neighborCache) config() NUDConfigurations { + return n.state.Config() +} + +// setConfig changes the NUD configuration. +// +// If config contains invalid NUD configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *neighborCache) setConfig(config NUDConfigurations) { + config.resetInvalidFields() + n.state.SetConfig(config) +} + +// HandleProbe implements NUDHandler.HandleProbe by following the logic defined +// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled +// by the caller. +func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) { + entry := n.getOrCreateEntry(remoteAddr, localAddr, nil) + entry.mu.Lock() + entry.handleProbeLocked(remoteLinkAddr) + entry.mu.Unlock() +} + +// HandleConfirmation implements NUDHandler.HandleConfirmation by following the +// logic defined in RFC 4861 section 7.2.5. +// +// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other +// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol +// should be deployed where preventing access to the broadcast segment might +// not be possible. SEND uses RSA key pairs to produce cryptographically +// generated addresses, as defined in RFC 3972, Cryptographically Generated +// Addresses (CGA). This ensures that the claimed source of an NDP message is +// the owner of the claimed address. +func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + n.mu.RLock() + entry, ok := n.cache[addr] + n.mu.RUnlock() + if ok { + entry.mu.Lock() + entry.handleConfirmationLocked(linkAddr, flags) + entry.mu.Unlock() + } + // The confirmation SHOULD be silently discarded if the recipient did not + // initiate any communication with the target. This is indicated if there is + // no matching entry for the remote address. +} + +// HandleUpperLevelConfirmation implements +// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in +// RFC 4861 section 7.3.1. +func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) { + n.mu.RLock() + entry, ok := n.cache[addr] + n.mu.RUnlock() + if ok { + entry.mu.Lock() + entry.handleUpperLevelConfirmationLocked() + entry.mu.Unlock() + } +} diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go new file mode 100644 index 000000000..4cb2c9c6b --- /dev/null +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -0,0 +1,1752 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "bytes" + "encoding/binary" + "fmt" + "math" + "math/rand" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // entryStoreSize is the default number of entries that will be generated and + // added to the entry store. This number needs to be larger than the size of + // the neighbor cache to give ample opportunity for verifying behavior during + // cache overflows. Four times the size of the neighbor cache allows for + // three complete cache overflows. + entryStoreSize = 4 * neighborCacheSize + + // typicalLatency is the typical latency for an ARP or NDP packet to travel + // to a router and back. + typicalLatency = time.Millisecond + + // testEntryBroadcastAddr is a special address that indicates a packet should + // be sent to all nodes. + testEntryBroadcastAddr = tcpip.Address("broadcast") + + // testEntryLocalAddr is the source address of neighbor probes. + testEntryLocalAddr = tcpip.Address("local_addr") + + // testEntryBroadcastLinkAddr is a special link address sent back to + // multicast neighbor probes. + testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") + + // infiniteDuration indicates that a task will not occur in our lifetime. + infiniteDuration = time.Duration(math.MaxInt64) +) + +// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor +// entries. The UpdatedAt field is ignored due to a lack of a deterministic +// method to predict the time that an event will be dispatched. +func entryDiffOpts() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), + } +} + +// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to +// sort slices of entries for cases where ordering must be ignored. +func entryDiffOptsWithSort() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), + cmpopts.SortSlices(func(a, b NeighborEntry) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + }), + } +} + +func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { + config.resetInvalidFields() + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + return &neighborCache{ + nic: &NIC{ + stack: &Stack{ + clock: clock, + nudDisp: nudDisp, + }, + id: 1, + }, + state: NewNUDState(config, rng), + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } +} + +// testEntryStore contains a set of IP to NeighborEntry mappings. +type testEntryStore struct { + mu sync.RWMutex + entriesMap map[tcpip.Address]NeighborEntry +} + +func toAddress(i int) tcpip.Address { + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, uint8(1)) + binary.Write(buf, binary.BigEndian, uint8(0)) + binary.Write(buf, binary.BigEndian, uint16(i)) + return tcpip.Address(buf.String()) +} + +func toLinkAddress(i int) tcpip.LinkAddress { + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, uint8(1)) + binary.Write(buf, binary.BigEndian, uint8(0)) + binary.Write(buf, binary.BigEndian, uint32(i)) + return tcpip.LinkAddress(buf.String()) +} + +// newTestEntryStore returns a testEntryStore pre-populated with entries. +func newTestEntryStore() *testEntryStore { + store := &testEntryStore{ + entriesMap: make(map[tcpip.Address]NeighborEntry), + } + for i := 0; i < entryStoreSize; i++ { + addr := toAddress(i) + linkAddr := toLinkAddress(i) + + store.entriesMap[addr] = NeighborEntry{ + Addr: addr, + LocalAddr: testEntryLocalAddr, + LinkAddr: linkAddr, + } + } + return store +} + +// size returns the number of entries in the store. +func (s *testEntryStore) size() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.entriesMap) +} + +// entry returns the entry at index i. Returns an empty entry and false if i is +// out of bounds. +func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { + return s.entryByAddr(toAddress(i)) +} + +// entryByAddr returns the entry matching addr for situations when the index is +// not available. Returns an empty entry and false if no entries match addr. +func (s *testEntryStore) entryByAddr(addr tcpip.Address) (NeighborEntry, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + entry, ok := s.entriesMap[addr] + return entry, ok +} + +// entries returns all entries in the store. +func (s *testEntryStore) entries() []NeighborEntry { + entries := make([]NeighborEntry, 0, len(s.entriesMap)) + s.mu.RLock() + defer s.mu.RUnlock() + for i := 0; i < entryStoreSize; i++ { + addr := toAddress(i) + if entry, ok := s.entriesMap[addr]; ok { + entries = append(entries, entry) + } + } + return entries +} + +// set modifies the link addresses of an entry. +func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { + addr := toAddress(i) + s.mu.Lock() + defer s.mu.Unlock() + if entry, ok := s.entriesMap[addr]; ok { + entry.LinkAddr = linkAddr + s.entriesMap[addr] = entry + } +} + +// testNeighborResolver implements LinkAddressResolver to emulate sending a +// neighbor probe. +type testNeighborResolver struct { + clock tcpip.Clock + neigh *neighborCache + entries *testEntryStore + delay time.Duration + onLinkAddressRequest func() +} + +var _ LinkAddressResolver = (*testNeighborResolver)(nil) + +func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { + // Delay handling the request to emulate network latency. + r.clock.AfterFunc(r.delay, func() { + r.fakeRequest(addr) + }) + + // Execute post address resolution action, if available. + if f := r.onLinkAddressRequest; f != nil { + f() + } + return nil +} + +// fakeRequest emulates handling a response for a link address request. +func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) { + if entry, ok := r.entries.entryByAddr(addr); ok { + r.neigh.HandleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + } +} + +func (*testNeighborResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == testEntryBroadcastAddr { + return testEntryBroadcastLinkAddr, true + } + return "", false +} + +func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return 0 +} + +type entryEvent struct { + nicID tcpip.NICID + address tcpip.Address + linkAddr tcpip.LinkAddress + state NeighborState +} + +func TestNeighborCacheGetConfig(t *testing.T) { + nudDisp := testNUDDispatcher{} + c := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + + if got, want := neigh.config(), c; got != want { + t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheSetConfig(t *testing.T) { + nudDisp := testNUDDispatcher{} + c := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + neigh.setConfig(c) + + if got, want := neigh.config(), c; got != want { + t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheEntry(t *testing.T) { + c := DefaultNUDConfigurations() + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + + clock.advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + + // No more events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +// TestNeighborCacheEntryNoLinkAddress verifies calling entry() without a +// LinkAddressResolver returns ErrNoLinkAddress. +func TestNeighborCacheEntryNoLinkAddress(t *testing.T) { + nudDisp := testNUDDispatcher{} + c := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + store := newTestEntryStore() + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, nil, nil) + if err != tcpip.ErrNoLinkAddress { + t.Errorf("got neigh.entry(%s, %s, nil, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheRemoveEntry(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + + clock.advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + neigh.removeEntry(entry.Addr) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } +} + +type testContext struct { + clock *fakeClock + neigh *neighborCache + store *testEntryStore + linkRes *testNeighborResolver + nudDisp *testNUDDispatcher +} + +func newTestContext(c NUDConfigurations) testContext { + nudDisp := &testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(nudDisp, c, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + return testContext{ + clock: clock, + neigh: neigh, + store: store, + linkRes: linkRes, + nudDisp: nudDisp, + } +} + +type overflowOptions struct { + startAtEntryIndex int + wantStaticEntries []NeighborEntry +} + +func (c *testContext) overflowCache(opts overflowOptions) error { + // Fill the neighbor cache to capacity to verify the LRU eviction strategy is + // working properly after the entry removal. + for i := opts.startAtEntryIndex; i < c.store.size(); i++ { + // Add a new entry + entry, ok := c.store.entry(i) + if !ok { + return fmt.Errorf("c.store.entry(%d) not found", i) + } + if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { + return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(c.neigh.config().RetransmitTimer) + + var wantEvents []testEntryEventInfo + + // When beyond the full capacity, the cache will evict an entry as per the + // LRU eviction strategy. Note that the number of static entries should not + // affect the total number of dynamic entries that can be added. + if i >= neighborCacheSize+opts.startAtEntryIndex { + removedEntry, ok := c.store.entry(i - neighborCacheSize) + if !ok { + return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize) + } + wantEvents = append(wantEvents, testEntryEventInfo{ + EventType: entryTestRemoved, + NICID: 1, + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }) + } + + wantEvents = append(wantEvents, testEntryEventInfo{ + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, testEntryEventInfo{ + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }) + + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Expect to find only the most recent entries. The order of entries reported + // by entries() is undeterministic, so entries have to be sorted before + // comparison. + wantUnsortedEntries := opts.wantStaticEntries + for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { + entry, ok := c.store.entry(i) + if !ok { + return fmt.Errorf("c.store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // No more events should have been dispatched. + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + return nil +} + +// TestNeighborCacheOverflow verifies that the LRU cache eviction strategy +// respects the dynamic entry count. +func TestNeighborCacheOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheRemoveEntryThenOverflow verifies that the LRU cache +// eviction strategy respects the dynamic entry count when an entry is removed. +func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(c.neigh.config().RetransmitTimer) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the entry + c.neigh.removeEntry(entry.Addr) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress verifies that +// adding a duplicate static entry with the same link address does not dispatch +// any events. +func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { + config := DefaultNUDConfigurations() + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the static entry that was just added + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + + // No more events should have been dispatched. + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +// TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress verifies that +// adding a duplicate static entry with a different link address dispatches a +// change event. +func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { + config := DefaultNUDConfigurations() + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Add a duplicate entry with a different link address + staticLinkAddr += "duplicate" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } +} + +// TestNeighborCacheRemoveStaticEntryThenOverflow verifies that the LRU cache +// eviction strategy respects the dynamic entry count when a static entry is +// added then removed. In this case, the dynamic entry count shouldn't have +// been touched. +func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the static entry that was just added + c.neigh.removeEntry(entry.Addr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheOverwriteWithStaticEntryThenOverflow verifies that the LRU +// cache eviction strategy keeps count of the dynamic entry count when an entry +// is overwritten by a static entry. Static entries should not count towards +// the size of the LRU cache. +func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(typicalLatency) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Override the entry with a static one using the same address + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 1, + wantStaticEntries: []NeighborEntry{ + { + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: staticLinkAddr, + State: Static, + }, + }, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheNotifiesWaker(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + w := sleep.Waker{} + s := sleep.Sleeper{} + const wakerID = 1 + s.AddWaker(&w, wakerID) + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh == nil { + t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + clock.advance(typicalLatency) + + select { + case <-doneCh: + default: + t.Fatal("expected notification from done channel") + } + + id, ok := s.Fetch(false /* block */) + if !ok { + t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + if id != wakerID { + t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheRemoveWaker(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + w := sleep.Waker{} + s := sleep.Sleeper{} + const wakerID = 1 + s.AddWaker(&w, wakerID) + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh == nil { + t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + + // Remove the waker before the neighbor cache has the opportunity to send a + // notification. + neigh.removeWaker(entry.Addr, &w) + clock.advance(typicalLatency) + + select { + case <-doneCh: + default: + t.Fatal("expected notification from done channel") + } + + if id, ok := s.Fetch(false /* block */); ok { + t.Errorf("unexpected notification from waker with id %d", id) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) + e, _, err := c.neigh.entry(entry.Addr, "", nil, nil) + if err != nil { + t.Errorf("unexpected error from c.neigh.entry(%s, \"\", nil nil): %s", entry.Addr, err) + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: entry.LinkAddr, + State: Static, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("c.neigh.entry(%s, \"\", nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + opts := overflowOptions{ + startAtEntryIndex: 1, + wantStaticEntries: []NeighborEntry{ + { + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: entry.LinkAddr, + State: Static, + }, + }, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheClear(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + // Add a dynamic entry. + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Add a static entry. + neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Clear shoud remove both dynamic and static entries. + neigh.clear() + + // Remove events dispatched from clear() have no deterministic order so they + // need to be sorted beforehand. + wantUnsortedEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +// TestNeighborCacheClearThenOverflow verifies that the LRU cache eviction +// strategy keeps count of the dynamic entry count when all entries are +// cleared. +func TestNeighborCacheClearThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(typicalLatency) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Clear the cache. + c.neigh.clear() + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + frequentlyUsedEntry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + + // The following logic is very similar to overflowCache, but + // periodically refreshes the frequently used entry. + + // Fill the neighbor cache to capacity + for i := 0; i < neighborCacheSize; i++ { + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Keep adding more entries + for i := neighborCacheSize; i < store.size(); i++ { + // Periodically refresh the frequently used entry + if i%(neighborCacheSize/2) == 0 { + _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err) + } + } + + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + + // An entry should have been removed, as per the LRU eviction strategy + removedEntry, ok := store.entry(i - neighborCacheSize + 1) + if !ok { + t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1) + } + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Expect to find only the frequently used entry and the most recent entries. + // The order of entries reported by entries() is undeterministic, so entries + // have to be sorted before comparison. + wantUnsortedEntries := []NeighborEntry{ + { + Addr: frequentlyUsedEntry.Addr, + LocalAddr: frequentlyUsedEntry.LocalAddr, + LinkAddr: frequentlyUsedEntry.LinkAddr, + State: Reachable, + }, + } + + for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ { + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // No more events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheConcurrent(t *testing.T) { + const concurrentProcesses = 16 + + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + storeEntries := store.entries() + for _, entry := range storeEntries { + var wg sync.WaitGroup + for r := 0; r < concurrentProcesses; r++ { + wg.Add(1) + go func(entry NeighborEntry) { + defer wg.Done() + e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil && err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock) + } + }(entry) + } + + // Wait for all gorountines to send a request + wg.Wait() + + // Process all the requests for a single entry concurrently + clock.advance(typicalLatency) + } + + // All goroutines add in the same order and add more values than can fit in + // the cache. Our eviction strategy requires that the last entries are + // present, up to the size of the neighbor cache, and the rest are missing. + // The order of entries reported by entries() is undeterministic, so entries + // have to be sorted before comparison. + var wantUnsortedEntries []NeighborEntry + for i := store.size() - neighborCacheSize; i < store.size(); i++ { + entry, ok := store.entry(i) + if !ok { + t.Errorf("store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheReplace(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + // Add an entry + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + + // Verify the entry exists + e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + if doneCh != nil { + t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) + } + if t.Failed() { + t.FailNow() + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + } + + // Notify of a link address change + var updatedLinkAddr tcpip.LinkAddress + { + entry, ok := store.entry(1) + if !ok { + t.Fatalf("store.entry(1) not found") + } + updatedLinkAddr = entry.LinkAddr + } + store.set(0, updatedLinkAddr) + neigh.HandleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + + // Requesting the entry again should start address resolution + { + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(config.DelayFirstProbeTime + typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + } + + // Verify the entry's new link address + { + e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + clock.advance(typicalLatency) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + want = NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: updatedLinkAddr, + State: Reachable, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + } + } +} + +func TestNeighborCacheResolutionFailed(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + + var requestCount uint32 + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + onLinkAddressRequest: func() { + atomic.AddUint32(&requestCount, 1) + }, + } + + // First, sanity check that resolution is working + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + } + + // Verify that address resolution for an unknown address returns ErrNoLinkAddress + before := atomic.LoadUint32(&requestCount) + + entry.Addr += "2" + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) + clock.advance(waitFor) + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + } + + maxAttempts := neigh.config().MaxUnicastProbes + if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want { + t.Errorf("got link address request count = %d, want = %d", got, want) + } +} + +// TestNeighborCacheResolutionTimeout simulates sending MaxMulticastProbes +// probes and not retrieving a confirmation before the duration defined by +// MaxMulticastProbes * RetransmitTimer. +func TestNeighborCacheResolutionTimeout(t *testing.T) { + config := DefaultNUDConfigurations() + config.RetransmitTimer = time.Millisecond // small enough to cause timeout + + clock := newFakeClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: time.Minute, // large enough to cause timeout + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) + clock.advance(waitFor) + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + } +} + +// TestNeighborCacheStaticResolution checks that static link addresses are +// resolved immediately and don't send resolution requests. +func TestNeighborCacheStaticResolution(t *testing.T) { + config := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil) + if err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err) + } + want := NeighborEntry{ + Addr: testEntryBroadcastAddr, + LocalAddr: testEntryLocalAddr, + LinkAddr: testEntryBroadcastLinkAddr, + State: Static, + } + if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff) + } +} + +func BenchmarkCacheClear(b *testing.B) { + b.StopTimer() + config := DefaultNUDConfigurations() + clock := &tcpip.StdClock{} + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: 0, + } + + // Clear for every possible size of the cache + for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { + // Fill the neighbor cache to capacity. + for i := 0; i < cacheSize; i++ { + entry, ok := store.entry(i) + if !ok { + b.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh != nil { + <-doneCh + } + } + + b.StartTimer() + neigh.clear() + b.StopTimer() + } +} diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go new file mode 100644 index 000000000..0068cacb8 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -0,0 +1,482 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" +) + +// NeighborEntry describes a neighboring device in the local network. +type NeighborEntry struct { + Addr tcpip.Address + LocalAddr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time +} + +// NeighborState defines the state of a NeighborEntry within the Neighbor +// Unreachability Detection state machine, as per RFC 4861 section 7.3.2. +type NeighborState uint8 + +const ( + // Unknown means reachability has not been verified yet. This is the initial + // state of entries that have been created automatically by the Neighbor + // Unreachability Detection state machine. + Unknown NeighborState = iota + // Incomplete means that there is an outstanding request to resolve the + // address. + Incomplete + // Reachable means the path to the neighbor is functioning properly for both + // receive and transmit paths. + Reachable + // Stale means reachability to the neighbor is unknown, but packets are still + // able to be transmitted to the possibly stale link address. + Stale + // Delay means reachability to the neighbor is unknown and pending + // confirmation from an upper-level protocol like TCP, but packets are still + // able to be transmitted to the possibly stale link address. + Delay + // Probe means a reachability confirmation is actively being sought by + // periodically retransmitting reachability probes until a reachability + // confirmation is received, or until the max amount of probes has been sent. + Probe + // Static describes entries that have been explicitly added by the user. They + // do not expire and are not deleted until explicitly removed. + Static + // Failed means traffic should not be sent to this neighbor since attempts of + // reachability have returned inconclusive. + Failed +) + +// neighborEntry implements a neighbor entry's individual node behavior, as per +// RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in +// parallel with the sending of packets to a neighbor, necessitating the +// entry's lock to be acquired for all operations. +type neighborEntry struct { + neighborEntryEntry + + nic *NIC + protocol tcpip.NetworkProtocolNumber + + // linkRes provides the functionality to send reachability probes, used in + // Neighbor Unreachability Detection. + linkRes LinkAddressResolver + + // nudState points to the Neighbor Unreachability Detection configuration. + nudState *NUDState + + // mu protects the fields below. + mu sync.RWMutex + + neigh NeighborEntry + + // wakers is a set of waiters for address resolution result. Anytime state + // transitions out of incomplete these waiters are notified. It is nil iff + // address resolution is ongoing and no clients are waiting for the result. + wakers map[*sleep.Waker]struct{} + + // done is used to allow callers to wait on address resolution. It is nil + // iff nudState is not Reachable and address resolution is not yet in + // progress. + done chan struct{} + + isRouter bool + job *tcpip.Job +} + +// newNeighborEntry creates a neighbor cache entry starting at the default +// state, Unknown. Transition out of Unknown by calling either +// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created +// neighborEntry. +func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { + return &neighborEntry{ + nic: nic, + linkRes: linkRes, + nudState: nudState, + neigh: NeighborEntry{ + Addr: remoteAddr, + LocalAddr: localAddr, + State: Unknown, + }, + } +} + +// newStaticNeighborEntry creates a neighbor cache entry starting at the Static +// state. The entry can only transition out of Static by directly calling +// `setStateLocked`. +func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { + if nic.stack.nudDisp != nil { + nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now()) + } + return &neighborEntry{ + nic: nic, + nudState: state, + neigh: NeighborEntry{ + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), + }, + } +} + +// addWaker adds w to the list of wakers waiting for address resolution. +// Assumes the entry has already been appropriately locked. +func (e *neighborEntry) addWakerLocked(w *sleep.Waker) { + if w == nil { + return + } + if e.wakers == nil { + e.wakers = make(map[*sleep.Waker]struct{}) + } + e.wakers[w] = struct{}{} +} + +// notifyWakersLocked notifies those waiting for address resolution, whether it +// succeeded or failed. Assumes the entry has already been appropriately locked. +func (e *neighborEntry) notifyWakersLocked() { + for w := range e.wakers { + w.Assert() + } + e.wakers = nil + if ch := e.done; ch != nil { + close(ch) + e.done = nil + } +} + +// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has +// been added. +func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + } +} + +// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry +// has changed state or link-layer address. +func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + } +} + +// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry +// has been removed. +func (e *neighborEntry) dispatchRemoveEventLocked() { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now()) + } +} + +// setStateLocked transitions the entry to the specified state immediately. +// +// Follows the logic defined in RFC 4861 section 7.3.3. +// +// e.mu MUST be locked. +func (e *neighborEntry) setStateLocked(next NeighborState) { + // Cancel the previously scheduled action, if there is one. Entries in + // Unknown, Stale, or Static state do not have scheduled actions. + if timer := e.job; timer != nil { + timer.Cancel() + } + + prev := e.neigh.State + e.neigh.State = next + e.neigh.UpdatedAt = time.Now() + config := e.nudState.Config() + + switch next { + case Incomplete: + var retryCounter uint32 + var sendMulticastProbe func() + + sendMulticastProbe = func() { + if retryCounter == config.MaxMulticastProbes { + // "If no Neighbor Advertisement is received after + // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. + // The sender MUST return ICMP destination unreachable indications with + // code 3 (Address Unreachable) for each packet queued awaiting address + // resolution." - RFC 4861 section 7.2.2 + // + // There is no need to send an ICMP destination unreachable indication + // since the failure to resolve the address is expected to only occur + // on this node. Thus, redirecting traffic is currently not supported. + // + // "If the error occurs on a node other than the node originating the + // packet, an ICMP error message is generated. If the error occurs on + // the originating node, an implementation is not required to actually + // create and send an ICMP error packet to the source, as long as the + // upper-layer sender is notified through an appropriate mechanism + // (e.g. return value from a procedure call). Note, however, that an + // implementation may find it convenient in some cases to return errors + // to the sender by taking the offending packet, generating an ICMP + // error message, and then delivering it (locally) through the generic + // error-handling routines.' - RFC 4861 section 2.1 + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil { + // There is no need to log the error here; the NUD implementation may + // assume a working link. A valid link should be the responsibility of + // the NIC/stack.LinkEndpoint. + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + sendMulticastProbe() + + case Reachable: + e.job = e.nic.stack.newJob(&e.mu, func() { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + }) + e.job.Schedule(e.nudState.ReachableTime()) + + case Delay: + e.job = e.nic.stack.newJob(&e.mu, func() { + e.dispatchChangeEventLocked(Probe) + e.setStateLocked(Probe) + }) + e.job.Schedule(config.DelayFirstProbeTime) + + case Probe: + var retryCounter uint32 + var sendUnicastProbe func() + + sendUnicastProbe = func() { + if retryCounter == config.MaxUnicastProbes { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + if retryCounter == config.MaxUnicastProbes { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + sendUnicastProbe() + + case Failed: + e.notifyWakersLocked() + e.job = e.nic.stack.newJob(&e.mu, func() { + e.nic.neigh.removeEntryLocked(e) + }) + e.job.Schedule(config.UnreachableTime) + + case Unknown, Stale, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid state transition from %q to %q", prev, next)) + } +} + +// handlePacketQueuedLocked advances the state machine according to a packet +// being queued for outgoing transmission. +// +// Follows the logic defined in RFC 4861 section 7.3.3. +func (e *neighborEntry) handlePacketQueuedLocked() { + switch e.neigh.State { + case Unknown: + e.dispatchAddEventLocked(Incomplete) + e.setStateLocked(Incomplete) + + case Stale: + e.dispatchChangeEventLocked(Delay) + e.setStateLocked(Delay) + + case Incomplete, Reachable, Delay, Probe, Static, Failed: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleProbeLocked processes an incoming neighbor probe (e.g. ARP request or +// Neighbor Solicitation for ARP or NDP, respectively). +// +// Follows the logic defined in RFC 4861 section 7.2.3. +func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { + // Probes MUST be silently discarded if the target address is tentative, does + // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These + // checks MUST be done by the NetworkEndpoint. + + switch e.neigh.State { + case Unknown, Incomplete, Failed: + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchAddEventLocked(Stale) + e.setStateLocked(Stale) + e.notifyWakersLocked() + + case Reachable, Delay, Probe: + if e.neigh.LinkAddr != remoteLinkAddr { + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + + case Stale: + if e.neigh.LinkAddr != remoteLinkAddr { + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchChangeEventLocked(Stale) + } + + case Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleConfirmationLocked processes an incoming neighbor confirmation +// (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, respectively). +// +// Follows the state machine defined by RFC 4861 section 7.2.5. +// +// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other +// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol +// should be deployed where preventing access to the broadcast segment might +// not be possible. SEND uses RSA key pairs to produce Cryptographically +// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the +// claimed source of an NDP message is the owner of the claimed address. +func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + switch e.neigh.State { + case Incomplete: + if len(linkAddr) == 0 { + // "If the link layer has addresses and no Target Link-Layer Address + // option is included, the receiving node SHOULD silently discard the + // received advertisement." - RFC 4861 section 7.2.5 + break + } + + e.neigh.LinkAddr = linkAddr + if flags.Solicited { + e.dispatchChangeEventLocked(Reachable) + e.setStateLocked(Reachable) + } else { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + e.isRouter = flags.IsRouter + e.notifyWakersLocked() + + // "Note that the Override flag is ignored if the entry is in the + // INCOMPLETE state." - RFC 4861 section 7.2.5 + + case Reachable, Stale, Delay, Probe: + sameLinkAddr := e.neigh.LinkAddr == linkAddr + + if !sameLinkAddr { + if !flags.Override { + if e.neigh.State == Reachable { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + break + } + + e.neigh.LinkAddr = linkAddr + + if !flags.Solicited { + if e.neigh.State != Stale { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } else { + // Notify the LinkAddr change, even though NUD state hasn't changed. + e.dispatchChangeEventLocked(e.neigh.State) + } + break + } + } + + if flags.Solicited && (flags.Override || sameLinkAddr) { + if e.neigh.State != Reachable { + e.dispatchChangeEventLocked(Reachable) + } + // Set state to Reachable again to refresh timers. + e.setStateLocked(Reachable) + e.notifyWakersLocked() + } + + if e.isRouter && !flags.IsRouter { + // "In those cases where the IsRouter flag changes from TRUE to FALSE as + // a result of this update, the node MUST remove that router from the + // Default Router List and update the Destination Cache entries for all + // destinations using that neighbor as a router as specified in Section + // 7.3.3. This is needed to detect when a node that is used as a router + // stops forwarding packets due to being configured as a host." + // - RFC 4861 section 7.2.5 + e.nic.mu.Lock() + e.nic.mu.ndp.invalidateDefaultRouter(e.neigh.Addr) + e.nic.mu.Unlock() + } + e.isRouter = flags.IsRouter + + case Unknown, Failed, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleUpperLevelConfirmationLocked processes an incoming upper-level protocol +// (e.g. TCP acknowledgements) reachability confirmation. +func (e *neighborEntry) handleUpperLevelConfirmationLocked() { + switch e.neigh.State { + case Reachable, Stale, Delay, Probe: + if e.neigh.State != Reachable { + e.dispatchChangeEventLocked(Reachable) + // Set state to Reachable again to refresh timers. + } + e.setStateLocked(Reachable) + + case Unknown, Incomplete, Failed, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go new file mode 100644 index 000000000..08c9ccd25 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -0,0 +1,2770 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "math" + "math/rand" + "strings" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + + entryTestNICID tcpip.NICID = 1 + entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + + entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") + entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") + + // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, + // except where another value is explicitly used. It is chosen to match the + // MTU of loopback interfaces on Linux systems. + entryTestNetDefaultMTU = 65536 +) + +// eventDiffOpts are the options passed to cmp.Diff to compare entry events. +// The UpdatedAt field is ignored due to a lack of a deterministic method to +// predict the time that an event will be dispatched. +func eventDiffOpts() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + } +} + +// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to +// sort slices of events for cases where ordering must be ignored. +func eventDiffOptsWithSort() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + }), + } +} + +// The following unit tests exercise every state transition and verify its +// behavior with RFC 4681. +// +// | From | To | Cause | Action | Event | +// | ========== | ========== | ========================================== | =============== | ======= | +// | Unknown | Unknown | Confirmation w/ unknown address | | Added | +// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added | +// | Unknown | Stale | Probe w/ unknown address | | Added | +// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed | +// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed | +// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed | +// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed | +// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | | +// | Reachable | Stale | Reachable timer expired | | Changed | +// | Reachable | Stale | Probe or confirmation w/ different address | | Changed | +// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Stale | Stale | Override confirmation | Update LinkAddr | Changed | +// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | +// | Stale | Delay | Packet sent | | Changed | +// | Delay | Reachable | Upper-layer confirmation | | Changed | +// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Delay | Stale | Probe or confirmation w/ different address | | Changed | +// | Delay | Probe | Delay timer expired | Send probe | Changed | +// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed | +// | Probe | Stale | Probe or confirmation w/ different address | | Changed | +// | Probe | Probe | Retransmit timer expired | Send probe | Changed | +// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | +// | Failed | | Unreachability timer expired | Delete entry | | + +type testEntryEventType uint8 + +const ( + entryTestAdded testEntryEventType = iota + entryTestChanged + entryTestRemoved +) + +func (t testEntryEventType) String() string { + switch t { + case entryTestAdded: + return "add" + case entryTestChanged: + return "change" + case entryTestRemoved: + return "remove" + default: + return fmt.Sprintf("unknown (%d)", t) + } +} + +// Fields are exported for use with cmp.Diff. +type testEntryEventInfo struct { + EventType testEntryEventType + NICID tcpip.NICID + Addr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time +} + +func (e testEntryEventInfo) String() string { + return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State) +} + +// testNUDDispatcher implements NUDDispatcher to validate the dispatching of +// events upon certain NUD state machine events. +type testNUDDispatcher struct { + mu sync.Mutex + events []testEntryEventInfo +} + +var _ NUDDispatcher = (*testNUDDispatcher)(nil) + +func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) { + d.mu.Lock() + defer d.mu.Unlock() + d.events = append(d.events, e) +} + +func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestAdded, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestChanged, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestRemoved, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +type entryTestLinkResolver struct { + mu sync.Mutex + probes []entryTestProbeInfo +} + +var _ LinkAddressResolver = (*entryTestLinkResolver)(nil) + +type entryTestProbeInfo struct { + RemoteAddress tcpip.Address + RemoteLinkAddress tcpip.LinkAddress + LocalAddress tcpip.Address +} + +func (p entryTestProbeInfo) String() string { + return fmt.Sprintf("probe with RemoteAddress=%q, RemoteLinkAddress=%q, LocalAddress=%q", p.RemoteAddress, p.RemoteLinkAddress, p.LocalAddress) +} + +// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts +// to the local network if linkAddr is the zero value. +func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { + p := entryTestProbeInfo{ + RemoteAddress: addr, + RemoteLinkAddress: linkAddr, + LocalAddress: localAddr, + } + r.mu.Lock() + defer r.mu.Unlock() + r.probes = append(r.probes, p) + return nil +} + +// ResolveStaticAddress attempts to resolve address without sending requests. +// It either resolves the name immediately or returns the empty LinkAddress. +func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + return "", false +} + +// LinkAddressProtocol returns the network protocol of the addresses this +// resolver can resolve. +func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return entryTestNetNumber +} + +func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *fakeClock) { + clock := newFakeClock() + disp := testNUDDispatcher{} + nic := NIC{ + id: entryTestNICID, + linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint + stack: &Stack{ + clock: clock, + nudDisp: &disp, + }, + } + + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + nudState := NewNUDState(c, rng) + linkRes := entryTestLinkResolver{} + entry := newNeighborEntry(&nic, entryTestAddr1, entryTestAddr2, nudState, &linkRes) + + // Stub out ndpState to verify modification of default routers. + nic.mu.ndp = ndpState{ + nic: &nic, + defaultRouters: make(map[tcpip.Address]defaultRouterState), + } + + // Stub out the neighbor cache to verify deletion from the cache. + nic.neigh = &neighborCache{ + nic: &nic, + state: nudState, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + nic.neigh.cache[entryTestAddr1] = entry + + return entry, &disp, &linkRes, clock +} + +// TestEntryInitiallyUnknown verifies that the state of a newly created +// neighborEntry is Unknown. +func TestEntryInitiallyUnknown(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + if got, want := e.neigh.State, Unknown; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.RetransmitTimer) + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Unknown; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(time.Hour) + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryUnknownToIncomplete(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + { + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } +} + +func TestEntryUnknownToStale(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + updatedAt := e.neigh.UpdatedAt + e.mu.Unlock() + + clock.advance(c.RetransmitTimer) + + // UpdatedAt should remain the same during address resolution. + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.UpdatedAt, updatedAt; got != want { + t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.RetransmitTimer) + + // UpdatedAt should change after failing address resolution. Timing out after + // sending the last probe transitions the entry to Failed. + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + clock.advance(c.RetransmitTimer) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant { + t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got) + } + e.mu.Unlock() +} + +func TestEntryIncompleteToReachable(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +// TestEntryAddsAndClearsWakers verifies that wakers are added when +// addWakerLocked is called and cleared when address resolution finishes. In +// this case, address resolution will finish when transitioning from Incomplete +// to Reachable. +func TestEntryAddsAndClearsWakers(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + w := sleep.Waker{} + s := sleep.Sleeper{} + s.AddWaker(&w, 123) + defer s.Done() + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got := e.wakers; got != nil { + t.Errorf("got e.wakers = %v, want = nil", got) + } + e.addWakerLocked(&w) + if got, want := w.IsAsserted(), false; got != want { + t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + } + if e.wakers == nil { + t.Error("expected e.wakers to be non-nil") + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.wakers != nil { + t.Errorf("got e.wakers = %v, want = nil", e.wakers) + } + if got, want := w.IsAsserted(), true; got != want { + t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.isRouter, true; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + linkRes.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToStale(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToFailed(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) + clock.advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The Incomplete-to-Incomplete state transition is tested here by + // verifying that 3 reachability probes were sent. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Failed; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +type testLocker struct{} + +var _ sync.Locker = (*testLocker)(nil) + +func (*testLocker) Lock() {} +func (*testLocker) Unlock() {} + +func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.isRouter, true; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + e.nic.mu.ndp.defaultRouters[entryTestAddr1] = defaultRouterState{ + invalidationJob: e.nic.stack.newJob(&testLocker{}, func() {}), + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.isRouter, false; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + if _, ok := e.nic.mu.ndp.defaultRouters[entryTestAddr1]; ok { + t.Errorf("unexpected defaultRouter for %s", entryTestAddr1) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryReachableToStaleWhenTimeout(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToDelay(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleUpperLevelConfirmationLocked() + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 1 + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToProbe(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToFailed(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + c.MaxUnicastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + clock.advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The next three probe are caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Failed; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryFailedGetsDeleted(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + c.MaxUnicastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + // Verify the cache contains the entry. + if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { + t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) + } + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime + clock.advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The next three probe are caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + // Verify the cache no longer contains the entry. + if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok { + t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1) + } +} diff --git a/pkg/tcpip/stack/neighborstate_string.go b/pkg/tcpip/stack/neighborstate_string.go new file mode 100644 index 000000000..aa7311ec6 --- /dev/null +++ b/pkg/tcpip/stack/neighborstate_string.go @@ -0,0 +1,44 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type NeighborState"; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Unknown-0] + _ = x[Incomplete-1] + _ = x[Reachable-2] + _ = x[Stale-3] + _ = x[Delay-4] + _ = x[Probe-5] + _ = x[Static-6] + _ = x[Failed-7] +} + +const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailed" + +var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53} + +func (i NeighborState) String() string { + if i >= NeighborState(len(_NeighborState_index)-1) { + return "NeighborState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _NeighborState_name[_NeighborState_index[i]:_NeighborState_index[i+1]] +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 9256d4d43..f21066fce 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "math/rand" "reflect" "sort" "strings" @@ -45,6 +46,7 @@ type NIC struct { context NICContext stats NICStats + neigh *neighborCache mu struct { sync.RWMutex @@ -141,6 +143,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{} } + // Check for Neighbor Unreachability Detection support. + if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 { + rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds())) + nic.neigh = &neighborCache{ + nic: nic, + state: NewNUDState(stack.nudConfigs, rng), + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + } + nic.linkEP.Attach(nic) return nic @@ -1540,6 +1552,27 @@ func (n *NIC) setNDPConfigs(c NDPConfigurations) { n.mu.Unlock() } +// NUDConfigs gets the NUD configurations for n. +func (n *NIC) NUDConfigs() (NUDConfigurations, *tcpip.Error) { + if n.neigh == nil { + return NUDConfigurations{}, tcpip.ErrNotSupported + } + return n.neigh.config(), nil +} + +// setNUDConfigs sets the NUD configurations for n. +// +// Note, if c contains invalid NUD configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error { + if n.neigh == nil { + return tcpip.ErrNotSupported + } + c.resetInvalidFields() + n.neigh.setConfig(c) + return nil +} + // handleNDPRA handles an NDP Router Advertisement message that arrived on n. func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { n.mu.Lock() diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go new file mode 100644 index 000000000..f848d50ad --- /dev/null +++ b/pkg/tcpip/stack/nud.go @@ -0,0 +1,466 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "math" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // defaultBaseReachableTime is the default base duration for computing the + // random reachable time. + // + // Reachable time is the duration for which a neighbor is considered + // reachable after a positive reachability confirmation is received. It is a + // function of a uniformly distributed random value between the minimum and + // maximum random factors, multiplied by the base reachable time. Using a + // random component eliminates the possibility that Neighbor Unreachability + // Detection messages will synchronize with each other. + // + // Default taken from REACHABLE_TIME of RFC 4861 section 10. + defaultBaseReachableTime = 30 * time.Second + + // minimumBaseReachableTime is the minimum base duration for computing the + // random reachable time. + // + // Minimum = 1ms + minimumBaseReachableTime = time.Millisecond + + // defaultMinRandomFactor is the default minimum value of the random factor + // used for computing reachable time. + // + // Default taken from MIN_RANDOM_FACTOR of RFC 4861 section 10. + defaultMinRandomFactor = 0.5 + + // defaultMaxRandomFactor is the default maximum value of the random factor + // used for computing reachable time. + // + // The default value depends on the value of MinRandomFactor. + // If MinRandomFactor is less than MAX_RANDOM_FACTOR of RFC 4861 section 10, + // the value from the RFC will be used; otherwise, the default is + // MinRandomFactor multiplied by three. + defaultMaxRandomFactor = 1.5 + + // defaultRetransmitTimer is the default amount of time to wait between + // sending reachability probes. + // + // Default taken from RETRANS_TIMER of RFC 4861 section 10. + defaultRetransmitTimer = time.Second + + // minimumRetransmitTimer is the minimum amount of time to wait between + // sending reachability probes. + // + // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here + // to make sure the messages are not sent all at once. We also come to this + // value because in the RetransmitTimer field of a Router Advertisement, a + // value of 0 means unspecified, so the smallest valid value is 1. Note, the + // unit of the RetransmitTimer field in the Router Advertisement is + // milliseconds. + minimumRetransmitTimer = time.Millisecond + + // defaultDelayFirstProbeTime is the default duration to wait for a + // non-Neighbor-Discovery related protocol to reconfirm reachability after + // entering the DELAY state. After this time, a reachability probe will be + // sent and the entry will transition to the PROBE state. + // + // Default taken from DELAY_FIRST_PROBE_TIME of RFC 4861 section 10. + defaultDelayFirstProbeTime = 5 * time.Second + + // defaultMaxMulticastProbes is the default number of reachabililty probes + // to send before concluding negative reachability and deleting the neighbor + // entry from the INCOMPLETE state. + // + // Default taken from MAX_MULTICAST_SOLICIT of RFC 4861 section 10. + defaultMaxMulticastProbes = 3 + + // defaultMaxUnicastProbes is the default number of reachability probes to + // send before concluding retransmission from within the PROBE state should + // cease and the entry SHOULD be deleted. + // + // Default taken from MAX_UNICASE_SOLICIT of RFC 4861 section 10. + defaultMaxUnicastProbes = 3 + + // defaultMaxAnycastDelayTime is the default time in which the stack SHOULD + // delay sending a response for a random time between 0 and this time, if the + // target address is an anycast address. + // + // Default taken from MAX_ANYCAST_DELAY_TIME of RFC 4861 section 10. + defaultMaxAnycastDelayTime = time.Second + + // defaultMaxReachbilityConfirmations is the default amount of unsolicited + // reachability confirmation messages a node MAY send to all-node multicast + // address when it determines its link-layer address has changed. + // + // Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10. + defaultMaxReachbilityConfirmations = 3 + + // defaultUnreachableTime is the default duration for how long an entry will + // remain in the FAILED state before being removed from the neighbor cache. + // + // Note, there is no equivalent protocol constant defined in RFC 4861. It + // leaves the specifics of any garbage collection mechanism up to the + // implementation. + defaultUnreachableTime = 5 * time.Second +) + +// NUDDispatcher is the interface integrators of netstack must implement to +// receive and handle NUD related events. +type NUDDispatcher interface { + // OnNeighborAdded will be called when a new entry is added to a NIC's (with + // ID nicID) neighbor table. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + + // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID) + // neighbor table changes state and/or link address. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + + // OnNeighborRemoved will be called when an entry is removed from a NIC's + // (with ID nicID) neighbor table. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) +} + +// ReachabilityConfirmationFlags describes the flags used within a reachability +// confirmation (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, +// respectively). +type ReachabilityConfirmationFlags struct { + // Solicited indicates that the advertisement was sent in response to a + // reachability probe. + Solicited bool + + // Override indicates that the reachability confirmation should override an + // existing neighbor cache entry and update the cached link-layer address. + // When Override is not set the confirmation will not update a cached + // link-layer address, but will update an existing neighbor cache entry for + // which no link-layer address is known. + Override bool + + // IsRouter indicates that the sender is a router. + IsRouter bool +} + +// NUDHandler communicates external events to the Neighbor Unreachability +// Detection state machine, which is implemented per-interface. This is used by +// network endpoints to inform the Neighbor Cache of probes and confirmations. +type NUDHandler interface { + // HandleProbe processes an incoming neighbor probe (e.g. ARP request or + // Neighbor Solicitation for ARP or NDP, respectively). Validation of the + // probe needs to be performed before calling this function since the + // Neighbor Cache doesn't have access to view the NIC's assigned addresses. + HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) + + // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP + // reply or Neighbor Advertisement for ARP or NDP, respectively). + HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) + + // HandleUpperLevelConfirmation processes an incoming upper-level protocol + // (e.g. TCP acknowledgements) reachability confirmation. + HandleUpperLevelConfirmation(addr tcpip.Address) +} + +// NUDConfigurations is the NUD configurations for the netstack. This is used +// by the neighbor cache to operate the NUD state machine on each device in the +// local network. +type NUDConfigurations struct { + // BaseReachableTime is the base duration for computing the random reachable + // time. + // + // Reachable time is the duration for which a neighbor is considered + // reachable after a positive reachability confirmation is received. It is a + // function of uniformly distributed random value between minRandomFactor and + // maxRandomFactor multiplied by baseReachableTime. Using a random component + // eliminates the possibility that Neighbor Unreachability Detection messages + // will synchronize with each other. + // + // After this time, a neighbor entry will transition from REACHABLE to STALE + // state. + // + // Must be greater than 0. + BaseReachableTime time.Duration + + // LearnBaseReachableTime enables learning BaseReachableTime during runtime + // from the neighbor discovery protocol, if supported. + // + // TODO(gvisor.dev/issue/2240): Implement this NUD configuration option. + LearnBaseReachableTime bool + + // MinRandomFactor is the minimum value of the random factor used for + // computing reachable time. + // + // See BaseReachbleTime for more information on computing the reachable time. + // + // Must be greater than 0. + MinRandomFactor float32 + + // MaxRandomFactor is the maximum value of the random factor used for + // computing reachabile time. + // + // See BaseReachbleTime for more information on computing the reachable time. + // + // Must be great than or equal to MinRandomFactor. + MaxRandomFactor float32 + + // RetransmitTimer is the duration between retransmission of reachability + // probes in the PROBE state. + RetransmitTimer time.Duration + + // LearnRetransmitTimer enables learning RetransmitTimer during runtime from + // the neighbor discovery protocol, if supported. + // + // TODO(gvisor.dev/issue/2241): Implement this NUD configuration option. + LearnRetransmitTimer bool + + // DelayFirstProbeTime is the duration to wait for a non-Neighbor-Discovery + // related protocol to reconfirm reachability after entering the DELAY state. + // After this time, a reachability probe will be sent and the entry will + // transition to the PROBE state. + // + // Must be greater than 0. + DelayFirstProbeTime time.Duration + + // MaxMulticastProbes is the number of reachability probes to send before + // concluding negative reachability and deleting the neighbor entry from the + // INCOMPLETE state. + // + // Must be greater than 0. + MaxMulticastProbes uint32 + + // MaxUnicastProbes is the number of reachability probes to send before + // concluding retransmission from within the PROBE state should cease and + // entry SHOULD be deleted. + // + // Must be greater than 0. + MaxUnicastProbes uint32 + + // MaxAnycastDelayTime is the time in which the stack SHOULD delay sending a + // response for a random time between 0 and this time, if the target address + // is an anycast address. + // + // TODO(gvisor.dev/issue/2242): Use this option when sending solicited + // neighbor confirmations to anycast addresses and proxying neighbor + // confirmations. + MaxAnycastDelayTime time.Duration + + // MaxReachabilityConfirmations is the number of unsolicited reachability + // confirmation messages a node MAY send to all-node multicast address when + // it determines its link-layer address has changed. + // + // TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD + // configuration option is necessary. + MaxReachabilityConfirmations uint32 + + // UnreachableTime describes how long an entry will remain in the FAILED + // state before being removed from the neighbor cache. + UnreachableTime time.Duration +} + +// DefaultNUDConfigurations returns a NUDConfigurations populated with default +// values defined by RFC 4861 section 10. +func DefaultNUDConfigurations() NUDConfigurations { + return NUDConfigurations{ + BaseReachableTime: defaultBaseReachableTime, + LearnBaseReachableTime: true, + MinRandomFactor: defaultMinRandomFactor, + MaxRandomFactor: defaultMaxRandomFactor, + RetransmitTimer: defaultRetransmitTimer, + LearnRetransmitTimer: true, + DelayFirstProbeTime: defaultDelayFirstProbeTime, + MaxMulticastProbes: defaultMaxMulticastProbes, + MaxUnicastProbes: defaultMaxUnicastProbes, + MaxAnycastDelayTime: defaultMaxAnycastDelayTime, + MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations, + UnreachableTime: defaultUnreachableTime, + } +} + +// resetInvalidFields modifies an invalid NDPConfigurations with valid values. +// If invalid values are present in c, the corresponding default values will be +// used instead. This is needed to check, and conditionally fix, user-specified +// NUDConfigurations. +func (c *NUDConfigurations) resetInvalidFields() { + if c.BaseReachableTime < minimumBaseReachableTime { + c.BaseReachableTime = defaultBaseReachableTime + } + if c.MinRandomFactor <= 0 { + c.MinRandomFactor = defaultMinRandomFactor + } + if c.MaxRandomFactor < c.MinRandomFactor { + c.MaxRandomFactor = calcMaxRandomFactor(c.MinRandomFactor) + } + if c.RetransmitTimer < minimumRetransmitTimer { + c.RetransmitTimer = defaultRetransmitTimer + } + if c.DelayFirstProbeTime == 0 { + c.DelayFirstProbeTime = defaultDelayFirstProbeTime + } + if c.MaxMulticastProbes == 0 { + c.MaxMulticastProbes = defaultMaxMulticastProbes + } + if c.MaxUnicastProbes == 0 { + c.MaxUnicastProbes = defaultMaxUnicastProbes + } + if c.UnreachableTime == 0 { + c.UnreachableTime = defaultUnreachableTime + } +} + +// calcMaxRandomFactor calculates the maximum value of the random factor used +// for computing reachable time. This function is necessary for when the +// default specified in RFC 4861 section 10 is less than the current +// MinRandomFactor. +// +// Assumes minRandomFactor is positive since validation of the minimum value +// should come before the validation of the maximum. +func calcMaxRandomFactor(minRandomFactor float32) float32 { + if minRandomFactor > defaultMaxRandomFactor { + return minRandomFactor * 3 + } + return defaultMaxRandomFactor +} + +// A Rand is a source of random numbers. +type Rand interface { + // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0). + Float32() float32 +} + +// NUDState stores states needed for calculating reachable time. +type NUDState struct { + rng Rand + + // mu protects the fields below. + // + // It is necessary for NUDState to handle its own locking since neighbor + // entries may access the NUD state from within the goroutine spawned by + // time.AfterFunc(). This goroutine may run concurrently with the main + // process for controlling the neighbor cache and would otherwise introduce + // race conditions if NUDState was not locked properly. + mu sync.RWMutex + + config NUDConfigurations + + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration + + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 +} + +// NewNUDState returns new NUDState using c as configuration and the specified +// random number generator for use in recomputing ReachableTime. +func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { + s := &NUDState{ + rng: rng, + } + s.config = c + return s +} + +// Config returns the NUD configuration. +func (s *NUDState) Config() NUDConfigurations { + s.mu.RLock() + defer s.mu.RUnlock() + return s.config +} + +// SetConfig replaces the existing NUD configurations with c. +func (s *NUDState) SetConfig(c NUDConfigurations) { + s.mu.Lock() + defer s.mu.Unlock() + s.config = c +} + +// ReachableTime returns the duration to wait for a REACHABLE entry to +// transition into STALE after inactivity. This value is recalculated for new +// values of BaseReachableTime, MinRandomFactor, and MaxRandomFactor using the +// algorithm defined in RFC 4861 section 6.3.2. +func (s *NUDState) ReachableTime() time.Duration { + s.mu.Lock() + defer s.mu.Unlock() + + if time.Now().After(s.expiration) || + s.config.BaseReachableTime != s.prevBaseReachableTime || + s.config.MinRandomFactor != s.prevMinRandomFactor || + s.config.MaxRandomFactor != s.prevMaxRandomFactor { + return s.recomputeReachableTimeLocked() + } + return s.reachableTime +} + +// recomputeReachableTimeLocked forces a recalculation of ReachableTime using +// the algorithm defined in RFC 4861 section 6.3.2. +// +// This SHOULD automatically be invoked during certain situations, as per +// RFC 4861 section 6.3.4: +// +// If the received Reachable Time value is non-zero, the host SHOULD set its +// BaseReachableTime variable to the received value. If the new value +// differs from the previous value, the host SHOULD re-compute a new random +// ReachableTime value. ReachableTime is computed as a uniformly +// distributed random value between MIN_RANDOM_FACTOR and MAX_RANDOM_FACTOR +// times the BaseReachableTime. Using a random component eliminates the +// possibility that Neighbor Unreachability Detection messages will +// synchronize with each other. +// +// In most cases, the advertised Reachable Time value will be the same in +// consecutive Router Advertisements, and a host's BaseReachableTime rarely +// changes. In such cases, an implementation SHOULD ensure that a new +// random value gets re-computed at least once every few hours. +// +// s.mu MUST be locked for writing. +func (s *NUDState) recomputeReachableTimeLocked() time.Duration { + s.prevBaseReachableTime = s.config.BaseReachableTime + s.prevMinRandomFactor = s.config.MinRandomFactor + s.prevMaxRandomFactor = s.config.MaxRandomFactor + + randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + + // Check for overflow, given that minRandomFactor and maxRandomFactor are + // guaranteed to be positive numbers. + if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { + s.reachableTime = time.Duration(math.MaxInt64) + } else if randomFactor == 1 { + // Avoid loss of precision when a large base reachable time is used. + s.reachableTime = s.config.BaseReachableTime + } else { + reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) + s.reachableTime = time.Duration(reachableTime) + } + + s.expiration = time.Now().Add(2 * time.Hour) + return s.reachableTime +} diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go new file mode 100644 index 000000000..2494ee610 --- /dev/null +++ b/pkg/tcpip/stack/nud_test.go @@ -0,0 +1,795 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "math" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + defaultBaseReachableTime = 30 * time.Second + minimumBaseReachableTime = time.Millisecond + defaultMinRandomFactor = 0.5 + defaultMaxRandomFactor = 1.5 + defaultRetransmitTimer = time.Second + minimumRetransmitTimer = time.Millisecond + defaultDelayFirstProbeTime = 5 * time.Second + defaultMaxMulticastProbes = 3 + defaultMaxUnicastProbes = 3 + defaultMaxAnycastDelayTime = time.Second + defaultMaxReachbilityConfirmations = 3 + defaultUnreachableTime = 5 * time.Second + + defaultFakeRandomNum = 0.5 +) + +// fakeRand is a deterministic random number generator. +type fakeRand struct { + num float32 +} + +var _ stack.Rand = (*fakeRand)(nil) + +func (f *fakeRand) Float32() float32 { + return f.num +} + +// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if +// we attempt to update NUD configurations using an invalid NICID. +func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The networking + // stack will only allocate neighbor caches if a protocol providing link + // address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + + // No NIC with ID 1 yet. + config := stack.NUDConfigurations{} + if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID) + } +} + +// TestNUDConfigurationFailsForNotSupported tests to make sure we get a +// NotSupported error if we attempt to retrieve NUD configurations when the +// stack doesn't support NUD. +// +// The stack will report to not support NUD if a neighbor cache for a given NIC +// is not allocated. The networking stack will only allocate neighbor caches if +// a protocol providing link address resolution is specified (e.g. ARP, IPv6). +func TestNUDConfigurationFailsForNotSupported(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported { + t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported) + } +} + +// TestNUDConfigurationFailsForNotSupported tests to make sure we get a +// NotSupported error if we attempt to set NUD configurations when the stack +// doesn't support NUD. +// +// The stack will report to not support NUD if a neighbor cache for a given NIC +// is not allocated. The networking stack will only allocate neighbor caches if +// a protocol providing link address resolution is specified (e.g. ARP, IPv6). +func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + config := stack.NUDConfigurations{} + if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported { + t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported) + } +} + +// TestDefaultNUDConfigurationIsValid verifies that calling +// resetInvalidFields() on the result of DefaultNUDConfigurations() does not +// change anything. DefaultNUDConfigurations() should return a valid +// NUDConfigurations. +func TestDefaultNUDConfigurations(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The networking + // stack will only allocate neighbor caches if a protocol providing link + // address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: stack.DefaultNUDConfigurations(), + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + c, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got, want := c, stack.DefaultNUDConfigurations(); got != want { + t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want) + } +} + +func TestNUDConfigurationsBaseReachableTime(t *testing.T) { + tests := []struct { + name string + baseReachableTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + baseReachableTime: 0, + want: defaultBaseReachableTime, + }, + // Valid cases + { + name: "MoreThanZero", + baseReachableTime: time.Millisecond, + want: time.Millisecond, + }, + { + name: "MoreThanDefaultBaseReachableTime", + baseReachableTime: 2 * defaultBaseReachableTime, + want: 2 * defaultBaseReachableTime, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.BaseReachableTime = test.baseReachableTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.BaseReachableTime; got != test.want { + t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMinRandomFactor(t *testing.T) { + tests := []struct { + name string + minRandomFactor float32 + want float32 + }{ + // Invalid cases + { + name: "LessThanZero", + minRandomFactor: -1, + want: defaultMinRandomFactor, + }, + { + name: "EqualToZero", + minRandomFactor: 0, + want: defaultMinRandomFactor, + }, + // Valid cases + { + name: "MoreThanZero", + minRandomFactor: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MinRandomFactor = test.minRandomFactor + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MinRandomFactor; got != test.want { + t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { + tests := []struct { + name string + minRandomFactor float32 + maxRandomFactor float32 + want float32 + }{ + // Invalid cases + { + name: "LessThanZero", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: -1, + want: defaultMaxRandomFactor, + }, + { + name: "EqualToZero", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: 0, + want: defaultMaxRandomFactor, + }, + { + name: "LessThanMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor * 0.99, + want: defaultMaxRandomFactor, + }, + { + name: "MoreThanMinRandomFactorWhenMinRandomFactorIsLargerThanMaxRandomFactorDefault", + minRandomFactor: defaultMaxRandomFactor * 2, + maxRandomFactor: defaultMaxRandomFactor, + want: defaultMaxRandomFactor * 6, + }, + // Valid cases + { + name: "EqualToMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor, + want: defaultMinRandomFactor, + }, + { + name: "MoreThanMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor * 1.1, + want: defaultMinRandomFactor * 1.1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MinRandomFactor = test.minRandomFactor + c.MaxRandomFactor = test.maxRandomFactor + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxRandomFactor; got != test.want { + t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsRetransmitTimer(t *testing.T) { + tests := []struct { + name string + retransmitTimer time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + retransmitTimer: 0, + want: defaultRetransmitTimer, + }, + { + name: "LessThanMinimumRetransmitTimer", + retransmitTimer: minimumRetransmitTimer - time.Nanosecond, + want: defaultRetransmitTimer, + }, + // Valid cases + { + name: "EqualToMinimumRetransmitTimer", + retransmitTimer: minimumRetransmitTimer, + want: minimumBaseReachableTime, + }, + { + name: "LargetThanMinimumRetransmitTimer", + retransmitTimer: 2 * minimumBaseReachableTime, + want: 2 * minimumBaseReachableTime, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.RetransmitTimer = test.retransmitTimer + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.RetransmitTimer; got != test.want { + t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { + tests := []struct { + name string + delayFirstProbeTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + delayFirstProbeTime: 0, + want: defaultDelayFirstProbeTime, + }, + // Valid cases + { + name: "MoreThanZero", + delayFirstProbeTime: time.Millisecond, + want: time.Millisecond, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.DelayFirstProbeTime = test.delayFirstProbeTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.DelayFirstProbeTime; got != test.want { + t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { + tests := []struct { + name string + maxMulticastProbes uint32 + want uint32 + }{ + // Invalid cases + { + name: "EqualToZero", + maxMulticastProbes: 0, + want: defaultMaxMulticastProbes, + }, + // Valid cases + { + name: "MoreThanZero", + maxMulticastProbes: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MaxMulticastProbes = test.maxMulticastProbes + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxMulticastProbes; got != test.want { + t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { + tests := []struct { + name string + maxUnicastProbes uint32 + want uint32 + }{ + // Invalid cases + { + name: "EqualToZero", + maxUnicastProbes: 0, + want: defaultMaxUnicastProbes, + }, + // Valid cases + { + name: "MoreThanZero", + maxUnicastProbes: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MaxUnicastProbes = test.maxUnicastProbes + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxUnicastProbes; got != test.want { + t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsUnreachableTime(t *testing.T) { + tests := []struct { + name string + unreachableTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + unreachableTime: 0, + want: defaultUnreachableTime, + }, + // Valid cases + { + name: "MoreThanZero", + unreachableTime: time.Millisecond, + want: time.Millisecond, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.UnreachableTime = test.unreachableTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.UnreachableTime; got != test.want { + t.Errorf("got UnreachableTime = %q, want = %q", got, test.want) + } + }) + } +} + +// TestNUDStateReachableTime verifies the correctness of the ReachableTime +// computation. +func TestNUDStateReachableTime(t *testing.T) { + tests := []struct { + name string + baseReachableTime time.Duration + minRandomFactor float32 + maxRandomFactor float32 + want time.Duration + }{ + { + name: "AllZeros", + baseReachableTime: 0, + minRandomFactor: 0, + maxRandomFactor: 0, + want: 0, + }, + { + name: "ZeroMaxRandomFactor", + baseReachableTime: time.Second, + minRandomFactor: 0, + maxRandomFactor: 0, + want: 0, + }, + { + name: "ZeroMinRandomFactor", + baseReachableTime: time.Second, + minRandomFactor: 0, + maxRandomFactor: 1, + want: time.Duration(defaultFakeRandomNum * float32(time.Second)), + }, + { + name: "FractionalRandomFactor", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 0.001, + maxRandomFactor: 0.002, + want: time.Duration((0.001 + (0.001 * defaultFakeRandomNum)) * float32(math.MaxInt64)), + }, + { + name: "MinAndMaxRandomFactorsEqual", + baseReachableTime: time.Second, + minRandomFactor: 1, + maxRandomFactor: 1, + want: time.Second, + }, + { + name: "MinAndMaxRandomFactorsDifferent", + baseReachableTime: time.Second, + minRandomFactor: 1, + maxRandomFactor: 2, + want: time.Duration((1.0 + defaultFakeRandomNum) * float32(time.Second)), + }, + { + name: "MaxInt64", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 1, + maxRandomFactor: 1, + want: time.Duration(math.MaxInt64), + }, + { + name: "Overflow", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 1.5, + maxRandomFactor: 1.5, + want: time.Duration(math.MaxInt64), + }, + { + name: "DoubleOverflow", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 2.5, + maxRandomFactor: 2.5, + want: time.Duration(math.MaxInt64), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := stack.NUDConfigurations{ + BaseReachableTime: test.baseReachableTime, + MinRandomFactor: test.minRandomFactor, + MaxRandomFactor: test.maxRandomFactor, + } + // A fake random number generator is used to ensure deterministic + // results. + rng := fakeRand{ + num: defaultFakeRandomNum, + } + s := stack.NewNUDState(c, &rng) + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + }) + } +} + +// TestNUDStateRecomputeReachableTime exercises the ReachableTime function +// twice to verify recomputation of reachable time when the min random factor, +// max random factor, or base reachable time changes. +func TestNUDStateRecomputeReachableTime(t *testing.T) { + const defaultBase = time.Second + const defaultMin = 2.0 * defaultMaxRandomFactor + const defaultMax = 3.0 * defaultMaxRandomFactor + + tests := []struct { + name string + baseReachableTime time.Duration + minRandomFactor float32 + maxRandomFactor float32 + want time.Duration + }{ + { + name: "BaseReachableTime", + baseReachableTime: 2 * defaultBase, + minRandomFactor: defaultMin, + maxRandomFactor: defaultMax, + want: time.Duration((defaultMin + (defaultMax-defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), + }, + { + name: "MinRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: defaultMax, + maxRandomFactor: defaultMax, + want: time.Duration(defaultMax * float32(defaultBase)), + }, + { + name: "MaxRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: defaultMin, + maxRandomFactor: defaultMin, + want: time.Duration(defaultMin * float32(defaultBase)), + }, + { + name: "BothRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: 2 * defaultMin, + maxRandomFactor: 2 * defaultMax, + want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(defaultBase)), + }, + { + name: "BaseReachableTimeAndBothRandomFactors", + baseReachableTime: 2 * defaultBase, + minRandomFactor: 2 * defaultMin, + maxRandomFactor: 2 * defaultMax, + want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := stack.DefaultNUDConfigurations() + c.BaseReachableTime = defaultBase + c.MinRandomFactor = defaultMin + c.MaxRandomFactor = defaultMax + + // A fake random number generator is used to ensure deterministic + // results. + rng := fakeRand{ + num: defaultFakeRandomNum, + } + s := stack.NewNUDState(c, &rng) + old := s.ReachableTime() + + if got, want := s.ReachableTime(), old; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + + // Check for recomputation when changing the min random factor, the max + // random factor, the base reachability time, or any permutation of those + // three options. + c.BaseReachableTime = test.baseReachableTime + c.MinRandomFactor = test.minRandomFactor + c.MaxRandomFactor = test.maxRandomFactor + s.SetConfig(c) + + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + + // Verify that ReachableTime isn't recomputed when none of the + // configuration options change. The random factor is changed so that if + // a recompution were to occur, ReachableTime would change. + rng.num = defaultFakeRandomNum / 2.0 + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + }) + } +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 36626e051..3f07e4159 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -445,6 +445,9 @@ type Stack struct { // ndpConfigs is the default NDP configurations used by interfaces. ndpConfigs NDPConfigurations + // nudConfigs is the default NUD configurations used by interfaces. + nudConfigs NUDConfigurations + // autoGenIPv6LinkLocal determines whether or not the stack will attempt // to auto-generate an IPv6 link-local address for newly enabled non-loopback // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. @@ -454,6 +457,10 @@ type Stack struct { // integrator NDP related events. ndpDisp NDPDispatcher + // nudDisp is the NUD event dispatcher that is used to send the netstack + // integrator NUD related events. + nudDisp NUDDispatcher + // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID @@ -518,6 +525,9 @@ type Options struct { // before assigning an address to a NIC. NDPConfigs NDPConfigurations + // NUDConfigs is the default NUD configurations used by interfaces. + NUDConfigs NUDConfigurations + // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to // auto-generate an IPv6 link-local address for newly enabled non-loopback // NICs. @@ -536,6 +546,10 @@ type Options struct { // receive NDP related events. NDPDisp NDPDispatcher + // NUDDisp is the NUD event dispatcher that an integrator can provide to + // receive NUD related events. + NUDDisp NUDDispatcher + // RawFactory produces raw endpoints. Raw endpoints are enabled only if // this is non-nil. RawFactory RawFactory @@ -670,6 +684,8 @@ func New(opts Options) *Stack { // Make sure opts.NDPConfigs contains valid values only. opts.NDPConfigs.validate() + opts.NUDConfigs.resetInvalidFields() + s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), @@ -685,9 +701,11 @@ func New(opts Options) *Stack { icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, + nudConfigs: opts.NUDConfigs, autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, + nudDisp: opts.NUDDisp, opaqueIIDOpts: opts.OpaqueIIDOpts, tempIIDSeed: opts.TempIIDSeed, forwarder: newForwardQueue(), @@ -1869,10 +1887,38 @@ func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip } nic.setNDPConfigs(c) - return nil } +// NUDConfigurations gets the per-interface NUD configurations. +func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) { + s.mu.RLock() + nic, ok := s.nics[id] + s.mu.RUnlock() + + if !ok { + return NUDConfigurations{}, tcpip.ErrUnknownNICID + } + + return nic.NUDConfigs() +} + +// SetNUDConfigurations sets the per-interface NUD configurations. +// +// Note, if c contains invalid NUD configuration values, it will be fixed to +// use default values for the erroneous values. +func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[id] + s.mu.RUnlock() + + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.setNUDConfigs(c) +} + // HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement // message that it needs to handle. func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error { -- cgit v1.2.3 From 2a7b2a61e3ea32129c26eeaa6fab3d81a5d8ad6e Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Thu, 11 Jun 2020 20:33:56 -0700 Subject: iptables: support SO_ORIGINAL_DST Envoy (#170) uses this to get the original destination of redirected packets. --- pkg/abi/linux/netfilter.go | 8 +- pkg/abi/linux/socket.go | 4 +- pkg/sentry/socket/netstack/netstack.go | 17 ++++ pkg/sentry/strace/socket.go | 1 + pkg/tcpip/stack/conntrack.go | 26 ++++++ pkg/tcpip/stack/iptables.go | 11 ++- pkg/tcpip/tcpip.go | 4 + pkg/tcpip/transport/tcp/endpoint.go | 11 +++ test/iptables/BUILD | 1 + test/iptables/iptables_test.go | 8 ++ test/iptables/iptables_unsafe.go | 63 ++++++++++++++ test/iptables/iptables_util.go | 51 ++++++++++- test/iptables/nat.go | 152 ++++++++++++++++++++++++++++++++- 13 files changed, 344 insertions(+), 13 deletions(-) create mode 100644 test/iptables/iptables_unsafe.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index a91f9f018..9c27f7bb2 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -59,7 +59,7 @@ var VerdictStrings = map[int32]string{ NF_RETURN: "RETURN", } -// Socket options. These correspond to values in +// Socket options for SOL_SOCKET. These correspond to values in // include/uapi/linux/netfilter_ipv4/ip_tables.h. const ( IPT_BASE_CTL = 64 @@ -74,6 +74,12 @@ const ( IPT_SO_GET_MAX = IPT_SO_GET_REVISION_TARGET ) +// Socket option for SOL_IP. This corresponds to the value in +// include/uapi/linux/netfilter_ipv4.h. +const ( + SO_ORIGINAL_DST = 80 +) + // Name lengths. These correspond to values in // include/uapi/linux/netfilter/x_tables.h. const ( diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index c24a8216e..d6946bb82 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -239,11 +239,13 @@ const SockAddrMax = 128 type InetAddr [4]byte // SockAddrInet is struct sockaddr_in, from uapi/linux/in.h. +// +// +marshal type SockAddrInet struct { Family uint16 Port uint16 Addr InetAddr - Zero [8]uint8 // pad to sizeof(struct sockaddr). + _ [8]uint8 // pad to sizeof(struct sockaddr). } // InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h. diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index f86e6cd7a..31a168f7e 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1490,6 +1490,10 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marsha vP := primitive.Int32(boolToInt32(v)) return &vP, nil + case linux.SO_ORIGINAL_DST: + // TODO(gvisor.dev/issue/170): ip6tables. + return nil, syserr.ErrInvalidArgument + default: emitUnimplementedEventIPv6(t, name) } @@ -1600,6 +1604,19 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in vP := primitive.Int32(boolToInt32(v)) return &vP, nil + case linux.SO_ORIGINAL_DST: + if outLen < int(binary.Size(linux.SockAddrInet{})) { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.OriginalDestinationOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) + return a.(*linux.SockAddrInet), nil + default: emitUnimplementedEventIP(t, name) } diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index c0512de89..b51c4c941 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -521,6 +521,7 @@ var sockOptNames = map[uint64]abi.ValueSet{ linux.IP_ROUTER_ALERT: "IP_ROUTER_ALERT", linux.IP_PKTOPTIONS: "IP_PKTOPTIONS", linux.IP_MTU: "IP_MTU", + linux.SO_ORIGINAL_DST: "SO_ORIGINAL_DST", }, linux.SOL_SOCKET: { linux.SO_ERROR: "SO_ERROR", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 559a1c4dd..470c265aa 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -240,7 +240,10 @@ func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { if err != nil { return nil, dirOriginal } + return ct.connForTID(tid) +} +func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { bucket := ct.bucket(tid) now := time.Now() @@ -604,3 +607,26 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } + +func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { + // Lookup the connection. The reply's original destination + // describes the original address. + tid := tupleID{ + srcAddr: epID.LocalAddress, + srcPort: epID.LocalPort, + dstAddr: epID.RemoteAddress, + dstPort: epID.RemotePort, + transProto: header.TCPProtocolNumber, + netProto: header.IPv4ProtocolNumber, + } + conn, _ := ct.connForTID(tid) + if conn == nil { + // Not a tracked connection. + return "", 0, tcpip.ErrNotConnected + } else if conn.manip == manipNone { + // Unmanipulated connection. + return "", 0, tcpip.ErrInvalidOptionValue + } + + return conn.original.dstAddr, conn.original.dstPort, nil +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index cbbae4224..110ba073d 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -218,19 +218,16 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Many users never configure iptables. Spare them the cost of rule // traversal if rules have never been set. it.mu.RLock() + defer it.mu.RUnlock() if !it.modified { - it.mu.RUnlock() return true } - it.mu.RUnlock() // Packets are manipulated only if connection and matching // NAT rule exists. shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - it.mu.RLock() - defer it.mu.RUnlock() priorities := it.priorities[hook] for _, tableID := range priorities { // If handlePacket already NATed the packet, we don't need to @@ -418,3 +415,9 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx // All the matchers matched, so run the target. return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) } + +// OriginalDst returns the original destination of redirected connections. It +// returns an error if the connection doesn't exist or isn't redirected. +func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { + return it.connections.originalDst(epID) +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index a634b9b60..45f59b60f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -954,6 +954,10 @@ type DefaultTTLOption uint8 // classic BPF filter on a given endpoint. type SocketDetachFilterOption int +// OriginalDestinationOption is used to get the original destination address +// and port of a redirected packet. +type OriginalDestinationOption FullAddress + // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 0f7487963..682687ebe 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2017,6 +2017,17 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = tcpip.TCPDeferAcceptOption(e.deferAccept) e.UnlockUser() + case *tcpip.OriginalDestinationOption: + ipt := e.stack.IPTables() + addr, port, err := ipt.OriginalDst(e.ID) + if err != nil { + return err + } + *o = tcpip.OriginalDestinationOption{ + Addr: addr, + Port: port, + } + default: return tcpip.ErrUnknownProtocolOption } diff --git a/test/iptables/BUILD b/test/iptables/BUILD index 40b63ebbe..66453772a 100644 --- a/test/iptables/BUILD +++ b/test/iptables/BUILD @@ -9,6 +9,7 @@ go_library( "filter_input.go", "filter_output.go", "iptables.go", + "iptables_unsafe.go", "iptables_util.go", "nat.go", ], diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 550b6198a..fda5f694f 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -371,3 +371,11 @@ func TestFilterAddrs(t *testing.T) { } } } + +func TestNATPreOriginalDst(t *testing.T) { + singleTest(t, NATPreOriginalDst{}) +} + +func TestNATOutOriginalDst(t *testing.T) { + singleTest(t, NATOutOriginalDst{}) +} diff --git a/test/iptables/iptables_unsafe.go b/test/iptables/iptables_unsafe.go new file mode 100644 index 000000000..bd85a8fea --- /dev/null +++ b/test/iptables/iptables_unsafe.go @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "fmt" + "syscall" + "unsafe" +) + +type originalDstError struct { + errno syscall.Errno +} + +func (e originalDstError) Error() string { + return fmt.Sprintf("errno (%d) when calling getsockopt(SO_ORIGINAL_DST): %v", int(e.errno), e.errno.Error()) +} + +// SO_ORIGINAL_DST gets the original destination of a redirected packet via +// getsockopt. +const SO_ORIGINAL_DST = 80 + +func originalDestination4(connfd int) (syscall.RawSockaddrInet4, error) { + var addr syscall.RawSockaddrInet4 + var addrLen uint32 = syscall.SizeofSockaddrInet4 + if errno := originalDestination(connfd, syscall.SOL_IP, unsafe.Pointer(&addr), &addrLen); errno != 0 { + return syscall.RawSockaddrInet4{}, originalDstError{errno} + } + return addr, nil +} + +func originalDestination6(connfd int) (syscall.RawSockaddrInet6, error) { + var addr syscall.RawSockaddrInet6 + var addrLen uint32 = syscall.SizeofSockaddrInet6 + if errno := originalDestination(connfd, syscall.SOL_IPV6, unsafe.Pointer(&addr), &addrLen); errno != 0 { + return syscall.RawSockaddrInet6{}, originalDstError{errno} + } + return addr, nil +} + +func originalDestination(connfd int, level uintptr, optval unsafe.Pointer, optlen *uint32) syscall.Errno { + _, _, errno := syscall.Syscall6( + syscall.SYS_GETSOCKOPT, + uintptr(connfd), + level, + SO_ORIGINAL_DST, + uintptr(optval), + uintptr(unsafe.Pointer(optlen)), + 0) + return errno +} diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index ca80a4b5f..5125fe47b 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -15,6 +15,8 @@ package iptables import ( + "encoding/binary" + "errors" "fmt" "net" "os/exec" @@ -218,17 +220,58 @@ func filterAddrs(addrs []string, ipv6 bool) []string { // getInterfaceName returns the name of the interface other than loopback. func getInterfaceName() (string, bool) { - var ifname string + iface, ok := getNonLoopbackInterface() + if !ok { + return "", false + } + return iface.Name, true +} + +func getInterfaceAddrs(ipv6 bool) ([]net.IP, error) { + iface, ok := getNonLoopbackInterface() + if !ok { + return nil, errors.New("no non-loopback interface found") + } + addrs, err := iface.Addrs() + if err != nil { + return nil, err + } + + // Get only IPv4 or IPv6 addresses. + ips := make([]net.IP, 0, len(addrs)) + for _, addr := range addrs { + parts := strings.Split(addr.String(), "/") + var ip net.IP + // To16() returns IPv4 addresses as IPv4-mapped IPv6 addresses. + // So we check whether To4() returns nil to test whether the + // address is v4 or v6. + if v4 := net.ParseIP(parts[0]).To4(); ipv6 && v4 == nil { + ip = net.ParseIP(parts[0]).To16() + } else { + ip = v4 + } + if ip != nil { + ips = append(ips, ip) + } + } + return ips, nil +} + +func getNonLoopbackInterface() (net.Interface, bool) { if interfaces, err := net.Interfaces(); err == nil { for _, intf := range interfaces { if intf.Name != "lo" { - ifname = intf.Name - break + return intf, true } } } + return net.Interface{}, false +} - return ifname, ifname != "" +func htons(x uint16) uint16 { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, x) + return binary.LittleEndian.Uint16(buf) } func localIP(ipv6 bool) string { diff --git a/test/iptables/nat.go b/test/iptables/nat.go index ac0d91bb2..b7fea2527 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -18,12 +18,11 @@ import ( "errors" "fmt" "net" + "syscall" "time" ) -const ( - redirectPort = 42 -) +const redirectPort = 42 func init() { RegisterTestCase(NATPreRedirectUDPPort{}) @@ -42,6 +41,8 @@ func init() { RegisterTestCase(NATOutRedirectInvert{}) RegisterTestCase(NATRedirectRequiresProtocol{}) RegisterTestCase(NATLoopbackSkipsPrerouting{}) + RegisterTestCase(NATPreOriginalDst{}) + RegisterTestCase(NATOutOriginalDst{}) } // NATPreRedirectUDPPort tests that packets are redirected to different port. @@ -471,6 +472,151 @@ func (NATLoopbackSkipsPrerouting) LocalAction(ip net.IP, ipv6 bool) error { return nil } +// NATPreOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination +// of PREROUTING NATted packets. +type NATPreOriginalDst struct{} + +// Name implements TestCase.Name. +func (NATPreOriginalDst) Name() string { + return "NATPreOriginalDst" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { + // Redirect incoming TCP connections to acceptPort. + if err := natTable(ipv6, "-A", "PREROUTING", + "-p", "tcp", + "--destination-port", fmt.Sprintf("%d", dropPort), + "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + + addrs, err := getInterfaceAddrs(ipv6) + if err != nil { + return err + } + return listenForRedirectedConn(ipv6, addrs) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreOriginalDst) LocalAction(ip net.IP, ipv6 bool) error { + return connectTCP(ip, dropPort, sendloopDuration) +} + +// NATOutOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination +// of OUTBOUND NATted packets. +type NATOutOriginalDst struct{} + +// Name implements TestCase.Name. +func (NATOutOriginalDst) Name() string { + return "NATOutOriginalDst" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { + // Redirect incoming TCP connections to acceptPort. + if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil { + return err + } + + connCh := make(chan error) + go func() { + connCh <- connectTCP(ip, dropPort, sendloopDuration) + }() + + if err := listenForRedirectedConn(ipv6, []net.IP{ip}); err != nil { + return err + } + return <-connCh +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutOriginalDst) LocalAction(ip net.IP, ipv6 bool) error { + // No-op. + return nil +} + +func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { + // The net package doesn't give guarantee access to the connection's + // underlying FD, and thus we cannot call getsockopt. We have to use + // traditional syscalls for SO_ORIGINAL_DST. + + // Create the listening socket, bind, listen, and accept. + family := syscall.AF_INET + if ipv6 { + family = syscall.AF_INET6 + } + sockfd, err := syscall.Socket(family, syscall.SOCK_STREAM, 0) + if err != nil { + return err + } + defer syscall.Close(sockfd) + + var bindAddr syscall.Sockaddr + if ipv6 { + bindAddr = &syscall.SockaddrInet6{ + Port: acceptPort, + Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any + } + } else { + bindAddr = &syscall.SockaddrInet4{ + Port: acceptPort, + Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY + } + } + if err := syscall.Bind(sockfd, bindAddr); err != nil { + return err + } + + if err := syscall.Listen(sockfd, 1); err != nil { + return err + } + + connfd, _, err := syscall.Accept(sockfd) + if err != nil { + return err + } + defer syscall.Close(connfd) + + // Verify that, despite listening on acceptPort, SO_ORIGINAL_DST + // indicates the packet was sent to originalDst:dropPort. + if ipv6 { + got, err := originalDestination6(connfd) + if err != nil { + return err + } + // The original destination could be any of our IPs. + for _, dst := range originalDsts { + want := syscall.RawSockaddrInet6{ + Family: syscall.AF_INET6, + Port: htons(dropPort), + } + copy(want.Addr[:], dst.To16()) + if got == want { + return nil + } + } + return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) + } else { + got, err := originalDestination4(connfd) + if err != nil { + return err + } + // The original destination could be any of our IPs. + for _, dst := range originalDsts { + want := syscall.RawSockaddrInet4{ + Family: syscall.AF_INET, + Port: htons(dropPort), + } + copy(want.Addr[:], dst.To4()) + if got == want { + return nil + } + } + return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) + } +} + // loopbackTests runs an iptables rule and ensures that packets sent to // dest:dropPort are received by localhost:acceptPort. func loopbackTest(ipv6 bool, dest net.IP, args ...string) error { -- cgit v1.2.3 From 0e6f7a12c29efa52581c38ca30637b133556a6cf Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Tue, 4 Aug 2020 20:57:28 -0700 Subject: Update variables for implementation of RACK in TCP RACK (Recent Acknowledgement) is a new loss detection algorithm in TCP. These are the fields which should be stored on connections to implement RACK algorithm. PiperOrigin-RevId: 324948703 --- pkg/tcpip/stack/stack.go | 13 +++++ pkg/tcpip/transport/tcp/BUILD | 3 ++ pkg/tcpip/transport/tcp/accept.go | 2 +- pkg/tcpip/transport/tcp/connect.go | 13 +++-- pkg/tcpip/transport/tcp/endpoint.go | 15 ++++-- pkg/tcpip/transport/tcp/rack.go | 82 ++++++++++++++++++++++++++++++++ pkg/tcpip/transport/tcp/rack_state.go | 29 +++++++++++ pkg/tcpip/transport/tcp/snd.go | 40 ++++++++++------ pkg/tcpip/transport/tcp/tcp_rack_test.go | 74 ++++++++++++++++++++++++++++ 9 files changed, 248 insertions(+), 23 deletions(-) create mode 100644 pkg/tcpip/transport/tcp/rack.go create mode 100644 pkg/tcpip/transport/tcp/rack_state.go create mode 100644 pkg/tcpip/transport/tcp/tcp_rack_test.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3f07e4159..7189e8e7e 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -73,6 +73,16 @@ type TCPCubicState struct { WEst float64 } +// TCPRACKState is used to hold a copy of the internal RACK state when the +// TCPProbeFunc is invoked. +type TCPRACKState struct { + XmitTime time.Time + EndSequence seqnum.Value + FACK seqnum.Value + RTT time.Duration + Reord bool +} + // TCPEndpointID is the unique 4 tuple that identifies a given endpoint. type TCPEndpointID struct { // LocalPort is the local port associated with the endpoint. @@ -212,6 +222,9 @@ type TCPSenderState struct { // Cubic holds the state related to CUBIC congestion control. Cubic TCPCubicState + + // RACKState holds the state related to RACK loss detection algorithm. + RACKState TCPRACKState } // TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index e860ee484..234fb95ce 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -40,6 +40,8 @@ go_library( "endpoint_state.go", "forwarder.go", "protocol.go", + "rack.go", + "rack_state.go", "rcv.go", "rcv_state.go", "reno.go", @@ -83,6 +85,7 @@ go_test( "dual_stack_test.go", "sack_scoreboard_test.go", "tcp_noracedetector_test.go", + "tcp_rack_test.go", "tcp_sack_test.go", "tcp_test.go", "tcp_timestamp_test.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 6e00e5526..913ea6535 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -521,7 +521,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, - TSVal: tcpTimeStamp(timeStampOffset()), + TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), TSEcr: opts.TSVal, MSS: mssForRoute(&s.route), } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 6e5e55b6f..8dd759ba2 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1166,13 +1166,18 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { return nil } -// handleSegment handles a given segment and notifies the worker goroutine if -// if the connection should be terminated. -func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { - // Invoke the tcp probe if installed. +func (e *endpoint) probeSegment() { if e.probe != nil { e.probe(e.completeState()) } +} + +// handleSegment handles a given segment and notifies the worker goroutine if +// if the connection should be terminated. +func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { + // Invoke the tcp probe if installed. The tcp probe function will update + // the TCPEndpointState after the segment is processed. + defer e.probeSegment() if s.flagIsSet(header.TCPFlagRst) { if ok, err := e.handleReset(s); !ok { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 682687ebe..39ea38fe6 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2692,15 +2692,14 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { // timestamp returns the timestamp value to be used in the TSVal field of the // timestamp option for outgoing TCP segments for a given endpoint. func (e *endpoint) timestamp() uint32 { - return tcpTimeStamp(e.tsOffset) + return tcpTimeStamp(time.Now(), e.tsOffset) } // tcpTimeStamp returns a timestamp offset by the provided offset. This is // not inlined above as it's used when SYN cookies are in use and endpoint // is not created at the time when the SYN cookie is sent. -func tcpTimeStamp(offset uint32) uint32 { - now := time.Now() - return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset +func tcpTimeStamp(curTime time.Time, offset uint32) uint32 { + return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset } // timeStampOffset returns a randomized timestamp offset to be used when sending @@ -2843,6 +2842,14 @@ func (e *endpoint) completeState() stack.TCPEndpointState { WEst: cubic.wEst, } } + + rc := e.snd.rc + s.Sender.RACKState = stack.TCPRACKState{ + XmitTime: rc.xmitTime, + EndSequence: rc.endSequence, + FACK: rc.fack, + RTT: rc.rtt, + } return s } diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go new file mode 100644 index 000000000..d969ca23a --- /dev/null +++ b/pkg/tcpip/transport/tcp/rack.go @@ -0,0 +1,82 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" + + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +// RACK is a loss detection algorithm used in TCP to detect packet loss and +// reordering using transmission timestamp of the packets instead of packet or +// sequence counts. To use RACK, SACK should be enabled on the connection. + +// rackControl stores the rack related fields. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-6.1 +// +// +stateify savable +type rackControl struct { + // xmitTime is the latest transmission timestamp of rackControl.seg. + xmitTime time.Time `state:".(unixTime)"` + + // endSequence is the ending TCP sequence number of rackControl.seg. + endSequence seqnum.Value + + // fack is the highest selectively or cumulatively acknowledged + // sequence. + fack seqnum.Value + + // rtt is the RTT of the most recently delivered packet on the + // connection (either cumulatively acknowledged or selectively + // acknowledged) that was not marked invalid as a possible spurious + // retransmission. + rtt time.Duration +} + +// Update will update the RACK related fields when an ACK has been received. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 +func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, offset uint32) { + rtt := time.Now().Sub(seg.xmitTime) + + // If the ACK is for a retransmitted packet, do not update if it is a + // spurious inference which is determined by below checks: + // 1. When Timestamping option is available, if the TSVal is less than the + // transmit time of the most recent retransmitted packet. + // 2. When RTT calculated for the packet is less than the smoothed RTT + // for the connection. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 + // step 2 + if seg.xmitCount > 1 { + if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 { + if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) { + return + } + } + if rtt < srtt { + return + } + } + + rc.rtt = rtt + // Update rc.xmitTime and rc.endSequence to the transmit time and + // ending sequence number of the packet which has been acknowledged + // most recently. + endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) { + rc.xmitTime = seg.xmitTime + rc.endSequence = endSeq + } +} diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go new file mode 100644 index 000000000..c9dc7e773 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rack_state.go @@ -0,0 +1,29 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveXmitTime is invoked by stateify. +func (rc *rackControl) saveXmitTime() unixTime { + return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()} +} + +// loadXmitTime is invoked by stateify. +func (rc *rackControl) loadXmitTime(unix unixTime) { + rc.xmitTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 5862c32f2..c55589c45 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -191,6 +191,10 @@ type sender struct { // cc is the congestion control algorithm in use for this sender. cc congestionControl + + // rc has the fields needed for implementing RACK loss detection + // algorithm. + rc rackControl } // rtt is a synchronization wrapper used to appease stateify. See the comment @@ -1272,21 +1276,21 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { // handleRcvdSegment is called when a segment is received; it is responsible for // updating the send-related state. -func (s *sender) handleRcvdSegment(seg *segment) { +func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Check if we can extract an RTT measurement from this ack. - if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) { + if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { s.updateRTO(time.Now().Sub(s.rttMeasureTime)) s.rttMeasureSeqNum = s.sndNxt } // Update Timestamp if required. See RFC7323, section-4.3. - if s.ep.sendTSOk && seg.parsedOptions.TS { - s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber) + if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS { + s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber) } // Insert SACKBlock information into our scoreboard. if s.ep.sackPermitted { - for _, sb := range seg.parsedOptions.SACKBlocks { + for _, sb := range rcvdSeg.parsedOptions.SACKBlocks { // Only insert the SACK block if the following holds // true: // * SACK block acks data after the ack number in the @@ -1299,27 +1303,27 @@ func (s *sender) handleRcvdSegment(seg *segment) { // NOTE: This check specifically excludes DSACK blocks // which have start/end before sndUna and are used to // indicate spurious retransmissions. - if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) { + if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) { s.ep.scoreboard.Insert(sb) - seg.hasNewSACKInfo = true + rcvdSeg.hasNewSACKInfo = true } } s.SetPipe() } // Count the duplicates and do the fast retransmit if needed. - rtx := s.checkDuplicateAck(seg) + rtx := s.checkDuplicateAck(rcvdSeg) // Stash away the current window size. - s.sndWnd = seg.window + s.sndWnd = rcvdSeg.window - ack := seg.ackNumber + ack := rcvdSeg.ackNumber // Disable zero window probing if remote advertizes a non-zero receive // window. This can be with an ACK to the zero window probe (where the // acknumber refers to the already acknowledged byte) OR to any previously // unacknowledged segment. - if s.zeroWindowProbing && seg.window > 0 && + if s.zeroWindowProbing && rcvdSeg.window > 0 && (ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) { s.disableZeroWindowProbing() } @@ -1344,10 +1348,10 @@ func (s *sender) handleRcvdSegment(seg *segment) { // averaged RTT measurement only if the segment acknowledges // some new data, i.e., only if it advances the left edge of // the send window. - if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 { + if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 { // TSVal/Ecr values sent by Netstack are at a millisecond // granularity. - elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond + elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond s.updateRTO(elapsed) } @@ -1361,6 +1365,9 @@ func (s *sender) handleRcvdSegment(seg *segment) { ackLeft := acked originalOutstanding := s.outstanding + s.rtt.Lock() + srtt := s.rtt.srtt + s.rtt.Unlock() for ackLeft > 0 { // We use logicalLen here because we can have FIN // segments (which are always at the end of list) that @@ -1380,6 +1387,11 @@ func (s *sender) handleRcvdSegment(seg *segment) { s.writeNext = seg.Next() } + // Update the RACK fields if SACK is enabled. + if s.ep.sackPermitted { + s.rc.Update(seg, rcvdSeg, srtt, s.ep.tsOffset) + } + s.writeList.Remove(seg) // if SACK is enabled then Only reduce outstanding if @@ -1435,7 +1447,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { // that the window opened up, or the congestion window was inflated due // to a duplicate ack during fast recovery. This will also re-enable // the retransmit timer if needed. - if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo { + if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo { s.sendData() } } diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go new file mode 100644 index 000000000..e03f101e8 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -0,0 +1,74 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" +) + +// TestRACKUpdate tests the RACK related fields are updated when an ACK is +// received on a SACK enabled connection. +func TestRACKUpdate(t *testing.T) { + const maxPayload = 10 + const tsOptionSize = 12 + const maxTCPOptionSize = 40 + + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) + defer c.Cleanup() + + var xmitTime time.Time + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that the endpoint Sender.RACKState is what we expect. + if state.Sender.RACKState.XmitTime.Before(xmitTime) { + t.Fatalf("RACK transmit time failed to update when an ACK is received") + } + + gotSeq := state.Sender.RACKState.EndSequence + wantSeq := state.Sender.SndNxt + if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { + t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq) + } + + if state.Sender.RACKState.RTT == 0 { + t.Fatalf("RACK RTT failed to update when an ACK is received") + } + }) + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + + data := buffer.NewView(maxPayload) + for i := range data { + data[i] = byte(i) + } + + // Write the data. + xmitTime = time.Now() + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + bytesRead := 0 + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + c.SendAck(790, bytesRead) + time.Sleep(200 * time.Millisecond) +} -- cgit v1.2.3 From e7b232a5b881c8930db968a7709ebd93cb2cab04 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 5 Aug 2020 14:09:26 -0700 Subject: Prefer RLock over Lock in functions that don't need Lock(). Updates #231 PiperOrigin-RevId: 325097683 --- pkg/tcpip/stack/stack.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 7189e8e7e..5b19c5d59 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1985,8 +1985,8 @@ func generateRandInt64() int64 { // FindNetworkEndpoint returns the network endpoint for the given address. func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() for _, nic := range s.nics { id := NetworkEndpointID{address} @@ -2005,8 +2005,8 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres // FindNICNameFromID returns the name of the nic for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { -- cgit v1.2.3 From 90a2d4e8238a9a92b77d363439485d3e8b2211ac Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 5 Aug 2020 17:30:39 -0700 Subject: Support receiving broadcast IPv4 packets Test: integration_test.TestIncomingSubnetBroadcast PiperOrigin-RevId: 325135617 --- pkg/tcpip/stack/nic.go | 48 ++++ pkg/tcpip/stack/stack_test.go | 2 +- pkg/tcpip/tests/integration/BUILD | 21 ++ .../tests/integration/multicast_broadcast_test.go | 254 +++++++++++++++++++++ 4 files changed, 324 insertions(+), 1 deletion(-) create mode 100644 pkg/tcpip/tests/integration/BUILD create mode 100644 pkg/tcpip/tests/integration/multicast_broadcast_test.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f21066fce..ae4d241de 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -609,6 +609,9 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // If none exists a temporary one may be created if we are in promiscuous mode // or spoofing. Promiscuous mode will only be checked if promiscuous is true. // Similarly, spoofing will only be checked if spoofing is true. +// +// If the address is the IPv4 broadcast address for an endpoint's network, that +// endpoint will be returned. func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { n.mu.RLock() @@ -633,6 +636,16 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t } } + // Check if address is a broadcast address for the endpoint's network. + // + // Only IPv4 has a notion of broadcast addresses. + if protocol == header.IPv4ProtocolNumber { + if ref := n.getRefForBroadcastRLocked(address); ref != nil { + n.mu.RUnlock() + return ref + } + } + // A usable reference was not found, create a temporary one if requested by // the caller or if the address is found in the NIC's subnets. createTempEP := spoofingOrPromiscuous @@ -670,8 +683,34 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t return ref } +// getRefForBroadcastLocked returns an endpoint where address is the IPv4 +// broadcast address for the endpoint's network. +// +// n.mu MUST be read locked. +func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint { + for _, ref := range n.mu.endpoints { + // Only IPv4 has a notion of broadcast addresses. + if ref.protocol != header.IPv4ProtocolNumber { + continue + } + + addr := ref.addrWithPrefix() + subnet := addr.Subnet() + if subnet.IsBroadcast(address) && ref.tryIncRef() { + return ref + } + } + + return nil +} + /// getRefOrCreateTempLocked returns an existing endpoint for address or creates /// and returns a temporary endpoint. +// +// If the address is the IPv4 broadcast address for an endpoint's network, that +// endpoint will be returned. +// +// n.mu must be write locked. func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // No need to check the type as we are ok with expired endpoints at this @@ -685,6 +724,15 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add n.removeEndpointLocked(ref) } + // Check if address is a broadcast address for an endpoint's network. + // + // Only IPv4 has a notion of broadcast addresses. + if protocol == header.IPv4ProtocolNumber { + if ref := n.getRefForBroadcastRLocked(address); ref != nil { + return ref + } + } + // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index f22062889..63ddbc44e 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1704,7 +1704,7 @@ func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, sub // Trying the next address should always fail since it is outside the range. if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0) + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = 0", fakeNetNumber, tcpip.Address(addrBytes), gotNicID) } } diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD new file mode 100644 index 000000000..7fff30462 --- /dev/null +++ b/pkg/tcpip/tests/integration/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_test") + +package(licenses = ["notice"]) + +go_test( + name = "integration_test", + size = "small", + srcs = ["multicast_broadcast_test.go"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go new file mode 100644 index 000000000..4a860a805 --- /dev/null +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -0,0 +1,254 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const defaultMTU = 1280 + +func TestIncomingSubnetBroadcast(t *testing.T) { + const ( + nicID = 1 + remotePort = 5555 + localPort = 80 + ttl = 255 + ) + + data := []byte{1, 2, 3, 4} + + // Local IPv4 subnet: 192.168.1.58/24 + ipv4Addr := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 24, + } + ipv4Subnet := ipv4Addr.Subnet() + ipv4SubnetBcast := ipv4Subnet.Broadcast() + + // Local IPv6 subnet: 200a::1/64 + ipv6Addr := tcpip.AddressWithPrefix{ + Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + PrefixLen: 64, + } + ipv6Subnet := ipv6Addr.Subnet() + ipv6SubnetBcast := ipv6Subnet.Broadcast() + + // Remote addrs. + remoteIPv4Addr := tcpip.Address("\x64\x0a\x7b\x18") + remoteIPv6Addr := tcpip.Address("\x20\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + + rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) { + payloadLen := header.UDPMinimumSize + len(data) + totalLen := header.IPv4MinimumSize + payloadLen + hdr := buffer.NewPrependable(totalLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv4Addr, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(udp.ProtocolNumber), + TTL: ttl, + SrcAddr: remoteIPv4Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) { + payloadLen := header.UDPMinimumSize + len(data) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv6Addr, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLen), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + tests := []struct { + name string + bindAddr tcpip.Address + dstAddr tcpip.Address + expectRx bool + }{ + { + name: "IPv4 unicast binding to unicast", + bindAddr: ipv4Addr.Address, + dstAddr: ipv4Addr.Address, + expectRx: true, + }, + { + name: "IPv4 unicast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4Addr.Address, + expectRx: false, + }, + { + name: "IPv4 unicast binding to wildcard", + dstAddr: ipv4Addr.Address, + expectRx: true, + }, + + { + name: "IPv4 directed broadcast binding to subnet broadcast", + bindAddr: ipv4SubnetBcast, + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + { + name: "IPv4 directed broadcast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4SubnetBcast, + expectRx: false, + }, + { + name: "IPv4 directed broadcast binding to wildcard", + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + + { + name: "IPv4 broadcast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: header.IPv4Broadcast, + expectRx: true, + }, + { + name: "IPv4 broadcast binding to subnet broadcast", + bindAddr: ipv4SubnetBcast, + dstAddr: header.IPv4Broadcast, + expectRx: false, + }, + { + name: "IPv4 broadcast binding to wildcard", + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + + // IPv6 has no notion of a broadcast. + { + name: "IPv6 unicast binding to wildcard", + dstAddr: ipv6Addr.Address, + expectRx: true, + }, + { + name: "IPv6 broadcast-like address binding to wildcard", + dstAddr: ipv6SubnetBcast, + expectRx: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err) + } + ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr} + if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err) + } + + var netproto tcpip.NetworkProtocolNumber + var rxUDP func(*channel.Endpoint, tcpip.Address) + switch l := len(test.dstAddr); l { + case header.IPv4AddressSize: + netproto = header.IPv4ProtocolNumber + rxUDP = rxIPv4UDP + case header.IPv6AddressSize: + netproto = header.IPv6ProtocolNumber + rxUDP = rxIPv6UDP + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + wq := waiter.Queue{} + ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("ep.Bind(%+v): %s", bindAddr, err) + } + + rxUDP(e, test.dstAddr) + if gotPayload, _, err := ep.Read(nil); test.expectRx { + if err != nil { + t.Fatalf("Read(nil): %s", err) + } + if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + } + }) + } +} -- cgit v1.2.3 From fc4dd3ef455975a033714052b12ebebc85e937d5 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 6 Aug 2020 01:29:32 -0700 Subject: Join IPv4 all-systems group on NIC enable Test: - stack_test.TestJoinLeaveMulticastOnNICEnableDisable - integration_test.TestIncomingMulticastAndBroadcast PiperOrigin-RevId: 325185259 --- pkg/tcpip/header/ipv4.go | 5 + pkg/tcpip/stack/ndp_test.go | 8 +- pkg/tcpip/stack/nic.go | 12 ++ pkg/tcpip/stack/stack_test.go | 191 +++++++++++++-------- .../tests/integration/multicast_broadcast_test.go | 22 ++- test/packetimpact/tests/BUILD | 2 - 6 files changed, 161 insertions(+), 79 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 62ac932bb..d0d1efd0d 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -101,6 +101,11 @@ const ( // IPv4Version is the version of the ipv4 protocol. IPv4Version = 4 + // IPv4AllSystems is the all systems IPv4 multicast address as per + // IANA's IPv4 Multicast Address Space Registry. See + // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml. + IPv4AllSystems tcpip.Address = "\xe0\x00\x00\x01" + // IPv4Broadcast is the broadcast address of the IPv4 procotol. IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff" diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 644ba7c33..5d286ccbc 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1689,13 +1689,7 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) AddressWithPrefix: item, } - for _, i := range list { - if i == protocolAddress { - return true - } - } - - return false + return containsAddr(list, protocolAddress) } // TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ae4d241de..eaaf756cd 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -217,6 +217,11 @@ func (n *NIC) disableLocked() *tcpip.Error { } if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + // The NIC may have already left the multicast group. + if err := n.leaveGroupLocked(header.IPv4AllSystems, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + // The address may have already been removed. if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress { return err @@ -255,6 +260,13 @@ func (n *NIC) enable() *tcpip.Error { if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { return err } + + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts + // multicast group. Note, the IANA calls the all-hosts multicast group the + // all-systems multicast group. + if err := n.joinGroupLocked(header.IPv4ProtocolNumber, header.IPv4AllSystems); err != nil { + return err + } } // Join the IPv6 All-Nodes Multicast group if the stack is configured to diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 63ddbc44e..0b6deda02 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -277,6 +277,17 @@ func (l *linkEPWithMockedAttach) isAttached() bool { return l.attached } +// Checks to see if list contains an address. +func containsAddr(list []tcpip.ProtocolAddress, item tcpip.ProtocolAddress) bool { + for _, i := range list { + if i == item { + return true + } + } + + return false +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. @@ -3089,6 +3100,13 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { const nicID = 1 + broadcastAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 32, + }, + } e := loopback.New() s := stack.New(stack.Options{ @@ -3099,49 +3117,41 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) } - allStackAddrs := s.AllAddresses() - allNICAddrs, ok := allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) + } } // Enabling the NIC should add the IPv4 broadcast address. if err := s.EnableNIC(nicID); err != nil { t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } - allStackAddrs = s.AllAddresses() - allNICAddrs, ok = allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 1 { - t.Fatalf("got len(allNICAddrs) = %d, want = 1", l) - } - want := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: header.IPv4Broadcast, - PrefixLen: 32, - }, - } - if allNICAddrs[0] != want { - t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want) + + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if !containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, want = %+v", allNICAddrs, broadcastAddr) + } } // Disabling the NIC should remove the IPv4 broadcast address. if err := s.DisableNIC(nicID); err != nil { t.Fatalf("s.DisableNIC(%d): %s", nicID, err) } - allStackAddrs = s.AllAddresses() - allNICAddrs, ok = allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) + } } } @@ -3189,50 +3199,93 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { } } -func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { +func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) { const nicID = 1 - e := loopback.New() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - }) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + }{ + { + name: "IPv6 All-Nodes", + proto: header.IPv6ProtocolNumber, + addr: header.IPv6AllNodesMulticastAddress, + }, + { + name: "IPv4 All-Systems", + proto: header.IPv4ProtocolNumber, + addr: header.IPv4AllSystems, + }, } - // Should not be in the IPv6 all-nodes multicast group yet because the NIC has - // not been enabled yet. - isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } - // The all-nodes multicast group should be joined when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if !isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress) - } + // Should not be in the multicast group yet because the NIC has not been + // enabled yet. + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } - // The all-nodes multicast group should be left when the NIC is disabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) + } + + // The multicast group should be left when the NIC is disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } + + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) + } + + // Leaving the group before disabling the NIC should not cause an error. + if err := s.LeaveGroup(test.proto, nicID, test.addr); err != nil { + t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", test.proto, nicID, test.addr, err) + } + + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } + }) } } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 4a860a805..d9b2d147a 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -31,7 +31,9 @@ import ( const defaultMTU = 1280 -func TestIncomingSubnetBroadcast(t *testing.T) { +// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some +// multicast or broadcast address. +func TestIncomingMulticastAndBroadcast(t *testing.T) { const ( nicID = 1 remotePort = 5555 @@ -179,6 +181,24 @@ func TestIncomingSubnetBroadcast(t *testing.T) { expectRx: true, }, + { + name: "IPv4 all-systems multicast binding to all-systems multicast", + bindAddr: header.IPv4AllSystems, + dstAddr: header.IPv4AllSystems, + expectRx: true, + }, + { + name: "IPv4 all-systems multicast binding to wildcard", + dstAddr: header.IPv4AllSystems, + expectRx: true, + }, + { + name: "IPv4 all-systems multicast binding to unicast", + bindAddr: ipv4Addr.Address, + dstAddr: header.IPv4AllSystems, + expectRx: false, + }, + // IPv6 has no notion of a broadcast. { name: "IPv6 unicast binding to wildcard", diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 0c2a05380..74658fea0 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -40,8 +40,6 @@ packetimpact_go_test( packetimpact_go_test( name = "udp_recv_mcast_bcast", srcs = ["udp_recv_mcast_bcast_test.go"], - # TODO(b/152813495): Fix netstack then remove the line below. - expect_netstack_failure = True, deps = [ "//pkg/tcpip", "//pkg/tcpip/header", -- cgit v1.2.3 From 94447aeab3d20400680f624e4b84e7b6fc0aae0b Mon Sep 17 00:00:00 2001 From: Sam Balana Date: Fri, 7 Aug 2020 15:05:13 -0700 Subject: Fix panic during Address Resolution of neighbor entry created by NS When a Neighbor Solicitation is received, a neighbor entry is created with the remote host's link layer address, but without a link layer address resolver. If the host decides to send a packet addressed to the IP address of that neighbor entry, Address Resolution starts with a nil pointer to the link layer address resolver. This causes the netstack to panic and crash. This change ensures that when a packet is sent in that situation, the link layer address resolver will be set before Address Resolution begins. Tests: pkg/tcpip/stack:stack_test + TestEntryUnknownToStaleToProbeToReachable - TestNeighborCacheEntryNoLinkAddress Updates #1889 Updates #1894 Updates #1895 Updates #1947 Updates #1948 Updates #1949 Updates #1950 PiperOrigin-RevId: 325516471 --- pkg/tcpip/stack/neighbor_cache.go | 22 ++++--- pkg/tcpip/stack/neighbor_cache_test.go | 32 +---------- pkg/tcpip/stack/neighbor_entry_test.go | 102 ++++++++++++++++++++++++++++++++- pkg/tcpip/stack/nud.go | 2 +- 4 files changed, 115 insertions(+), 43 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 1d37716c2..27e1feec0 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -115,17 +115,15 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // channel is returned for the top level caller to block. Channel is closed // once address resolution is complete (success or not). func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { - if linkRes != nil { - if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { - e := NeighborEntry{ - Addr: remoteAddr, - LocalAddr: localAddr, - LinkAddr: linkAddr, - State: Static, - UpdatedAt: time.Now(), - } - return e, nil, nil + if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { + e := NeighborEntry{ + Addr: remoteAddr, + LocalAddr: localAddr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), } + return e, nil, nil } entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) @@ -289,8 +287,8 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { // HandleProbe implements NUDHandler.HandleProbe by following the logic defined // in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled // by the caller. -func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) { - entry := n.getOrCreateEntry(remoteAddr, localAddr, nil) +func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { + entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 4cb2c9c6b..b4fa69e3e 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -335,32 +335,6 @@ func TestNeighborCacheEntry(t *testing.T) { } } -// TestNeighborCacheEntryNoLinkAddress verifies calling entry() without a -// LinkAddressResolver returns ErrNoLinkAddress. -func TestNeighborCacheEntryNoLinkAddress(t *testing.T) { - nudDisp := testNUDDispatcher{} - c := DefaultNUDConfigurations() - clock := newFakeClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) - store := newTestEntryStore() - - entry, ok := store.entry(0) - if !ok { - t.Fatalf("store.entry(0) not found") - } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, nil, nil) - if err != tcpip.ErrNoLinkAddress { - t.Errorf("got neigh.entry(%s, %s, nil, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } -} - func TestNeighborCacheRemoveEntry(t *testing.T) { config := DefaultNUDConfigurations() @@ -1048,9 +1022,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Fatalf("c.store.entry(0) not found") } c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) - e, _, err := c.neigh.entry(entry.Addr, "", nil, nil) + e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) if err != nil { - t.Errorf("unexpected error from c.neigh.entry(%s, \"\", nil nil): %s", entry.Addr, err) + t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1059,7 +1033,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { State: Static, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 08c9ccd25..b769fb2fa 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -236,7 +236,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) linkRes := entryTestLinkResolver{} - entry := newNeighborEntry(&nic, entryTestAddr1, entryTestAddr2, nudState, &linkRes) + entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes) // Stub out ndpState to verify modification of default routers. nic.mu.ndp = ndpState{ @@ -2344,6 +2344,106 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { nudDisp.mu.Unlock() } +// TestEntryUnknownToStaleToProbeToReachable exercises the following scenario: +// 1. Probe is received +// 2. Entry is created in Stale +// 3. Packet is queued on the entry +// 4. Entry transitions to Delay then Probe +// 5. Probe is sent +func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Probe to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handleProbeLocked(entryTestLinkAddr1) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // Probe caused by the Delay-to-Probe transition + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() // Eliminate random factors from ReachableTime computation so the transition diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index f848d50ad..e1ec15487 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -177,7 +177,7 @@ type NUDHandler interface { // Neighbor Solicitation for ARP or NDP, respectively). Validation of the // probe needs to be performed before calling this function since the // Neighbor Cache doesn't have access to view the NIC's assigned addresses. - HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) + HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP // reply or Neighbor Advertisement for ARP or NDP, respectively). -- cgit v1.2.3 From b404b5c255214a37d7f787f9fe24bb8e22509eb4 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Sat, 8 Aug 2020 17:43:15 -0700 Subject: Use unicast source for ICMP echo replies Packets MUST NOT use a non-unicast source address for ICMP Echo Replies. Test: integration_test.TestPingMulticastBroadcast PiperOrigin-RevId: 325634380 --- pkg/tcpip/network/ipv4/icmp.go | 20 ++ pkg/tcpip/network/ipv6/icmp.go | 20 ++ pkg/tcpip/stack/BUILD | 1 + pkg/tcpip/stack/route.go | 24 ++- pkg/tcpip/stack/stack_test.go | 48 +++++ pkg/tcpip/tests/integration/BUILD | 1 + .../tests/integration/multicast_broadcast_test.go | 208 ++++++++++++++++++--- pkg/tcpip/transport/udp/endpoint.go | 2 +- 8 files changed, 299 insertions(+), 25 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 83e71cb8c..94803a359 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -96,6 +96,26 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), }) + remoteLinkAddr := r.RemoteLinkAddress + + // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP + // source address MUST be one of its own IP addresses (but not a broadcast + // or multicast address). + localAddr := r.LocalAddress + if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) { + localAddr = "" + } + + r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + // If we cannot find a route to the destination, silently drop the packet. + return + } + defer r.Release() + + // Use the remote link address from the incoming packet. + r.ResolveWith(remoteLinkAddr) + vv := pkt.Data.Clone(nil) vv.TrimFront(header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 24600d877..ded91d83a 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -389,6 +389,26 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme received.Invalid.Increment() return } + + remoteLinkAddr := r.RemoteLinkAddress + + // As per RFC 4291 section 2.7, multicast addresses must not be used as + // source addresses in IPv6 packets. + localAddr := r.LocalAddress + if header.IsV6MulticastAddress(r.LocalAddress) { + localAddr = "" + } + + r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + if err != nil { + // If we cannot find a route to the destination, silently drop the packet. + return + } + defer r.Release() + + // Use the link address from the source of the original packet. + r.ResolveWith(remoteLinkAddr) + pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 1c58bed2d..bfc7a0c7c 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -121,6 +121,7 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/ports", diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 91e0110f1..9ce0a2c22 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -110,6 +110,12 @@ func (r *Route) GSOMaxSize() uint32 { return 0 } +// ResolveWith immediately resolves a route with the specified remote link +// address. +func (r *Route) ResolveWith(addr tcpip.LinkAddress) { + r.RemoteLinkAddress = addr +} + // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in // case address resolution requires blocking, e.g. wait for ARP reply. Waker is // notified when address resolution is complete (success or not). @@ -279,12 +285,26 @@ func (r *Route) Stack() *Stack { return r.ref.stack() } -// IsBroadcast returns true if the route is to send a broadcast packet. -func (r *Route) IsBroadcast() bool { +// IsOutboundBroadcast returns true if the route is for an outbound broadcast +// packet. +func (r *Route) IsOutboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast } +// IsInboundBroadcast returns true if the route is for an inbound broadcast +// packet. +func (r *Route) IsInboundBroadcast() bool { + // Only IPv4 has a notion of broadcast. + if r.LocalAddress == header.IPv4Broadcast { + return true + } + + addr := r.ref.addrWithPrefix() + subnet := addr.Subnet() + return subnet.IsBroadcast(r.LocalAddress) +} + // ReverseRoute returns new route with given source and destination address. func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { return Route{ diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 0b6deda02..fe1c1b8a4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "math" + "net" "sort" "strings" "testing" @@ -34,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -3694,3 +3696,49 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { }) } } + +func TestResolveWith(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID = 1 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, + }) + ep := channel.New(0, defaultMTU, "") + ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + addr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), + PrefixLen: 24, + }, + } + if err := s.AddProtocolAddress(nicID, addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) + + remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4()) + r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err) + } + defer r.Release() + + // Should initially require resolution. + if !r.IsResolutionRequired() { + t.Fatal("got r.IsResolutionRequired() = false, want = true") + } + + // Manually resolving the route should no longer require resolution. + r.ResolveWith("\x01") + if r.IsResolutionRequired() { + t.Fatal("got r.IsResolutionRequired() = true, want = false") + } +} diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 7fff30462..6d52af98a 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -14,6 +14,7 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index d9b2d147a..0ff3a2b89 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -15,6 +15,7 @@ package integration_test import ( + "net" "testing" "github.com/google/go-cmp/cmp" @@ -25,11 +26,195 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) -const defaultMTU = 1280 +const ( + defaultMTU = 1280 + ttl = 255 +) + +var ( + ipv4Addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()), + PrefixLen: 24, + } + ipv4Subnet = ipv4Addr.Subnet() + ipv4SubnetBcast = ipv4Subnet.Broadcast() + + ipv6Addr = tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("200a::1").To16()), + PrefixLen: 64, + } + ipv6Subnet = ipv6Addr.Subnet() + ipv6SubnetBcast = ipv6Subnet.Broadcast() + + // Remote addrs. + remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4()) + remoteIPv6Addr = tcpip.Address(net.ParseIP("200b::1").To16()) +) + +// TestPingMulticastBroadcast tests that responding to an Echo Request destined +// to a multicast or broadcast address uses a unicast source address for the +// reply. +func TestPingMulticastBroadcast(t *testing.T) { + const nicID = 1 + + rxIPv4ICMP := func(e *channel.Endpoint, dst tcpip.Address) { + totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4Echo) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(^header.Checksum(pkt, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(icmp.ProtocolNumber4), + TTL: ttl, + SrcAddr: remoteIPv4Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) { + totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetCode(0) + pkt.SetChecksum(0) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + tests := []struct { + name string + dstAddr tcpip.Address + }{ + { + name: "IPv4 unicast", + dstAddr: ipv4Addr.Address, + }, + { + name: "IPv4 directed broadcast", + dstAddr: ipv4SubnetBcast, + }, + { + name: "IPv4 broadcast", + dstAddr: header.IPv4Broadcast, + }, + { + name: "IPv4 all-systems multicast", + dstAddr: header.IPv4AllSystems, + }, + { + name: "IPv6 unicast", + dstAddr: ipv6Addr.Address, + }, + { + name: "IPv6 all-nodes multicast", + dstAddr: header.IPv6AllNodesMulticastAddress, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ipv4Proto := ipv4.NewProtocol() + ipv6Proto := ipv6.NewProtocol() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4Proto, ipv6Proto}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4(), icmp.NewProtocol6()}, + }) + // We only expect a single packet in response to our ICMP Echo Request. + e := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err) + } + ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr} + if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err) + } + + // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote + // node when attempting to send the ICMP Echo Reply. + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + var rxICMP func(*channel.Endpoint, tcpip.Address) + var expectedSrc tcpip.Address + var expectedDst tcpip.Address + var proto stack.NetworkProtocol + switch l := len(test.dstAddr); l { + case header.IPv4AddressSize: + rxICMP = rxIPv4ICMP + expectedSrc = ipv4Addr.Address + expectedDst = remoteIPv4Addr + proto = ipv4Proto + case header.IPv6AddressSize: + rxICMP = rxIPv6ICMP + expectedSrc = ipv6Addr.Address + expectedDst = remoteIPv6Addr + proto = ipv6Proto + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + rxICMP(e, test.dstAddr) + pkt, ok := e.Read() + if !ok { + t.Fatal("expected ICMP response") + } + + if pkt.Route.LocalAddress != expectedSrc { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, expectedSrc) + } + if pkt.Route.RemoteAddress != expectedDst { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) + } + + src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader) + if src != expectedSrc { + t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) + } + if dst != expectedDst { + t.Errorf("got pkt destination = %s, want = %s", dst, expectedDst) + } + }) + } + +} // TestIncomingMulticastAndBroadcast tests receiving a packet destined to some // multicast or broadcast address. @@ -38,31 +223,10 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { nicID = 1 remotePort = 5555 localPort = 80 - ttl = 255 ) data := []byte{1, 2, 3, 4} - // Local IPv4 subnet: 192.168.1.58/24 - ipv4Addr := tcpip.AddressWithPrefix{ - Address: "\xc0\xa8\x01\x3a", - PrefixLen: 24, - } - ipv4Subnet := ipv4Addr.Subnet() - ipv4SubnetBcast := ipv4Subnet.Broadcast() - - // Local IPv6 subnet: 200a::1/64 - ipv6Addr := tcpip.AddressWithPrefix{ - Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", - PrefixLen: 64, - } - ipv6Subnet := ipv6Addr.Subnet() - ipv6SubnetBcast := ipv6Subnet.Broadcast() - - // Remote addrs. - remoteIPv4Addr := tcpip.Address("\x64\x0a\x7b\x18") - remoteIPv6Addr := tcpip.Address("\x20\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) { payloadLen := header.UDPMinimumSize + len(data) totalLen := header.IPv4MinimumSize + payloadLen diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index b7d735889..444b5b01c 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -499,7 +499,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c resolve = route.Resolve } - if !e.broadcast && route.IsBroadcast() { + if !e.broadcast && route.IsOutboundBroadcast() { return 0, nil, tcpip.ErrBroadcastDisabled } -- cgit v1.2.3 From 8e31f0dc57d44fb463441f6156fba5c240369dfe Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Mon, 10 Aug 2020 19:32:48 -0700 Subject: Set the NetworkProtocolNumber of all PacketBuffers. NetworkEndpoints set the number on outgoing packets in Write() and NetworkProtocols set them on incoming packets in Parse(). Needed for #3549. PiperOrigin-RevId: 325938745 --- pkg/tcpip/network/ipv4/ipv4.go | 26 ++++++++++++++++---------- pkg/tcpip/network/ipv4/ipv4_test.go | 6 +++++- pkg/tcpip/network/ipv6/ipv6.go | 3 +++ pkg/tcpip/stack/packet_buffer.go | 10 +++++++--- pkg/tcpip/stack/registration.go | 4 ++-- 5 files changed, 33 insertions(+), 16 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 6c4f0ae3e..9ff27a363 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -173,9 +173,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + NetworkProtocolNumber: header.IPv4ProtocolNumber, }); err != nil { return err } @@ -192,9 +193,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + NetworkProtocolNumber: header.IPv4ProtocolNumber, }); err != nil { return err } @@ -206,9 +208,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: startOfHdr, - Data: emptyVV, - NetworkHeader: buffer.View(h), + Header: startOfHdr, + Data: emptyVV, + NetworkHeader: buffer.View(h), + NetworkProtocolNumber: header.IPv4ProtocolNumber, }); err != nil { return err } @@ -249,10 +252,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - nicName := e.stack.FindNICNameFromID(e.NICID()) // iptables filtering. All packets that reach here are locally // generated. + nicName := e.stack.FindNICNameFromID(e.NICID()) ipt := e.stack.IPTables() if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. @@ -304,6 +308,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pkt := pkts.Front(); pkt != nil; { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber pkt = pkt.Next() } @@ -570,6 +575,7 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu parseTransportHeader = false } + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber pkt.NetworkHeader = hdr pkt.Data.TrimFront(len(hdr)) pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index ded97ac64..63e2c36c2 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -150,6 +150,9 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) } + if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want { + t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want) + } if i < len(packets)-1 { sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) } else { @@ -285,7 +288,8 @@ func TestFragmentation(t *testing.T) { source := &stack.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. - Data: payload.Clone(nil), + Data: payload.Clone(nil), + NetworkProtocolNumber: header.IPv4ProtocolNumber, } c := buildContext(t, nil, ft.mtu) err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 4a0b53c45..d7d7fc611 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -117,6 +117,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in @@ -152,6 +153,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pb := pkts.Front(); pb != nil; pb = pb.Next() { ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) pb.NetworkHeader = buffer.View(ip) + pb.NetworkProtocolNumber = header.IPv6ProtocolNumber } n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) @@ -586,6 +588,7 @@ traverseExtensions: } ipHdr = header.IPv6(hdr) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber pkt.NetworkHeader = hdr pkt.Data.TrimFront(len(hdr)) pkt.Data.CapLength(int(ipHdr.PayloadLength())) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 5d6865e35..9e871f968 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -62,6 +62,11 @@ type PacketBuffer struct { NetworkHeader buffer.View TransportHeader buffer.View + // NetworkProtocol is only valid when NetworkHeader is set. + // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol + // numbers in registration APIs that take a PacketBuffer. + NetworkProtocolNumber tcpip.NetworkProtocolNumber + // Hash is the transport layer hash of this packet. A value of zero // indicates no valid hash has been set. Hash uint32 @@ -72,9 +77,8 @@ type PacketBuffer struct { // The following fields are only set by the qdisc layer when the packet // is added to a queue. - EgressRoute *Route - GSOOptions *GSO - NetworkProtocolNumber tcpip.NetworkProtocolNumber + EgressRoute *Route + GSOOptions *GSO // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 8604c4259..4570e8969 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -249,8 +249,8 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. It takes ownership of pkt. pkt.TransportHeader must have already - // been set. + // protocol. It takes ownership of pkt. pkt.TransportHeader must have + // already been set. WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and -- cgit v1.2.3 From 47515f475167ffa23267ca0b9d1b39e7907587d6 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Thu, 13 Aug 2020 13:07:03 -0700 Subject: Migrate to PacketHeader API for PacketBuffer. Formerly, when a packet is constructed or parsed, all headers are set by the client code. This almost always involved prepending to pk.Header buffer or trimming pk.Data portion. This is known to prone to bugs, due to the complexity and number of the invariants assumed across netstack to maintain. In the new PacketHeader API, client will call Push()/Consume() method to construct/parse an outgoing/incoming packet. All invariants, such as slicing and trimming, are maintained by the API itself. NewPacketBuffer() is introduced to create new PacketBuffer. Zero value is no longer valid. PacketBuffer now assumes the packet is a concatenation of following portions: * LinkHeader * NetworkHeader * TransportHeader * Data Any of them could be empty, or zero-length. PiperOrigin-RevId: 326507688 --- pkg/sentry/socket/netfilter/tcp_matcher.go | 4 +- pkg/sentry/socket/netfilter/udp_matcher.go | 4 +- pkg/tcpip/buffer/view.go | 10 + pkg/tcpip/link/channel/channel.go | 4 +- pkg/tcpip/link/fdbased/BUILD | 1 + pkg/tcpip/link/fdbased/endpoint.go | 14 +- pkg/tcpip/link/fdbased/endpoint_test.go | 134 ++++--- pkg/tcpip/link/fdbased/mmap.go | 16 +- pkg/tcpip/link/fdbased/packet_dispatchers.go | 45 ++- pkg/tcpip/link/loopback/loopback.go | 21 +- pkg/tcpip/link/muxed/injectable_test.go | 25 +- pkg/tcpip/link/nested/nested_test.go | 4 +- pkg/tcpip/link/sharedmem/sharedmem.go | 25 +- pkg/tcpip/link/sharedmem/sharedmem_test.go | 135 +++---- pkg/tcpip/link/sharedmem/tx.go | 14 +- pkg/tcpip/link/sniffer/sniffer.go | 19 +- pkg/tcpip/link/tun/device.go | 23 +- pkg/tcpip/link/waitable/waitable_test.go | 12 +- pkg/tcpip/network/arp/arp.go | 26 +- pkg/tcpip/network/arp/arp_test.go | 8 +- pkg/tcpip/network/ip_test.go | 80 +++-- pkg/tcpip/network/ipv4/icmp.go | 38 +- pkg/tcpip/network/ipv4/ipv4.go | 188 +++++----- pkg/tcpip/network/ipv4/ipv4_test.go | 72 ++-- pkg/tcpip/network/ipv6/icmp.go | 55 ++- pkg/tcpip/network/ipv6/icmp_test.go | 46 +-- pkg/tcpip/network/ipv6/ipv6.go | 47 +-- pkg/tcpip/network/ipv6/ipv6_test.go | 16 +- pkg/tcpip/network/ipv6/ndp_test.go | 29 +- pkg/tcpip/stack/BUILD | 2 + pkg/tcpip/stack/conntrack.go | 31 +- pkg/tcpip/stack/forwarder_test.go | 65 ++-- pkg/tcpip/stack/headertype_string.go | 39 ++ pkg/tcpip/stack/iptables.go | 2 +- pkg/tcpip/stack/iptables_targets.go | 9 +- pkg/tcpip/stack/ndp.go | 32 +- pkg/tcpip/stack/ndp_test.go | 22 +- pkg/tcpip/stack/nic.go | 40 +-- pkg/tcpip/stack/nic_test.go | 4 +- pkg/tcpip/stack/packet_buffer.go | 260 ++++++++++++-- pkg/tcpip/stack/packet_buffer_test.go | 397 +++++++++++++++++++++ pkg/tcpip/stack/route.go | 5 +- pkg/tcpip/stack/stack_test.go | 61 ++-- pkg/tcpip/stack/transport_demuxer_test.go | 14 +- pkg/tcpip/stack/transport_test.go | 54 ++- .../tests/integration/multicast_broadcast_test.go | 18 +- pkg/tcpip/transport/icmp/endpoint.go | 39 +- pkg/tcpip/transport/packet/endpoint.go | 35 +- pkg/tcpip/transport/raw/endpoint.go | 30 +- pkg/tcpip/transport/tcp/connect.go | 30 +- pkg/tcpip/transport/tcp/protocol.go | 17 +- pkg/tcpip/transport/tcp/segment.go | 2 +- pkg/tcpip/transport/tcp/testing/context/context.go | 28 +- pkg/tcpip/transport/udp/endpoint.go | 26 +- pkg/tcpip/transport/udp/protocol.go | 57 ++- pkg/tcpip/transport/udp/udp_test.go | 58 ++- 56 files changed, 1568 insertions(+), 924 deletions(-) create mode 100644 pkg/tcpip/stack/headertype_string.go create mode 100644 pkg/tcpip/stack/packet_buffer_test.go (limited to 'pkg/tcpip/stack') diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 4f98ee2d5..0bfd6c1f4 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -97,7 +97,7 @@ func (*TCPMatcher) Name() string { // Match implements Matcher.Match. func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) if netHeader.TransportProtocol() != header.TCPProtocolNumber { return false, false @@ -111,7 +111,7 @@ func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceN return false, false } - tcpHeader := header.TCP(pkt.TransportHeader) + tcpHeader := header.TCP(pkt.TransportHeader().View()) if len(tcpHeader) < header.TCPMinimumSize { // There's no valid TCP header here, so we drop the packet immediately. return false, true diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 3f20fc891..7ed05461d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -94,7 +94,7 @@ func (*UDPMatcher) Name() string { // Match implements Matcher.Match. func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. @@ -110,7 +110,7 @@ func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceN return false, false } - udpHeader := header.UDP(pkt.TransportHeader) + udpHeader := header.UDP(pkt.TransportHeader().View()) if len(udpHeader) < header.UDPMinimumSize { // There's no valid UDP header here, so we drop the packet immediately. return false, true diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 9a3c5d6c3..ea0c5413d 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -65,6 +65,16 @@ func (v View) ToVectorisedView() VectorisedView { return NewVectorisedView(len(v), []View{v}) } +// IsEmpty returns whether v is of length zero. +func (v View) IsEmpty() bool { + return len(v) == 0 +} + +// Size returns the length of v. +func (v View) Size() int { + return len(v) +} + // VectorisedView is a vectorised version of View using non contiguous memory. // It supports all the convenience methods supported by View. // diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index e12a5929b..c95aef63c 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -274,7 +274,9 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := PacketInfo{ - Pkt: &stack.PacketBuffer{Data: vv}, + Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }), Proto: 0, GSO: nil, } diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 507b44abc..10072eac1 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -37,5 +37,6 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/rawfile", "//pkg/tcpip/stack", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index c18bb91fb..975309fc8 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -390,8 +390,7 @@ const ( func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if e.hdrSize > 0 { // Add ethernet header if needed. - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) ethHdr := &header.EthernetFields{ DstAddr: remote, Type: protocol, @@ -420,7 +419,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} if gso != nil { - vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) if gso.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen @@ -443,11 +442,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne builder.Add(vnetHdrBuf) } - builder.Add(pkt.Header.View()) - for _, v := range pkt.Data.Views() { + for _, v := range pkt.Views() { builder.Add(v) } - return rawfile.NonBlockingWriteIovec(fd, builder.Build()) } @@ -463,7 +460,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { vnetHdr := virtioNetHdr{} if pkt.GSOOptions != nil { - vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) if pkt.GSOOptions.NeedsCsum { vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen @@ -486,8 +483,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc var builder iovec.Builder builder.Add(vnetHdrBuf) - builder.Add(pkt.Header.View()) - for _, v := range pkt.Data.Views() { + for _, v := range pkt.Views() { builder.Add(v) } iovecs := builder.Build() diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 7b995b85a..709f829c8 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -26,6 +26,7 @@ import ( "time" "unsafe" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -43,9 +44,36 @@ const ( ) type packetInfo struct { - raddr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber - contents *stack.PacketBuffer + Raddr tcpip.LinkAddress + Proto tcpip.NetworkProtocolNumber + Contents *stack.PacketBuffer +} + +type packetContents struct { + LinkHeader buffer.View + NetworkHeader buffer.View + TransportHeader buffer.View + Data buffer.View +} + +func checkPacketInfoEqual(t *testing.T, got, want packetInfo) { + t.Helper() + if diff := cmp.Diff( + want, got, + cmp.Transformer("ExtractPacketBuffer", func(pk *stack.PacketBuffer) *packetContents { + if pk == nil { + return nil + } + return &packetContents{ + LinkHeader: pk.LinkHeader().View(), + NetworkHeader: pk.NetworkHeader().View(), + TransportHeader: pk.TransportHeader().View(), + Data: pk.Data.ToView(), + } + }), + ); diff != "" { + t.Errorf("unexpected packetInfo (-want +got):\n%s", diff) + } } type context struct { @@ -159,19 +187,28 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u RemoteLinkAddress: raddr, } - // Build header. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) - b := hdr.Prepend(100) - for i := range b { - b[i] = uint8(rand.Intn(256)) + // Build payload. + payload := buffer.NewView(plen) + if _, err := rand.Read(payload); err != nil { + t.Fatalf("rand.Read(payload): %s", err) } - // Build payload and write. - payload := make(buffer.View, plen) - for i := range payload { - payload[i] = uint8(rand.Intn(256)) + // Build packet buffer. + const netHdrLen = 100 + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()) + netHdrLen, + Data: payload.ToVectorisedView(), + }) + pkt.Hash = hash + + // Build header. + b := pkt.NetworkHeader().Push(netHdrLen) + if _, err := rand.Read(b); err != nil { + t.Fatalf("rand.Read(b): %s", err) } - want := append(hdr.View(), payload...) + + // Write. + want := append(append(buffer.View(nil), b...), payload...) var gso *stack.GSO if gsoMaxSize != 0 { gso = &stack.GSO{ @@ -183,11 +220,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - Hash: hash, - }); err != nil { + if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -296,13 +329,14 @@ func TestPreserveSrcAddress(t *testing.T) { LocalLinkAddress: baddr, } - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) - if err := c.ep.WritePacket(r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.VectorisedView{}, - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + // WritePacket panics given a prependable with anything less than + // the minimum size of the ethernet header. + // TODO(b/153685824): Figure out if this should use c.ep.MaxHeaderLength(). + ReserveHeaderBytes: header.EthernetMinimumSize, + Data: buffer.VectorisedView{}, + }) + if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -331,24 +365,25 @@ func TestDeliverPacket(t *testing.T) { defer c.cleanup() // Build packet. - b := make([]byte, plen) - all := b - for i := range b { - b[i] = uint8(rand.Intn(256)) + all := make([]byte, plen) + if _, err := rand.Read(all); err != nil { + t.Fatalf("rand.Read(all): %s", err) } - - var hdr header.Ethernet - if !eth { - // So that it looks like an IPv4 packet. - b[0] = 0x40 - } else { - hdr = make(header.Ethernet, header.EthernetMinimumSize) + // Make it look like an IPv4 packet. + all[0] = 0x40 + + wantPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.EthernetMinimumSize, + Data: buffer.NewViewFromBytes(all).ToVectorisedView(), + }) + if eth { + hdr := header.Ethernet(wantPkt.LinkHeader().Push(header.EthernetMinimumSize)) hdr.Encode(&header.EthernetFields{ SrcAddr: raddr, DstAddr: laddr, Type: proto, }) - all = append(hdr, b...) + all = append(hdr, all...) } // Write packet via the file descriptor. @@ -360,24 +395,15 @@ func TestDeliverPacket(t *testing.T) { select { case pi := <-c.ch: want := packetInfo{ - raddr: raddr, - proto: proto, - contents: &stack.PacketBuffer{ - Data: buffer.View(b).ToVectorisedView(), - LinkHeader: buffer.View(hdr), - }, + Raddr: raddr, + Proto: proto, + Contents: wantPkt, } if !eth { - want.proto = header.IPv4ProtocolNumber - want.raddr = "" - } - // want.contents.Data will be a single - // view, so make pi do the same for the - // DeepEqual check. - pi.contents.Data = pi.contents.Data.ToView().ToVectorisedView() - if !reflect.DeepEqual(want, pi) { - t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) + want.Proto = header.IPv4ProtocolNumber + want.Raddr = "" } + checkPacketInfoEqual(t, pi, want) case <-time.After(10 * time.Second): t.Fatalf("Timed out waiting for packet") } @@ -572,8 +598,8 @@ func TestDispatchPacketFormat(t *testing.T) { t.Fatalf("len(sink.pkts) = %d, want %d", got, want) } pkt := sink.pkts[0] - if got, want := len(pkt.LinkHeader), header.EthernetMinimumSize; got != want { - t.Errorf("len(pkt.LinkHeader) = %d, want %d", got, want) + if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want { + t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want) } if got, want := pkt.Data.Size(), 4; got != want { t.Errorf("pkt.Data.Size() = %d, want %d", got, want) diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index 2dfd29aa9..c475dda20 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -18,6 +18,7 @@ package fdbased import ( "encoding/binary" + "fmt" "syscall" "golang.org/x/sys/unix" @@ -170,10 +171,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(pkt) + eth := header.Ethernet(pkt) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -190,10 +190,14 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) { } } - pkt = pkt[d.e.hdrSize:] - d.e.dispatcher.DeliverNetworkPacket(remote, local, p, &stack.PacketBuffer{ - Data: buffer.View(pkt).ToVectorisedView(), - LinkHeader: buffer.View(eth), + pbuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(pkt).ToVectorisedView(), }) + if d.e.hdrSize > 0 { + if _, ok := pbuf.LinkHeader().Consume(d.e.hdrSize); !ok { + panic(fmt.Sprintf("LinkHeader().Consume(%d) must succeed", d.e.hdrSize)) + } + } + d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pbuf) return true, nil } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index d8f2504b3..8c3ca86d6 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -103,7 +103,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { d.allocateViews(BufConfig) n, err := rawfile.BlockingReadv(d.fd, d.iovecs) - if err != nil { + if n == 0 || err != nil { return false, err } if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { @@ -111,17 +111,22 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { // isn't used and it isn't in a view. n -= virtioNetHdrSize } - if n <= d.e.hdrSize { - return false, nil - } + + used := d.capViews(n, BufConfig) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), + }) var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize]) + hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize) + if !ok { + return false, nil + } + eth := header.Ethernet(hdr) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -138,13 +143,6 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } } - used := d.capViews(n, BufConfig) - pkt := &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)), - LinkHeader: buffer.View(eth), - } - pkt.Data.TrimFront(d.e.hdrSize) - d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. @@ -268,17 +266,22 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 { n -= virtioNetHdrSize } - if n <= d.e.hdrSize { - return false, nil - } + + used := d.capViews(k, int(n), BufConfig) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), + }) var ( p tcpip.NetworkProtocolNumber remote, local tcpip.LinkAddress - eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[k][0][:header.EthernetMinimumSize]) + hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize) + if !ok { + return false, nil + } + eth := header.Ethernet(hdr) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() @@ -295,12 +298,6 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } } - used := d.capViews(k, int(n), BufConfig) - pkt := &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)), - LinkHeader: buffer.View(eth), - } - pkt.Data.TrimFront(d.e.hdrSize) d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt) // Prepare e.views for another packet: release used views. diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 781cdd317..38aa694e4 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -77,16 +77,16 @@ func (*endpoint) Wait() {} // WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound // packets to the network-layer dispatcher. func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) + // Construct data as the unparsed portion for the loopback packet. + data := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) // Because we're immediately turning around and writing the packet back // to the rx path, we intentionally don't preserve the remote and local // link addresses from the stack.Route we're passed. - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), + newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: data, }) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, newPkt) return nil } @@ -98,18 +98,17 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) // There should be an ethernet header at the beginning of vv. - hdr, ok := vv.PullUp(header.EthernetMinimumSize) + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) if !ok { // Reject the packet if it's shorter than an ethernet header. return tcpip.ErrBadAddress } linkHeader := header.Ethernet(hdr) - vv.TrimFront(len(linkHeader)) - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), &stack.PacketBuffer{ - Data: vv, - LinkHeader: buffer.View(linkHeader), - }) + e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt) return nil } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 0744f66d6..3e4afcdad 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -46,14 +46,14 @@ func TestInjectableEndpointRawDispatch(t *testing.T) { func TestInjectableEndpointDispatch(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: 1, + Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), + }) + pkt.TransportHeader().Push(1)[0] = 0xFA packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), - }) + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) @@ -67,13 +67,14 @@ func TestInjectableEndpointDispatch(t *testing.T) { func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { endpoint, sock, dstIP := makeTestInjectableEndpoint(t) - hdr := buffer.NewPrependable(1) - hdr.Prepend(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} - endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.NewView(0).ToVectorisedView(), + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: 1, + Data: buffer.NewView(0).ToVectorisedView(), }) + pkt.TransportHeader().Push(1)[0] = 0xFA + packetRoute := stack.Route{RemoteAddress: dstIP} + endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) if err != nil { diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go index 7d9249c1c..c1f9d308c 100644 --- a/pkg/tcpip/link/nested/nested_test.go +++ b/pkg/tcpip/link/nested/nested_test.go @@ -87,7 +87,7 @@ func TestNestedLinkEndpoint(t *testing.T) { t.Error("After attach, nestedEP.IsAttached() = false, want = true") } - nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if disp.count != 1 { t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count) } @@ -101,7 +101,7 @@ func TestNestedLinkEndpoint(t *testing.T) { } disp.count = 0 - nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if disp.count != 0 { t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count) } diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 507c76b76..7fb8a6c49 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -186,8 +186,7 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { // AddHeader implements stack.LinkEndpoint.AddHeader. func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { // Add ethernet header if needed. - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) ethHdr := &header.EthernetFields{ DstAddr: remote, Type: protocol, @@ -207,10 +206,10 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) - v := pkt.Data.ToView() + views := pkt.Views() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(pkt.Header.View(), v) + ok := e.tx.transmit(views...) e.mu.Unlock() if !ok { @@ -227,10 +226,10 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketB // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - v := vv.ToView() + views := vv.Views() // Transmit the packet. e.mu.Lock() - ok := e.tx.transmit(v, buffer.View{}) + ok := e.tx.transmit(views...) e.mu.Unlock() if !ok { @@ -276,16 +275,18 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { rxb[i].Size = e.bufferSize } - if n < header.EthernetMinimumSize { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.View(b).ToVectorisedView(), + }) + + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { continue } + eth := header.Ethernet(hdr) // Send packet up the stack. - eth := header.Ethernet(b[:header.EthernetMinimumSize]) - d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), &stack.PacketBuffer{ - Data: buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), - LinkHeader: buffer.View(eth), - }) + d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt) } // Clean state. diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 8f3cd9449..22d5c97f1 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -266,21 +266,23 @@ func TestSimpleSend(t *testing.T) { for iters := 1000; iters > 0; iters-- { func() { + hdrLen, dataLen := rand.Intn(10000), rand.Intn(10000) + // Prepare and send packet. - n := rand.Intn(10000) - hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength())) - hdrBuf := hdr.Prepend(n) + hdrBuf := buffer.NewView(hdrLen) randomFill(hdrBuf) - n = rand.Intn(10000) - buf := buffer.NewView(n) - randomFill(buf) + data := buffer.NewView(dataLen) + randomFill(data) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrLen + int(c.ep.MaxHeaderLength()), + Data: data.ToVectorisedView(), + }) + copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -317,7 +319,7 @@ func TestSimpleSend(t *testing.T) { // Compare contents skipping the ethernet header added by the // endpoint. - merged := append(hdrBuf, buf...) + merged := append(hdrBuf, data...) if uint32(len(contents)) < pi.Size { t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size) } @@ -344,14 +346,14 @@ func TestPreserveSrcAddressInSend(t *testing.T) { LocalLinkAddress: newLocalLinkAddress, } - // WritePacket panics given a prependable with anything less than - // the minimum size of the ethernet header. - hdr := buffer.NewPrependable(header.EthernetMinimumSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + // WritePacket panics given a prependable with anything less than + // the minimum size of the ethernet header. + ReserveHeaderBytes: header.EthernetMinimumSize, + }) proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000)) - if err := c.ep.WritePacket(&r, nil /* gso */, proto, &stack.PacketBuffer{ - Header: hdr, - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -403,12 +405,12 @@ func TestFillTxQueue(t *testing.T) { // until the tx queue if full. ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -422,11 +424,11 @@ func TestFillTxQueue(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -450,11 +452,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // Send two packets so that the id slice has at least two slots. for i := 2; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } @@ -473,11 +475,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { // until the tx queue if full. ids := make(map[uint64]struct{}) for i := queuePipeSize / 40; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -491,11 +493,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } @@ -517,11 +519,11 @@ func TestFillTxMemory(t *testing.T) { // we fill the memory. ids := make(map[uint64]struct{}) for i := queueDataSize / bufferSize; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -536,11 +538,11 @@ func TestFillTxMemory(t *testing.T) { } // Next attempt to write must fail. - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), }) + err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt) if want := tcpip.ErrWouldBlock; err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } @@ -564,11 +566,11 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Each packet is uses up one buffer, so write as many as possible // until there is only one buffer left. for i := queueDataSize/bufferSize - 1; i > 0; i-- { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } @@ -579,23 +581,22 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { // Attempt to write a two-buffer packet. It must fail. { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - uu := buffer.NewView(bufferSize).ToVectorisedView() - if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: uu, - }); err != want { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buffer.NewView(bufferSize).ToVectorisedView(), + }) + if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want { t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want) } } // Attempt to write the one-buffer packet again. It must succeed. { - hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength())) - if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - Data: buf.ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(c.ep.MaxHeaderLength()), + Data: buf.ToVectorisedView(), + }) + if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil { t.Fatalf("WritePacket failed unexpectedly: %v", err) } } diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go index 6b8d7859d..44f421c2d 100644 --- a/pkg/tcpip/link/sharedmem/tx.go +++ b/pkg/tcpip/link/sharedmem/tx.go @@ -18,6 +18,7 @@ import ( "math" "syscall" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue" ) @@ -76,9 +77,9 @@ func (t *tx) cleanup() { syscall.Munmap(t.data) } -// transmit sends a packet made up of up to two buffers. Returns a boolean that -// specifies whether the packet was successfully transmitted. -func (t *tx) transmit(a, b []byte) bool { +// transmit sends a packet made of bufs. Returns a boolean that specifies +// whether the packet was successfully transmitted. +func (t *tx) transmit(bufs ...buffer.View) bool { // Pull completions from the tx queue and add their buffers back to the // pool so that we can reuse them. for { @@ -93,7 +94,10 @@ func (t *tx) transmit(a, b []byte) bool { } bSize := t.bufs.entrySize - total := uint32(len(a) + len(b)) + total := uint32(0) + for _, data := range bufs { + total += uint32(len(data)) + } bufCount := (total + bSize - 1) / bSize // Allocate enough buffers to hold all the data. @@ -115,7 +119,7 @@ func (t *tx) transmit(a, b []byte) bool { // Copy data into allocated buffers. nBuf := buf var dBuf []byte - for _, data := range [][]byte{a, b} { + for _, data := range bufs { for len(data) > 0 { if len(dBuf) == 0 { dBuf = t.data[nBuf.Offset:][:nBuf.Size] diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 509076643..4fb127978 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -134,7 +134,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw logPacket(prefix, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { - totalLength := pkt.Header.UsedLength() + pkt.Data.Size() + totalLength := pkt.Size() length := totalLength if max := int(e.maxPCAPLen); length > max { length = max @@ -155,12 +155,11 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw length -= n } } - write(pkt.Header.View()) - for _, view := range pkt.Data.Views() { + for _, v := range pkt.Views() { if length == 0 { break } - write(view) + write(v) } } } @@ -185,9 +184,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - e.dumpPacket("send", nil, 0, &stack.PacketBuffer{ + e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) return e.Endpoint.WriteRawPacket(vv) } @@ -201,12 +200,8 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P var fragmentOffset uint16 var moreFragments bool - // Create a clone of pkt, including any headers if present. Avoid allocating - // backing memory for the clone. - views := [8]buffer.View{} - vv := buffer.NewVectorisedView(0, views[:0]) - vv.AppendView(pkt.Header.View()) - vv.Append(pkt.Data) + // Examine the packet using a new VV. Backing storage must not be written. + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) switch protocol { case header.IPv4ProtocolNumber: diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 22b0a12bd..3b1510a33 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -215,12 +215,11 @@ func (d *Device) Write(data []byte) (int64, error) { remote = tcpip.LinkAddress(zeroMAC[:]) } - pkt := &stack.PacketBuffer{ - Data: buffer.View(data).ToVectorisedView(), - } - if ethHdr != nil { - pkt.LinkHeader = buffer.View(ethHdr) - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: len(ethHdr), + Data: buffer.View(data).ToVectorisedView(), + }) + copy(pkt.LinkHeader().Push(len(ethHdr)), ethHdr) endpoint.InjectLinkAddr(protocol, remote, pkt) return dataLen, nil } @@ -265,21 +264,22 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { // If the packet does not already have link layer header, and the route // does not exist, we can't compute it. This is possibly a raw packet, tun // device doesn't support this at the moment. - if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" { + if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" { return nil, false } // Ethernet header (TAP only). if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. - if info.Pkt.LinkHeader == nil { + if info.Pkt.LinkHeader().View().IsEmpty() { d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) } - vv.AppendView(info.Pkt.LinkHeader) + vv.AppendView(info.Pkt.LinkHeader().View()) } // Append upper headers. - vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):])) + vv.AppendView(info.Pkt.NetworkHeader().View()) + vv.AppendView(info.Pkt.TransportHeader().View()) // Append data payload. vv.Append(info.Pkt.Data) @@ -361,8 +361,7 @@ func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip. if !e.isTap { return } - eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) - pkt.LinkHeader = buffer.View(eth) + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) hdr := &header.EthernetFields{ SrcAddr: local, DstAddr: remote, diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index c448a888f..94827fc56 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -104,21 +104,21 @@ func TestWaitWrite(t *testing.T) { wep := New(ep) // Write and check that it goes through. - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 1; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on dispatches, then try to write. It must go through. wep.WaitDispatch() - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } // Wait on writes, then try to write. It must not go through. wep.WaitWrite() - wep.WritePacket(nil, nil /* gso */, 0, &stack.PacketBuffer{}) + wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.writeCount != want { t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want) } @@ -135,21 +135,21 @@ func TestWaitDispatch(t *testing.T) { } // Dispatch and check that it goes through. - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 1; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on writes, then try to dispatch. It must go through. wep.WaitWrite() - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } // Wait on dispatches, then try to dispatch. It must not go through. wep.WaitDispatch() - ep.dispatcher.DeliverNetworkPacket("", "", 0, &stack.PacketBuffer{}) + ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{})) if want := 2; ep.dispatchCount != want { t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want) } diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 31a242482..1ad788a17 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -99,7 +99,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu } func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.ARP(pkt.NetworkHeader) + h := header.ARP(pkt.NetworkHeader().View()) if !h.IsValid() { return } @@ -110,17 +110,17 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } - hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) - packet := header.ARP(hdr.Prepend(header.ARPSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize, + }) + packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) + _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) fallthrough // also fill the cache from requests case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -168,17 +168,17 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd r.RemoteLinkAddress = header.EthernetBroadcastAddress } - hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) - h := header.ARP(hdr.Prepend(header.ARPSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize, + }) + h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) h.SetIPv4OverEthernet() h.SetOp(header.ARPRequest) copy(h.HardwareAddressSender(), linkEP.LinkAddress()) copy(h.ProtocolAddressSender(), localAddr) copy(h.ProtocolAddressTarget(), addr) - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. @@ -210,12 +210,10 @@ func (*protocol) Wait() {} // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.ARPSize) + _, ok = pkt.NetworkHeader().Consume(header.ARPSize) if !ok { return 0, false, false } - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(header.ARPSize) return 0, false, true } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index a35a64a0f..c2c3e6891 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -106,9 +106,9 @@ func TestDirectRequest(t *testing.T) { inject := func(addr tcpip.Address) { copy(h.ProtocolAddressTarget(), addr) - c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: v.ToVectorisedView(), - }) + })) } for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { @@ -118,9 +118,9 @@ func TestDirectRequest(t *testing.T) { if pi.Proto != arp.ProtocolNumber { t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } - rep := header.ARP(pi.Pkt.Header.View()) + rep := header.ARP(pi.Pkt.NetworkHeader().View()) if !rep.IsValid() { - t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) + t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) } if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 615bae648..e6768258a 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -156,13 +156,13 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne var dstAddr tcpip.Address if t.v4 { - h := header.IPv4(pkt.Header.View()) + h := header.IPv4(pkt.NetworkHeader().View()) prot = tcpip.TransportProtocolNumber(h.Protocol()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() } else { - h := header.IPv6(pkt.Header.View()) + h := header.IPv6(pkt.NetworkHeader().View()) prot = tcpip.TransportProtocolNumber(h.NextHeader()) srcAddr = h.SourceAddress() dstAddr = h.DestinationAddress() @@ -243,8 +243,11 @@ func TestIPv4Send(t *testing.T) { payload[i] = uint8(i) } - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + // Setup the packet buffer. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ep.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + }) // Issue the write. o.protocol = 123 @@ -260,10 +263,7 @@ func TestIPv4Send(t *testing.T) { Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { + }, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -303,9 +303,13 @@ func TestIPv4Receive(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -455,17 +459,25 @@ func TestIPv4FragmentationReceive(t *testing.T) { } // Send first segment. - pkt := stack.PacketBuffer{Data: frag1.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: frag1.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) } // Send second segment. - pkt = stack.PacketBuffer{Data: frag2.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: frag2.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -485,8 +497,11 @@ func TestIPv6Send(t *testing.T) { payload[i] = uint8(i) } - // Allocate the header buffer. - hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + // Setup the packet buffer. + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(ep.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + }) // Issue the write. o.protocol = 123 @@ -502,10 +517,7 @@ func TestIPv6Send(t *testing.T) { Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }); err != nil { + }, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } } @@ -545,9 +557,13 @@ func TestIPv6Receive(t *testing.T) { t.Fatalf("could not find route: %v", err) } - pkt := stack.PacketBuffer{Data: view.ToVectorisedView()} - proto.Parse(&pkt) - ep.HandlePacket(&r, &pkt) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, ok := proto.Parse(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(&r, pkt) if o.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) } @@ -673,11 +689,9 @@ func TestIPv6ReceiveControl(t *testing.T) { // becomes Data. func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer { v := view[:len(view)-trunc] - if len(v) < netHdrLen { - return &stack.PacketBuffer{Data: v.ToVectorisedView()} - } - return &stack.PacketBuffer{ - NetworkHeader: v[:netHdrLen], - Data: v[netHdrLen:].ToVectorisedView(), - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v.ToVectorisedView(), + }) + _, _ = pkt.NetworkHeader().Consume(netHdrLen) + return pkt } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 94803a359..067d770f3 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -89,12 +89,14 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { return } + // Make a copy of data before pkt gets sent to raw socket. + // DeliverTransportPacket will take ownership of pkt. + replyData := pkt.Data.Clone(nil) + replyData.TrimFront(header.ICMPv4MinimumSize) + // It's possible that a raw socket expects to receive this. h.SetChecksum(wantChecksum) - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, &stack.PacketBuffer{ - Data: pkt.Data.Clone(nil), - NetworkHeader: append(buffer.View(nil), pkt.NetworkHeader...), - }) + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) remoteLinkAddr := r.RemoteLinkAddress @@ -116,24 +118,26 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // Use the remote link address from the incoming packet. r.ResolveWith(remoteLinkAddr) - vv := pkt.Data.Clone(nil) - vv.TrimFront(header.ICMPv4MinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - copy(pkt, h) - pkt.SetType(header.ICMPv4EchoReply) - pkt.SetChecksum(0) - pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) + // Prepare a reply packet. + icmpHdr := make(header.ICMPv4, header.ICMPv4MinimumSize) + copy(icmpHdr, h) + icmpHdr.SetType(header.ICMPv4EchoReply) + icmpHdr.SetChecksum(0) + icmpHdr.SetChecksum(^header.Checksum(icmpHdr, header.ChecksumVV(replyData, 0))) + dataVV := buffer.View(icmpHdr).ToVectorisedView() + dataVV.Append(replyData) + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: dataVV, + }) + + // Send out the reply packet. sent := stats.ICMP.V4PacketsSent if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: vv, - TransportHeader: buffer.View(pkt), - }); err != nil { + }, replyPkt); err != nil { sent.Dropped.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 9ff27a363..3cd48ceb3 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -21,7 +21,6 @@ package ipv4 import ( - "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" @@ -127,14 +126,12 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { } // writePacketFragments calls e.linkEP.WritePacket with each packet fragment to -// write. It assumes that the IP header is entirely in pkt.Header but does not -// assume that only the IP header is in pkt.Header. It assumes that the input -// packet's stated length matches the length of the header+payload. mtu -// includes the IP header and options. This does not support the DontFragment -// IP flag. +// write. It assumes that the IP header is already present in pkt.NetworkHeader. +// pkt.TransportHeader may be set. mtu includes the IP header and options. This +// does not support the DontFragment IP flag. func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error { // This packet is too big, it needs to be fragmented. - ip := header.IPv4(pkt.Header.View()) + ip := header.IPv4(pkt.NetworkHeader().View()) flags := ip.Flags() // Update mtu to take into account the header, which will exist in all @@ -148,88 +145,84 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, outerMTU := innerMTU + int(ip.HeaderLength()) offset := ip.FragmentOffset() - originalAvailableLength := pkt.Header.AvailableLength() + + // Keep the length reserved for link-layer, we need to create fragments with + // the same reserved length. + reservedForLink := pkt.AvailableHeaderBytes() + + // Destroy the packet, pull all payloads out for fragmentation. + transHeader, data := pkt.TransportHeader().View(), pkt.Data + + // Where possible, the first fragment that is sent has the same + // number of bytes reserved for header as the input packet. The link-layer + // endpoint may depend on this for looking at, eg, L4 headers. + transFitsFirst := len(transHeader) <= innerMTU + for i := 0; i < n; i++ { - // Where possible, the first fragment that is sent has the same - // pkt.Header.UsedLength() as the input packet. The link-layer - // endpoint may depend on this for looking at, eg, L4 headers. - h := ip - if i > 0 { - pkt.Header = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength) - h = header.IPv4(pkt.Header.Prepend(int(ip.HeaderLength()))) - copy(h, ip[:ip.HeaderLength()]) + reserve := reservedForLink + int(ip.HeaderLength()) + if i == 0 && transFitsFirst { + // Reserve for transport header if it's going to be put in the first + // fragment. + reserve += len(transHeader) + } + fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: reserve, + }) + fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + + // Copy data for the fragment. + avail := innerMTU + + if n := len(transHeader); n > 0 { + if n > avail { + n = avail + } + if i == 0 && transFitsFirst { + copy(fragPkt.TransportHeader().Push(n), transHeader) + } else { + fragPkt.Data.AppendView(transHeader[:n:n]) + } + transHeader = transHeader[n:] + avail -= n } + + if avail > 0 { + n := data.Size() + if n > avail { + n = avail + } + data.ReadToVV(&fragPkt.Data, n) + avail -= n + } + + copied := uint16(innerMTU - avail) + + // Set lengths in header and calculate checksum. + h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip))) + copy(h, ip) if i != n-1 { h.SetTotalLength(uint16(outerMTU)) h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset) } else { - h.SetTotalLength(uint16(h.HeaderLength()) + uint16(pkt.Data.Size())) + h.SetTotalLength(uint16(h.HeaderLength()) + copied) h.SetFlagsFragmentOffset(flags, offset) } h.SetChecksum(0) h.SetChecksum(^h.CalculateChecksum()) - offset += uint16(innerMTU) - if i > 0 { - newPayload := pkt.Data.Clone(nil) - newPayload.CapLength(innerMTU) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), - NetworkProtocolNumber: header.IPv4ProtocolNumber, - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - pkt.Data.TrimFront(newPayload.Size()) - continue - } - // Special handling for the first fragment because it comes - // from the header. - if outerMTU >= pkt.Header.UsedLength() { - // This fragment can fit all of pkt.Header and possibly - // some of pkt.Data, too. - newPayload := pkt.Data.Clone(nil) - newPayloadLength := outerMTU - pkt.Header.UsedLength() - newPayload.CapLength(newPayloadLength) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), - NetworkProtocolNumber: header.IPv4ProtocolNumber, - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - pkt.Data.TrimFront(newPayloadLength) - } else { - // The fragment is too small to fit all of pkt.Header. - startOfHdr := pkt.Header - startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) - emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) - if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: startOfHdr, - Data: emptyVV, - NetworkHeader: buffer.View(h), - NetworkProtocolNumber: header.IPv4ProtocolNumber, - }); err != nil { - return err - } - r.Stats().IP.PacketsSent.Increment() - // Add the unused bytes of pkt.Header into the pkt.Data - // that remains to be sent. - restOfHdr := pkt.Header.View()[outerMTU:] - tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)}) - tmp.Append(pkt.Data) - pkt.Data = tmp + offset += copied + + // Send out the fragment. + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { + return err } + r.Stats().IP.PacketsSent.Increment() } return nil } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - length := uint16(hdr.UsedLength() + payloadSize) +func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { + ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) + length := uint16(pkt.Size()) // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. @@ -245,14 +238,12 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS DstAddr: r.RemoteAddress, }) ip.SetChecksum(^ip.CalculateChecksum()) - return ip + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + e.addIPHeader(r, pkt, params) // iptables filtering. All packets that reach here are locally // generated. @@ -269,7 +260,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // only NATted packets, but removing this check short circuits broadcasts // before they are sent out to other hosts. if pkt.NatDone { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) if err == nil { route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) @@ -286,7 +277,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw if r.Loop&stack.PacketOut == 0 { return nil } - if pkt.Header.UsedLength()+pkt.Data.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { + if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) { return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt) } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { @@ -306,9 +297,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pkt := pkts.Front(); pkt != nil; { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + e.addIPHeader(r, pkt, params) pkt = pkt.Next() } @@ -333,7 +322,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe continue } if _, ok := natPkts[pkt]; ok { - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil { src := netHeader.SourceAddress() dst := netHeader.DestinationAddress() @@ -402,17 +391,14 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu r.Stats().IP.PacketsSent.Increment() - ip = ip[:ip.HeaderLength()] - pkt.Header = buffer.NewPrependableFromView(buffer.View(ip)) - pkt.Data.TrimFront(int(ip.HeaderLength())) return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.IPv4(pkt.NetworkHeader) - if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { + h := header.IPv4(pkt.NetworkHeader().View()) + if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } @@ -426,7 +412,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } if h.More() || h.FragmentOffset() != 0 { - if pkt.Data.Size()+len(pkt.TransportHeader) == 0 { + if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. r.Stats().IP.MalformedPacketsReceived.Increment() @@ -470,7 +456,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { - pkt.NetworkHeader.CapLength(int(h.HeaderLength())) e.handleICMP(r, pkt) return } @@ -560,14 +545,19 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu } ipHdr := header.IPv4(hdr) - // If there are options, pull those into hdr as well. - if headerLen := int(ipHdr.HeaderLength()); headerLen > header.IPv4MinimumSize && headerLen <= pkt.Data.Size() { - hdr, ok = pkt.Data.PullUp(headerLen) - if !ok { - panic(fmt.Sprintf("There are only %d bytes in pkt.Data, but there should be at least %d", pkt.Data.Size(), headerLen)) - } - ipHdr = header.IPv4(hdr) + // Header may have options, determine the true header length. + headerLen := int(ipHdr.HeaderLength()) + if headerLen < header.IPv4MinimumSize { + // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in + // order for the packet to be valid. Figure out if we want to reject this + // case. + headerLen = header.IPv4MinimumSize + } + hdr, ok = pkt.NetworkHeader().Consume(headerLen) + if !ok { + return 0, false, false } + ipHdr = header.IPv4(hdr) // If this is a fragment, don't bother parsing the transport header. parseTransportHeader := true @@ -576,8 +566,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu } pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(len(hdr)) pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) return ipHdr.TransportProtocol(), parseTransportHeader, true } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 63e2c36c2..afd3ac06d 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -17,6 +17,7 @@ package ipv4_test import ( "bytes" "encoding/hex" + "fmt" "math/rand" "testing" @@ -91,15 +92,11 @@ func TestExcludeBroadcast(t *testing.T) { }) } -// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much +// makeRandPkt generates a randomize packet. hdrLength indicates how much // data should already be in the header before WritePacket. extraLength // indicates how much extra space should be in the header. The payload is made // from many Views of the sizes listed in viewSizes. -func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) { - hdr := buffer.NewPrependable(hdrLength + extraLength) - hdr.Prepend(hdrLength) - rand.Read(hdr.View()) - +func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer { var views []buffer.View totalLength := 0 for _, s := range viewSizes { @@ -108,8 +105,16 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. views = append(views, newView) totalLength += s } - payload := buffer.NewVectorisedView(totalLength, views) - return hdr, payload + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrLength + extraLength, + Data: buffer.NewVectorisedView(totalLength, views), + }) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + return pkt } // comparePayloads compared the contents of all the packets against the contents @@ -117,9 +122,9 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer. func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. - source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) - source = append(source, sourcePacketInfo.Header.View()...) - source = append(source, sourcePacketInfo.Data.ToView()...) + source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize]) + vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views()) + source = append(source, vv.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. @@ -132,8 +137,7 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI var reassembledPayload []byte for i, packet := range packets { // Confirm that the packet is valid. - allBytes := packet.Header.View().ToVectorisedView() - allBytes.Append(packet.Data) + allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) ip := header.IPv4(allBytes.ToView()) if !ip.IsValid(len(ip)) { t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) @@ -144,10 +148,17 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI if got, want := len(ip), int(mtu); got > want { t.Errorf("fragment is too large, got %d want %d", got, want) } - if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want { - t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + if i == 0 { + got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size() + // sourcePacketInfo does not have NetworkHeader added, simulate one. + want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size() + // Check that it kept the transport header in packet.TransportHeader if + // it fits in the first fragment. + if want < int(mtu) && got != want { + t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) + } } - if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { + if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) } if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want { @@ -284,22 +295,14 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) - source := &stack.PacketBuffer{ - Header: hdr, - // Save the source payload because WritePacket will modify it. - Data: payload.Clone(nil), - NetworkProtocolNumber: header.IPv4ProtocolNumber, - } + pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) + source := pkt.Clone() c := buildContext(t, nil, ft.mtu) err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload, - }) + }, pkt) if err != nil { t.Errorf("err got %v, want %v", err, nil) } @@ -344,16 +347,13 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) + pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, - }, &stack.PacketBuffer{ - Header: hdr, - Data: payload, - }) + }, pkt) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) @@ -472,9 +472,9 @@ func TestInvalidFragments(t *testing.T) { s.CreateNIC(nicID, sniffer.New(ep)) for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, &stack.PacketBuffer{ + ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), - }) + })) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { @@ -859,9 +859,9 @@ func TestReceiveFragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) - e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) } if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index ded91d83a..39ae19295 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -83,7 +83,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } h := header.ICMPv6(v) - iph := header.IPv6(pkt.NetworkHeader) + iph := header.IPv6(pkt.NetworkHeader().View()) // Validate ICMPv6 checksum before processing the packet. // @@ -276,8 +276,10 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme optsSerializer := header.NDPOptionsSerializer{ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress), } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length())) - packet := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()), + }) + packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) packet.SetType(header.ICMPv6NeighborAdvert) na := header.NDPNeighborAdvert(packet.NDPPayload()) na.SetSolicitedFlag(solicited) @@ -293,9 +295,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // // The IP Hop Limit field has a value of 255, i.e., the packet // could not possibly have been forwarded by a router. - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - }); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil { sent.Dropped.Increment() return } @@ -384,7 +384,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme case header.ICMPv6EchoRequest: received.EchoRequest.Increment() - icmpHdr, ok := pkt.Data.PullUp(header.ICMPv6EchoMinimumSize) + icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) if !ok { received.Invalid.Increment() return @@ -409,16 +409,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // Use the link address from the source of the original packet. r.ResolveWith(remoteLinkAddr) - pkt.Data.TrimFront(header.ICMPv6EchoMinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize) - packet := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, + Data: pkt.Data, + }) + packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize)) copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: pkt.Data, - }); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil { sent.Dropped.Increment() return } @@ -539,17 +538,19 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr) } - hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - copy(pkt[icmpV6OptOffset-len(addr):], addr) - pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr - pkt[icmpV6LengthOffset] = 1 - copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress()) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - - length := uint16(hdr.UsedLength()) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize, + }) + icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize)) + icmpHdr.SetType(header.ICMPv6NeighborSolicit) + copy(icmpHdr[icmpV6OptOffset-len(addr):], addr) + icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr + icmpHdr[icmpV6LengthOffset] = 1 + copy(icmpHdr[icmpV6LengthOffset+1:], linkEP.LinkAddress()) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + length := uint16(pkt.Size()) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(header.ICMPv6ProtocolNumber), @@ -559,9 +560,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd }) // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, &stack.PacketBuffer{ - Header: hdr, - }) + return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver. diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index f86aaed1d..2a2f7de01 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -183,7 +183,11 @@ func TestICMPCounts(t *testing.T) { } handleIPv6Payload := func(icmp header.ICMPv6) { - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: buffer.View(icmp).ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(len(icmp)), NextHeader: uint8(header.ICMPv6ProtocolNumber), @@ -191,10 +195,7 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, &stack.PacketBuffer{ - NetworkHeader: buffer.View(ip), - Data: buffer.View(icmp).ToVectorisedView(), - }) + ep.HandlePacket(&r, pkt) } for _, typ := range types { @@ -323,12 +324,10 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. pi, _ := args.src.ReadContext(context.Background()) { - views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} - size := pi.Pkt.Header.UsedLength() + pi.Pkt.Data.Size() - vv := buffer.NewVectorisedView(size, views) - args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), &stack.PacketBuffer{ - Data: vv, + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()), }) + args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt) } if pi.Proto != ProtocolNumber { @@ -340,7 +339,9 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) } - ipv6 := header.IPv6(pi.Pkt.Header.View()) + // Pull the full payload since network header. Needed for header.IPv6 to + // extract its payload. + ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader())) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { t.Errorf("unexpected transport protocol number %d", transProto) @@ -558,9 +559,10 @@ func TestICMPChecksumValidationSimple(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -719,12 +721,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { icmpSize := size + payloadSize hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(typ) - payloadFn(pkt.Payload()) + icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize)) + icmpHdr.SetType(typ) + payloadFn(icmpHdr.Payload()) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, buffer.VectorisedView{})) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -735,9 +737,10 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived @@ -895,14 +898,14 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - pkt := header.ICMPv6(hdr.Prepend(size)) - pkt.SetType(typ) + icmpHdr := header.ICMPv6(hdr.Prepend(size)) + icmpHdr.SetType(typ) payload := buffer.NewView(payloadSize) payloadFn(payload) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView())) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, payload.ToVectorisedView())) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -913,9 +916,10 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { SrcAddr: lladdr1, DstAddr: lladdr0, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), }) + e.InjectInbound(ProtocolNumber, pkt) } stats := s.Stats().ICMP.V6PacketsReceived diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index d7d7fc611..0ade655b2 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -99,9 +99,9 @@ func (e *endpoint) GSOMaxSize() uint32 { return 0 } -func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv6 { - length := uint16(hdr.UsedLength() + payloadSize) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) +func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { + length := uint16(pkt.Size()) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, NextHeader: uint8(params.Protocol), @@ -110,26 +110,20 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - return ip + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) - pkt.NetworkHeader = buffer.View(ip) - pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + e.addIPHeader(r, pkt, params) if r.Loop&stack.PacketLoop != 0 { - // The inbound path expects the network header to still be in - // the PacketBuffer's Data field. - views := make([]buffer.View, 1, 1+len(pkt.Data.Views())) - views[0] = pkt.Header.View() - views = append(views, pkt.Data.Views()...) loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, &stack.PacketBuffer{ - Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views), - }) + e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{ + // The inbound path expects an unparsed packet. + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) loopedR.Release() } @@ -151,9 +145,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pb := pkts.Front(); pb != nil; pb = pb.Next() { - ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) - pb.NetworkHeader = buffer.View(ip) - pb.NetworkProtocolNumber = header.IPv6ProtocolNumber + e.addIPHeader(r, pb, params) } n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) @@ -171,8 +163,8 @@ func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuff // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { - h := header.IPv6(pkt.NetworkHeader) - if !h.IsValid(pkt.Data.Size() + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) { + h := header.IPv6(pkt.NetworkHeader().View()) + if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { r.Stats().IP.MalformedPacketsReceived.Increment() return } @@ -181,8 +173,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. // - Any other payload data. - vv := pkt.NetworkHeader[header.IPv6MinimumSize:].ToVectorisedView() - vv.AppendView(pkt.TransportHeader) + vv := pkt.NetworkHeader().View()[header.IPv6MinimumSize:].ToVectorisedView() + vv.AppendView(pkt.TransportHeader().View()) vv.Append(pkt.Data) it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) hasFragmentHeader := false @@ -410,7 +402,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // // For reassembled fragments, pkt.TransportHeader is unset, so this is a // no-op and pkt.Data begins with the transport header. - extHdr.Buf.TrimFront(len(pkt.TransportHeader)) + extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { @@ -581,17 +573,14 @@ traverseExtensions: } } - // Put the IPv6 header with extensions in pkt.NetworkHeader. - hdr, ok = pkt.Data.PullUp(header.IPv6MinimumSize + extensionsSize) + // Put the IPv6 header with extensions in pkt.NetworkHeader(). + hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) if !ok { panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) } ipHdr = header.IPv6(hdr) - - pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(len(hdr)) pkt.Data.CapLength(int(ipHdr.PayloadLength())) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber return nextHdr, foundNext, true } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 3d65814de..081afb051 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -65,9 +65,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stats := s.Stats().ICMP.V6PacketsReceived @@ -123,9 +123,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst DstAddr: dst, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stat := s.Stats().UDP.PacketsReceived @@ -637,9 +637,9 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { DstAddr: addr2, }) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) stats := s.Stats().UDP.PacketsReceived @@ -1469,9 +1469,9 @@ func TestReceiveIPv6Fragments(t *testing.T) { vv := hdr.View().ToVectorisedView() vv.Append(f.data) - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - }) + })) } if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 64239ce9a..fe159b24f 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -136,9 +136,9 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) if linkAddr != test.expectedLinkAddr { @@ -380,9 +380,9 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, &stack.PacketBuffer{ + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) if test.nsInvalid { if got := invalid.Value(); got != 1 { @@ -410,7 +410,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) } - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(test.naSrc), checker.DstAddr(test.naDst), checker.TTL(header.NDPHopLimit), @@ -497,9 +497,9 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) if linkAddr != test.expectedLinkAddr { @@ -560,7 +560,11 @@ func TestNDPValidation(t *testing.T) { nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) } - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize + len(extensions))) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions), + Data: payload.ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions))) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(len(payload) + len(extensions)), NextHeader: nextHdr, @@ -571,10 +575,7 @@ func TestNDPValidation(t *testing.T) { if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) } - ep.HandlePacket(r, &stack.PacketBuffer{ - NetworkHeader: buffer.View(ip), - Data: payload.ToVectorisedView(), - }) + ep.HandlePacket(r, pkt) } var tllData [header.NDPLinkLayerAddressSize]byte @@ -885,9 +886,9 @@ func TestRouterAdvertValidation(t *testing.T) { t.Fatalf("got rxRA = %d, want = 0", got) } - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) if got := rxRA.Value(); got != 1 { t.Fatalf("got rxRA = %d, want = 1", got) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index bfc7a0c7c..900938dd1 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -57,6 +57,7 @@ go_library( "conntrack.go", "dhcpv6configurationfromndpra_string.go", "forwarder.go", + "headertype_string.go", "icmp_rate_limit.go", "iptables.go", "iptables_state.go", @@ -143,6 +144,7 @@ go_test( "neighbor_cache_test.go", "neighbor_entry_test.go", "nic_test.go", + "packet_buffer_test.go", ], library = ":stack", deps = [ diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 470c265aa..7dd344b4f 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -199,12 +199,12 @@ type bucket struct { func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { // TODO(gvisor.dev/issue/170): Need to support for other // protocols as well. - netHeader := header.IPv4(pkt.NetworkHeader) - if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber { return tupleID{}, tcpip.ErrUnknownProtocol } - tcpHeader := header.TCP(pkt.TransportHeader) - if tcpHeader == nil { + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { return tupleID{}, tcpip.ErrUnknownProtocol } @@ -344,8 +344,8 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { return } - netHeader := header.IPv4(pkt.NetworkHeader) - tcpHeader := header.TCP(pkt.TransportHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) + tcpHeader := header.TCP(pkt.TransportHeader().View()) // For prerouting redirection, packets going in the original direction // have their destinations modified and replies have their sources @@ -377,8 +377,8 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d return } - netHeader := header.IPv4(pkt.NetworkHeader) - tcpHeader := header.TCP(pkt.TransportHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) + tcpHeader := header.TCP(pkt.TransportHeader().View()) // For output redirection, packets going in the original direction // have their destinations modified and replies have their sources @@ -396,8 +396,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - hdr := &pkt.Header - length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) @@ -423,7 +422,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou } // TODO(gvisor.dev/issue/170): Support other transport protocols. - if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { return false } @@ -433,8 +432,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou return true } - tcpHeader := header.TCP(pkt.TransportHeader) - if tcpHeader == nil { + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { return false } @@ -455,7 +454,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou // Mark the connection as having been used recently so it isn't reaped. conn.lastUsed = time.Now() // Update connection state. - conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) return false } @@ -474,7 +473,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { } // We only track TCP connections. - if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber { return } @@ -486,7 +485,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } conn := newConn(tid, tid.reply(), manipNone, hook) - conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) ct.insertConn(conn) } diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index c962693f5..944f622fd 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -75,7 +75,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -97,7 +97,7 @@ func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNu func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { // Add the protocol's header to the packet and send it to the link // endpoint. - b := pkt.Header.Prepend(fwdTestNetHeaderLen) + b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) b[dstAddrOffset] = r.RemoteAddress[0] b[srcAddrOffset] = f.id.LocalAddress[0] b[protocolNumberOffset] = byte(params.Protocol) @@ -144,13 +144,11 @@ func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Add } func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen) + netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen) if !ok { return 0, false, false } - pkt.NetworkHeader = netHeader - pkt.Data.TrimFront(fwdTestNetHeaderLen) - return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true + return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { @@ -290,7 +288,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { p := fwdTestPacketInfo{ - Pkt: &PacketBuffer{Data: vv}, + Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), } select { @@ -382,9 +380,9 @@ func TestForwardingWithStaticResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) var p fwdTestPacketInfo @@ -419,9 +417,9 @@ func TestForwardingWithFakeResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) var p fwdTestPacketInfo @@ -450,9 +448,9 @@ func TestForwardingWithNoResolver(t *testing.T) { // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) select { case <-ep2.C: @@ -480,17 +478,17 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // not be forwarded. buf := buffer.NewView(30) buf[dstAddrOffset] = 4 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf = buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) var p fwdTestPacketInfo @@ -500,8 +498,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { t.Fatal("packet not forwarded") } - if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -529,9 +527,9 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } for i := 0; i < 2; i++ { @@ -543,8 +541,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { t.Fatal("packet not forwarded") } - if p.Pkt.NetworkHeader[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset]) + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) } // Test that the address resolution happened correctly. @@ -575,9 +573,9 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { buf[dstAddrOffset] = 3 // Set the packet sequence number. binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } for i := 0; i < maxPendingPacketsPerResolution; i++ { @@ -589,13 +587,14 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { t.Fatal("packet not forwarded") } - if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 { + b := PayloadSince(p.Pkt.NetworkHeader()) + if b[dstAddrOffset] != 3 { t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) } - seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes). - if !ok { - t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size()) + if len(b) < fwdTestNetHeaderLen+2 { + t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) } + seqNumBuf := b[fwdTestNetHeaderLen:] // The first 5 packets should not be forwarded so the sequence number should // start with 5. @@ -632,9 +631,9 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // maxPendingResolutions + 7). buf := buffer.NewView(30) buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } for i := 0; i < maxPendingResolutions; i++ { @@ -648,8 +647,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // The first 5 packets (address 3 to 7) should not be forwarded // because their address resolutions are interrupted. - if p.Pkt.NetworkHeader[dstAddrOffset] < 8 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset]) + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) } // Test that the address resolution happened correctly. diff --git a/pkg/tcpip/stack/headertype_string.go b/pkg/tcpip/stack/headertype_string.go new file mode 100644 index 000000000..5efddfaaf --- /dev/null +++ b/pkg/tcpip/stack/headertype_string.go @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type headerType ."; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[linkHeader-0] + _ = x[networkHeader-1] + _ = x[transportHeader-2] + _ = x[numHeaderType-3] +} + +const _headerType_name = "linkHeadernetworkHeadertransportHeadernumHeaderType" + +var _headerType_index = [...]uint8{0, 10, 23, 38, 51} + +func (i headerType) String() string { + if i < 0 || i >= headerType(len(_headerType_index)-1) { + return "headerType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _headerType_name[_headerType_index[i]:_headerType_index[i+1]] +} diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 110ba073d..c37da814f 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -394,7 +394,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { + if !rule.Filter.match(header.IPv4(pkt.NetworkHeader().View()), hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index dc88033c7..5f1b2af64 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -99,7 +99,7 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso } // Drop the packet if network and transport header are not set. - if pkt.NetworkHeader == nil || pkt.TransportHeader == nil { + if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() { return RuleDrop, 0 } @@ -118,17 +118,16 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if // we need to change dest address (for OUTPUT chain) or ports. - netHeader := header.IPv4(pkt.NetworkHeader) + netHeader := header.IPv4(pkt.NetworkHeader().View()) switch protocol := netHeader.TransportProtocol(); protocol { case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader) + udpHeader := header.UDP(pkt.TransportHeader().View()) udpHeader.SetDestinationPort(rt.MinPort) // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) - hdr := &pkt.Header - length := uint16(pkt.Data.Size()+hdr.UsedLength()) - uint16(netHeader.HeaderLength()) + length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength()) // Only calculate the checksum if offloading isn't supported. if r.Capabilities()&CapabilityTXChecksumOffload == 0 { diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 5174e639c..93567806b 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -746,12 +746,16 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID())) } - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) + icmpData.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmpData.NDPPayload()) ns.SetTargetAddress(addr) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) sent := r.Stats().ICMP.V6PacketsSent if err := r.WritePacket(nil, @@ -759,7 +763,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEnd Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, &PacketBuffer{Header: hdr}, + }, pkt, ); err != nil { sent.Dropped.Increment() return err @@ -1897,12 +1901,16 @@ func (ndp *ndpState) startSolicitingRouters() { } } payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize) - pkt := header.ICMPv6(hdr.Prepend(payloadSize)) - pkt.SetType(header.ICMPv6RouterSolicit) - rs := header.NDPRouterSolicit(pkt.NDPPayload()) + icmpData := header.ICMPv6(buffer.NewView(payloadSize)) + icmpData.SetType(header.ICMPv6RouterSolicit) + rs := header.NDPRouterSolicit(icmpData.NDPPayload()) rs.Options().Serialize(optsSerializer) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(icmpData).ToVectorisedView(), + }) sent := r.Stats().ICMP.V6PacketsSent if err := r.WritePacket(nil, @@ -1910,7 +1918,7 @@ func (ndp *ndpState) startSolicitingRouters() { Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS, - }, &PacketBuffer{Header: hdr}, + }, pkt, ); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 5d286ccbc..21bf53010 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -541,7 +541,7 @@ func TestDADResolve(t *testing.T) { // As per RFC 4861 section 4.3, a possible option is the Source Link // Layer option, but this option MUST NOT be included when the source // address of the packet is the unspecified address. - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), checker.DstAddr(snmc), checker.TTL(header.NDPHopLimit), @@ -550,8 +550,8 @@ func TestDADResolve(t *testing.T) { checker.NDPNSOptions(nil), )) - if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } } }) @@ -667,9 +667,10 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate multiple nodes owning or // attempting to own the same address. hdr := test.makeBuf(addr1) - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) + e.InjectInbound(header.IPv6ProtocolNumber, pkt) stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) if got := stat.Value(); got != 1 { @@ -1024,7 +1025,9 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo DstAddr: header.IPv6AllNodesMulticastAddress, }) - return &stack.PacketBuffer{Data: hdr.View().ToVectorisedView()} + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) } // raBufWithOpts returns a valid NDP Router Advertisement with options. @@ -5134,16 +5137,15 @@ func TestRouterSolicitation(t *testing.T) { t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } - checker.IPv6(t, - p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(test.expectedSrcAddr), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), ) - if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } } waitForNothing := func(timeout time.Duration) { @@ -5288,7 +5290,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { if p.Proto != header.IPv6ProtocolNumber { t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) } - checker.IPv6(t, p.Pkt.Header.View(), + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index eaaf756cd..2315ea5b9 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1299,7 +1299,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - src, dst := netProto.ParseAddresses(pkt.NetworkHeader) + src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil { // The source address is one of our own, so we never should have gotten a @@ -1401,24 +1401,19 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. - // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this - // by having lower layers explicity write each header instead of just - // pkt.Header. - // pkt may have set its NetworkHeader and TransportHeader. If we're - // forwarding, we'll have to copy them into pkt.Header. - pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader)) - if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader))) - } - if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) { - panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader))) - } + // pkt may have set its header and may not have enough headroom for link-layer + // header for the other link to prepend. Here we create a new packet to + // forward. + fwdPkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()), + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + }) - // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + // WritePacket takes ownership of fwdPkt, calculate numBytes first. + numBytes := fwdPkt.Size() - if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil { + if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return } @@ -1443,34 +1438,31 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // validly formed. n.stack.demux.deliverRawPacket(r, protocol, pkt) - // TransportHeader is nil only when pkt is an ICMP packet or was reassembled + // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. - if pkt.TransportHeader == nil { + if pkt.TransportHeader().View().IsEmpty() { // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. - transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) - if !ok { + if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } - pkt.TransportHeader = transHeader - pkt.Data.TrimFront(len(pkt.TransportHeader)) } else { // This is either a bad packet or was re-assembled from fragments. transProto.Parse(pkt) } } - if len(pkt.TransportHeader) < transProto.MinimumPacketSize() { + if pkt.TransportHeader().View().Size() < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() return } - srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader) + srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { n.stack.stats.MalformedRcvdPackets.Increment() return diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index a70792b50..0870c8d9c 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -311,7 +311,9 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { t.FailNow() } - nic.DeliverNetworkPacket("", "", 0, &PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + nic.DeliverNetworkPacket("", "", 0, NewPacketBuffer(PacketBufferOptions{ + Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), + })) if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 9e871f968..17b8beebb 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -14,16 +14,43 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) +type headerType int + +const ( + linkHeader headerType = iota + networkHeader + transportHeader + numHeaderType +) + +// PacketBufferOptions specifies options for PacketBuffer creation. +type PacketBufferOptions struct { + // ReserveHeaderBytes is the number of bytes to reserve for headers. Total + // number of bytes pushed onto the headers must not exceed this value. + ReserveHeaderBytes int + + // Data is the initial unparsed data for the new packet. If set, it will be + // owned by the new packet. + Data buffer.VectorisedView +} + // A PacketBuffer contains all the data of a network packet. // // As a PacketBuffer traverses up the stack, it may be necessary to pass it to -// multiple endpoints. Clone() should be called in such cases so that -// modifications to the Data field do not affect other copies. +// multiple endpoints. +// +// The whole packet is expected to be a series of bytes in the following order: +// LinkHeader, NetworkHeader, TransportHeader, and Data. Any of them can be +// empty. Use of PacketBuffer in any other order is unsupported. +// +// PacketBuffer must be created with NewPacketBuffer. type PacketBuffer struct { _ sync.NoCopy @@ -31,36 +58,27 @@ type PacketBuffer struct { // PacketBuffers. PacketBufferEntry - // Data holds the payload of the packet. For inbound packets, it also - // holds the headers, which are consumed as the packet moves up the - // stack. Headers are guaranteed not to be split across views. + // Data holds the payload of the packet. + // + // For inbound packets, Data is initially the whole packet. Then gets moved to + // headers via PacketHeader.Consume, when the packet is being parsed. // - // The bytes backing Data are immutable, but Data itself may be trimmed - // or otherwise modified. + // For outbound packets, Data is the innermost layer, defined by the protocol. + // Headers are pushed in front of it via PacketHeader.Push. + // + // The bytes backing Data are immutable, a.k.a. users shouldn't write to its + // backing storage. Data buffer.VectorisedView - // Header holds the headers of outbound packets. As a packet is passed - // down the stack, each layer adds to Header. Note that forwarded - // packets don't populate Headers on their way out -- their headers and - // payload are never parsed out and remain in Data. - // - // TODO(gvisor.dev/issue/170): Forwarded packets don't currently - // populate Header, but should. This will be doable once early parsing - // (https://github.com/google/gvisor/pull/1995) is supported. - Header buffer.Prependable + // headers stores metadata about each header. + headers [numHeaderType]headerInfo - // These fields are used by both inbound and outbound packets. They - // typically overlap with the Data and Header fields. + // header is the internal storage for outbound packets. Headers will be pushed + // (prepended) on this storage as the packet is being constructed. // - // The bytes backing these views are immutable. Each field may be nil - // if either it has not been set yet or no such header exists (e.g. - // packets sent via loopback may not have a link header). - // - // These fields may be Views into other slices (either Data or Header). - // SR dosen't support this, so deep copies are necessary in some cases. - LinkHeader buffer.View - NetworkHeader buffer.View - TransportHeader buffer.View + // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and + // data are held in the same underlying buffer storage. + header buffer.Prependable // NetworkProtocol is only valid when NetworkHeader is set. // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol @@ -89,20 +107,137 @@ type PacketBuffer struct { PktType tcpip.PacketType } -// Clone makes a copy of pk. It clones the Data field, which creates a new -// VectorisedView but does not deep copy the underlying bytes. -// -// Clone also does not deep copy any of its other fields. +// NewPacketBuffer creates a new PacketBuffer with opts. +func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer { + pk := &PacketBuffer{ + Data: opts.Data, + } + if opts.ReserveHeaderBytes != 0 { + pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes) + } + return pk +} + +// ReservedHeaderBytes returns the number of bytes initially reserved for +// headers. +func (pk *PacketBuffer) ReservedHeaderBytes() int { + return pk.header.UsedLength() + pk.header.AvailableLength() +} + +// AvailableHeaderBytes returns the number of bytes currently available for +// headers. This is relevant to PacketHeader.Push method only. +func (pk *PacketBuffer) AvailableHeaderBytes() int { + return pk.header.AvailableLength() +} + +// LinkHeader returns the handle to link-layer header. +func (pk *PacketBuffer) LinkHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: linkHeader, + } +} + +// NetworkHeader returns the handle to network-layer header. +func (pk *PacketBuffer) NetworkHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: networkHeader, + } +} + +// TransportHeader returns the handle to transport-layer header. +func (pk *PacketBuffer) TransportHeader() PacketHeader { + return PacketHeader{ + pk: pk, + typ: transportHeader, + } +} + +// HeaderSize returns the total size of all headers in bytes. +func (pk *PacketBuffer) HeaderSize() int { + // Note for inbound packets (Consume called), headers are not stored in + // pk.header. Thus, calculation of size of each header is needed. + var size int + for i := range pk.headers { + size += len(pk.headers[i].buf) + } + return size +} + +// Size returns the size of packet in bytes. +func (pk *PacketBuffer) Size() int { + return pk.HeaderSize() + pk.Data.Size() +} + +// Views returns the underlying storage of the whole packet. +func (pk *PacketBuffer) Views() []buffer.View { + // Optimization for outbound packets that headers are in pk.header. + useHeader := true + for i := range pk.headers { + if !canUseHeader(&pk.headers[i]) { + useHeader = false + break + } + } + + dataViews := pk.Data.Views() + + var vs []buffer.View + if useHeader { + vs = make([]buffer.View, 0, 1+len(dataViews)) + vs = append(vs, pk.header.View()) + } else { + vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews)) + for i := range pk.headers { + if v := pk.headers[i].buf; len(v) > 0 { + vs = append(vs, v) + } + } + } + return append(vs, dataViews...) +} + +func canUseHeader(h *headerInfo) bool { + // h.offset will be negative if the header was pushed in to prependable + // portion, or doesn't matter when it's empty. + return len(h.buf) == 0 || h.offset < 0 +} + +func (pk *PacketBuffer) push(typ headerType, size int) buffer.View { + h := &pk.headers[typ] + if h.buf != nil { + panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + } + h.buf = buffer.View(pk.header.Prepend(size)) + h.offset = -pk.header.UsedLength() + return h.buf +} + +func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) { + h := &pk.headers[typ] + if h.buf != nil { + panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) + } + v, ok := pk.Data.PullUp(size) + if !ok { + return + } + pk.Data.TrimFront(size) + h.buf = v + return h.buf, true +} + +// Clone makes a shallow copy of pk. // -// FIXME(b/153685824): Data gets copied but not other header references. +// Clone should be called in such cases so that no modifications is done to +// underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { - return &PacketBuffer{ + newPk := &PacketBuffer{ PacketBufferEntry: pk.PacketBufferEntry, Data: pk.Data.Clone(nil), - Header: pk.Header, - LinkHeader: pk.LinkHeader, - NetworkHeader: pk.NetworkHeader, - TransportHeader: pk.TransportHeader, + headers: pk.headers, + header: pk.header, Hash: pk.Hash, Owner: pk.Owner, EgressRoute: pk.EgressRoute, @@ -110,4 +245,55 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { NetworkProtocolNumber: pk.NetworkProtocolNumber, NatDone: pk.NatDone, } + return newPk +} + +// headerInfo stores metadata about a header in a packet. +type headerInfo struct { + // buf is the memorized slice for both prepended and consumed header. + // When header is prepended, buf serves as memorized value, which is a slice + // of pk.header. When header is consumed, buf is the slice pulled out from + // pk.Data, which is the only place to hold this header. + buf buffer.View + + // offset will be a negative number denoting the offset where this header is + // from the end of pk.header, if it is prepended. Otherwise, zero. + offset int +} + +// PacketHeader is a handle object to a header in the underlying packet. +type PacketHeader struct { + pk *PacketBuffer + typ headerType +} + +// View returns the underlying storage of h. +func (h PacketHeader) View() buffer.View { + return h.pk.headers[h.typ].buf +} + +// Push pushes size bytes in the front of its residing packet, and returns the +// backing storage. Callers may only call one of Push or Consume once on each +// header in the lifetime of the underlying packet. +func (h PacketHeader) Push(size int) buffer.View { + return h.pk.push(h.typ, size) +} + +// Consume moves the first size bytes of the unparsed data portion in the packet +// to h, and returns the backing storage. In the case of data is shorter than +// size, consumed will be false, and the state of h will not be affected. +// Callers may only call one of Push or Consume once on each header in the +// lifetime of the underlying packet. +func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { + return h.pk.consume(h.typ, size) +} + +// PayloadSince returns packet payload starting from and including a particular +// header. This method isn't optimized and should be used in test only. +func PayloadSince(h PacketHeader) buffer.View { + var v buffer.View + for _, hinfo := range h.pk.headers[h.typ:] { + v = append(v, hinfo.buf...) + } + return append(v, h.pk.Data.ToView()...) } diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go new file mode 100644 index 000000000..c6fa8da5f --- /dev/null +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -0,0 +1,397 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at // +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "bytes" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func TestPacketHeaderPush(t *testing.T) { + for _, test := range []struct { + name string + reserved int + link []byte + network []byte + transport []byte + data []byte + }{ + { + name: "construct empty packet", + }, + { + name: "construct link header only packet", + reserved: 60, + link: makeView(10), + }, + { + name: "construct link and network header only packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + }, + { + name: "construct header only packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + transport: makeView(30), + }, + { + name: "construct data only packet", + data: makeView(40), + }, + { + name: "construct L3 packet", + reserved: 60, + network: makeView(20), + transport: makeView(30), + data: makeView(40), + }, + { + name: "construct L2 packet", + reserved: 60, + link: makeView(10), + network: makeView(20), + transport: makeView(30), + data: makeView(40), + }, + } { + t.Run(test.name, func(t *testing.T) { + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: test.reserved, + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), + }) + + allHdrSize := len(test.link) + len(test.network) + len(test.transport) + + // Check the initial values for packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + ReserveHeaderBytes: test.reserved, + Data: buffer.View(test.data).ToVectorisedView(), + }) + + // Push headers. + if v := test.transport; len(v) > 0 { + copy(pk.TransportHeader().Push(len(v)), v) + } + if v := test.network; len(v) > 0 { + copy(pk.NetworkHeader().Push(len(v)), v) + } + if v := test.link; len(v) > 0 { + copy(pk.LinkHeader().Push(len(v)), v) + } + + // Check the after values for packet. + if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want { + t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want { + t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), allHdrSize; got != want { + t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) + } + if got, want := pk.Size(), allHdrSize+len(test.data); got != want { + t.Errorf("After pk.Size() = %d, want %d", got, want) + } + checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data) + checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), + concatViews(test.link, test.network, test.transport, test.data)) + // Check the after values for each header. + checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link) + checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network) + checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport) + // Check the after values for PayloadSince. + checkViewEqual(t, "After PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(test.link, test.network, test.transport, test.data)) + checkViewEqual(t, "After PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(test.network, test.transport, test.data)) + checkViewEqual(t, "After PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(test.transport, test.data)) + }) + } +} + +func TestPacketHeaderConsume(t *testing.T) { + for _, test := range []struct { + name string + data []byte + link int + network int + transport int + }{ + { + name: "parse L2 packet", + data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)), + link: 10, + network: 20, + transport: 30, + }, + { + name: "parse L3 packet", + data: concatViews(makeView(20), makeView(30), makeView(40)), + network: 20, + transport: 30, + }, + } { + t.Run(test.name, func(t *testing.T) { + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(), + }) + + // Check the initial values for packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + Data: buffer.View(test.data).ToVectorisedView(), + }) + + // Consume headers. + if size := test.link; size > 0 { + if _, ok := pk.LinkHeader().Consume(size); !ok { + t.Fatalf("pk.LinkHeader().Consume() = false, want true") + } + } + if size := test.network; size > 0 { + if _, ok := pk.NetworkHeader().Consume(size); !ok { + t.Fatalf("pk.NetworkHeader().Consume() = false, want true") + } + } + if size := test.transport; size > 0 { + if _, ok := pk.TransportHeader().Consume(size); !ok { + t.Fatalf("pk.TransportHeader().Consume() = false, want true") + } + } + + allHdrSize := test.link + test.network + test.transport + + // Check the after values for packet. + if got, want := pk.ReservedHeaderBytes(), 0; got != want { + t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), 0; got != want { + t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), allHdrSize; got != want { + t.Errorf("After pk.HeaderSize() = %d, want %d", got, want) + } + if got, want := pk.Size(), len(test.data); got != want { + t.Errorf("After pk.Size() = %d, want %d", got, want) + } + // After state of pk. + var ( + link = test.data[:test.link] + network = test.data[test.link:][:test.network] + transport = test.data[test.link+test.network:][:test.transport] + payload = test.data[allHdrSize:] + ) + checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload) + checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data) + // Check the after values for each header. + checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link) + checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network) + checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport) + // Check the after values for PayloadSince. + checkViewEqual(t, "After PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), + concatViews(link, network, transport, payload)) + checkViewEqual(t, "After PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), + concatViews(network, transport, payload)) + checkViewEqual(t, "After PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), + concatViews(transport, payload)) + }) + } +} + +func TestPacketHeaderConsumeDataTooShort(t *testing.T) { + data := makeView(10) + + pk := NewPacketBuffer(PacketBufferOptions{ + // Make a copy of data to make sure our truth data won't be taint by + // PacketBuffer. + Data: buffer.NewViewFromBytes(data).ToVectorisedView(), + }) + + // Consume should fail if pkt.Data is too short. + if _, ok := pk.LinkHeader().Consume(11); ok { + t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false") + } + if _, ok := pk.NetworkHeader().Consume(11); ok { + t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false") + } + if _, ok := pk.TransportHeader().Consume(11); ok { + t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false") + } + + // Check packet should look the same as initial packet. + checkInitialPacketBuffer(t, pk, PacketBufferOptions{ + Data: buffer.View(data).ToVectorisedView(), + }) +} + +func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: headerSize * int(numHeaderType), + }) + + for _, h := range []PacketHeader{ + pk.TransportHeader(), + pk.NetworkHeader(), + pk.LinkHeader(), + } { + t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) { + h.Push(headerSize) + + defer func() { recover() }() + h.Push(headerSize) + t.Fatal("Second push should have panicked") + }) + } +} + +func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), + }) + + for _, h := range []PacketHeader{ + pk.LinkHeader(), + pk.NetworkHeader(), + pk.TransportHeader(), + } { + t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) { + if _, ok := h.Consume(headerSize); !ok { + t.Fatal("First consume should succeed") + } + + defer func() { recover() }() + h.Consume(headerSize) + t.Fatal("Second consume should have panicked") + }) + } +} + +func TestPacketHeaderPushThenConsumePanics(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: headerSize * int(numHeaderType), + }) + + for _, h := range []PacketHeader{ + pk.TransportHeader(), + pk.NetworkHeader(), + pk.LinkHeader(), + } { + t.Run(h.typ.String(), func(t *testing.T) { + h.Push(headerSize) + + defer func() { recover() }() + h.Consume(headerSize) + t.Fatal("Consume should have panicked") + }) + } +} + +func TestPacketHeaderConsumeThenPushPanics(t *testing.T) { + const headerSize = 10 + + pk := NewPacketBuffer(PacketBufferOptions{ + Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(), + }) + + for _, h := range []PacketHeader{ + pk.LinkHeader(), + pk.NetworkHeader(), + pk.TransportHeader(), + } { + t.Run(h.typ.String(), func(t *testing.T) { + h.Consume(headerSize) + + defer func() { recover() }() + h.Push(headerSize) + t.Fatal("Push should have panicked") + }) + } +} + +func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) { + t.Helper() + reserved := opts.ReserveHeaderBytes + if got, want := pk.ReservedHeaderBytes(), reserved; got != want { + t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.AvailableHeaderBytes(), reserved; got != want { + t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want) + } + if got, want := pk.HeaderSize(), 0; got != want { + t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want) + } + data := opts.Data.ToView() + if got, want := pk.Size(), len(data); got != want { + t.Errorf("Initial pk.Size() = %d, want %d", got, want) + } + checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data) + checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data) + // Check the initial values for each header. + checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil) + checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil) + checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil) + // Check the initial valies for PayloadSince. + checkViewEqual(t, "Initial PayloadSince(LinkHeader)", + PayloadSince(pk.LinkHeader()), data) + checkViewEqual(t, "Initial PayloadSince(NetworkHeader)", + PayloadSince(pk.NetworkHeader()), data) + checkViewEqual(t, "Initial PayloadSince(TransportHeader)", + PayloadSince(pk.TransportHeader()), data) +} + +func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) { + t.Helper() + checkViewEqual(t, name+".View()", h.View(), want) +} + +func checkViewEqual(t *testing.T, what string, got, want buffer.View) { + t.Helper() + if !bytes.Equal(got, want) { + t.Errorf("%s = %x, want %x", what, got, want) + } +} + +func makeView(size int) buffer.View { + b := byte(size) + return bytes.Repeat([]byte{b}, size) +} + +func concatViews(views ...buffer.View) buffer.View { + var all buffer.View + for _, v := range views { + all = append(all, v...) + } + return all +} diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 9ce0a2c22..e267bebb0 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -173,7 +173,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf } // WritePacket takes ownership of pkt, calculate numBytes first. - numBytes := pkt.Header.UsedLength() + pkt.Data.Size() + numBytes := pkt.Size() err := r.ref.ep.WritePacket(r, gso, params, pkt) if err != nil { @@ -203,8 +203,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead writtenBytes := 0 for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { - writtenBytes += pb.Header.UsedLength() - writtenBytes += pb.Data.Size() + writtenBytes += pb.Size() } r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index fe1c1b8a4..0273b3c63 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -102,7 +102,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ // Handle control packets. - if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) { + if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return @@ -118,7 +118,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -143,10 +143,10 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // Add the protocol's header to the packet and send it to the link // endpoint. - pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen) - pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0] - pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0] - pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol) + hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) + hdr[dstAddrOffset] = r.RemoteAddress[0] + hdr[srcAddrOffset] = f.id.LocalAddress[0] + hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { f.HandlePacket(r, pkt) @@ -249,12 +249,10 @@ func (*fakeNetworkProtocol) Wait() {} // Parse implements TransportProtocol.Parse. func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { - hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen) + hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen) if !ok { return 0, false, false } - pkt.NetworkHeader = hdr - pkt.Data.TrimFront(fakeNetHeaderLen) return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } @@ -315,9 +313,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[dstAddrOffset] = 3 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -327,9 +325,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[dstAddrOffset] = 1 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -339,9 +337,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[dstAddrOffset] = 2 - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -350,9 +348,9 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - ep.InjectInbound(fakeNetNumber-1, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber-1, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -362,9 +360,9 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -383,11 +381,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro } func send(r stack.Route, payload buffer.View) *tcpip.Error { - hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: payload.ToVectorisedView(), - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: payload.ToVectorisedView(), + })) } func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { @@ -442,9 +439,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -2285,9 +2282,9 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -2367,9 +2364,9 @@ func TestNICForwarding(t *testing.T) { // Send a packet to dstAddr. buf := buffer.NewView(30) buf[dstAddrOffset] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) pkt, ok := ep2.Read() if !ok { @@ -2377,8 +2374,8 @@ func TestNICForwarding(t *testing.T) { } // Test that the link's MaxHeaderLength is honoured. - if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want { - t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want) + if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want { + t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want) } // Test that forwarding increments Tx stats correctly. diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 73dada928..1339edc2d 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -128,11 +128,10 @@ func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), }) + c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt) } func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) { @@ -166,11 +165,10 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI u.SetChecksum(^u.CalculateChecksum(xsum)) // Inject packet. - c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - NetworkHeader: buffer.View(ip), - TransportHeader: buffer.View(u), + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buf.ToVectorisedView(), }) + c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt) } func TestTransportDemuxerRegister(t *testing.T) { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 7e8b84867..6c6e44468 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -84,16 +84,16 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return 0, nil, tcpip.ErrNoRoute } - hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen) - hdr.Prepend(fakeTransHeaderLen) v, err := p.FullPayload() if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.View(v).ToVectorisedView(), - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, + Data: buffer.View(v).ToVectorisedView(), + }) + _ = pkt.TransportHeader().Push(fakeTransHeaderLen) + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil { return 0, nil, err } @@ -328,13 +328,8 @@ func (*fakeTransportProtocol) Wait() {} // Parse implements TransportProtocol.Parse. func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen) - if !ok { - return false - } - pkt.TransportHeader = hdr - pkt.Data.TrimFront(fakeTransHeaderLen) - return true + _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen) + return ok } func fakeTransFactory() stack.TransportProtocol { @@ -382,9 +377,9 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -393,9 +388,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -404,9 +399,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } @@ -459,9 +454,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -470,9 +465,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -481,9 +476,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } @@ -636,9 +631,9 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, &stack.PacketBuffer{ + ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: req.ToVectorisedView(), - }) + })) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -655,10 +650,11 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Pkt.NetworkHeader[0]; dst != 3 { + nh := stack.PayloadSince(p.Pkt.NetworkHeader()) + if dst := nh[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Pkt.NetworkHeader[1]; src != 1 { + if src := nh[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 0ff3a2b89..9f0dd4d6d 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -80,9 +80,9 @@ func TestPingMulticastBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) { @@ -102,9 +102,9 @@ func TestPingMulticastBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } tests := []struct { @@ -204,7 +204,7 @@ func TestPingMulticastBroadcast(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst) } - src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader) + src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader().View()) if src != expectedSrc { t.Errorf("got pkt source = %s, want = %s", src, expectedSrc) } @@ -252,9 +252,9 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) { @@ -280,9 +280,9 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { DstAddr: dst, }) - e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) } tests := []struct { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 4612be4e7..bd6f49eb8 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -430,9 +430,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return tcpip.ErrInvalidEndpointState } - hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()), + }) + pkt.Owner = owner - icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) copy(icmpv4, data) // Set the ident to the user-specified port. Sequence number should // already be set by the user. @@ -447,15 +450,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) + pkt.Data = data.ToVectorisedView() + if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: data.ToVectorisedView(), - TransportHeader: buffer.View(icmpv4), - Owner: owner, - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { @@ -463,9 +463,11 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err return tcpip.ErrInvalidEndpointState } - hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()), + }) - icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) copy(icmpv6, data) // Set the ident. Sequence number is provided by the user. icmpv6.SetIdent(ident) @@ -477,15 +479,12 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err dataVV := data.ToVectorisedView() icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV)) + pkt.Data = dataVV if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: dataVV, - TransportHeader: buffer.View(icmpv6), - }) + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) } // checkV4MappedLocked determines the effective network protocol and converts @@ -748,14 +747,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h := header.ICMPv4(pkt.TransportHeader) + h := header.ICMPv4(pkt.TransportHeader().View()) + // TODO(b/129292233): Determine if len(h) check is still needed after early + // parsing. if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h := header.ICMPv6(pkt.TransportHeader) + h := header.ICMPv6(pkt.TransportHeader().View()) + // TODO(b/129292233): Determine if len(h) check is still needed after early + // parsing. if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -791,7 +794,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } // ICMP socket's data includes ICMP header. - packet.data = pkt.TransportHeader.ToVectorisedView() + packet.data = pkt.TransportHeader().View().ToVectorisedView() packet.data.Append(pkt.Data) e.rcvList.PushBack(packet) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index df478115d..1b03ad6bb 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -433,9 +433,9 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet // TODO(gvisor.dev/issue/173): Return network protocol. - if len(pkt.LinkHeader) > 0 { + if !pkt.LinkHeader().View().IsEmpty() { // Get info directly from the ethernet header. - hdr := header.Ethernet(pkt.LinkHeader) + hdr := header.Ethernet(pkt.LinkHeader().View()) packet.senderAddr = tcpip.FullAddress{ NIC: nicID, Addr: tcpip.Address(hdr.SourceAddress()), @@ -458,9 +458,14 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, case tcpip.PacketHost: packet.data = pkt.Data case tcpip.PacketOutgoing: - // Strip Link Header from the Header. - pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):]) - combinedVV := pkt.Header.View().ToVectorisedView() + // Strip Link Header. + var combinedVV buffer.VectorisedView + if v := pkt.NetworkHeader().View(); !v.IsEmpty() { + combinedVV.AppendView(v) + } + if v := pkt.TransportHeader().View(); !v.IsEmpty() { + combinedVV.AppendView(v) + } combinedVV.Append(pkt.Data) packet.data = combinedVV default: @@ -471,9 +476,8 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Raw packets need their ethernet headers prepended before // queueing. var linkHeader buffer.View - var combinedVV buffer.VectorisedView if pkt.PktType != tcpip.PacketOutgoing { - if len(pkt.LinkHeader) == 0 { + if pkt.LinkHeader().View().IsEmpty() { // We weren't provided with an actual ethernet header, // so fake one. ethFields := header.EthernetFields{ @@ -485,19 +489,14 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, fakeHeader.Encode(ðFields) linkHeader = buffer.View(fakeHeader) } else { - linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...) } - combinedVV = linkHeader.ToVectorisedView() - } - if pkt.PktType == tcpip.PacketOutgoing { - // For outgoing packets the Link, Network and Transport - // headers are in the pkt.Header fields normally unless - // a Raw socket is in use in which case pkt.Header could - // be nil. - combinedVV.AppendView(pkt.Header.View()) + combinedVV := linkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + } else { + packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views()) } - combinedVV.Append(pkt.Data) - packet.data = combinedVV } packet.timestampNS = ep.stack.Clock().NowNanoseconds() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index f85a68554..edc2b5b61 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -352,18 +352,23 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } if e.hdrIncluded { - if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), - }); err != nil { + }) + if err := route.WriteHeaderIncludedPacket(pkt); err != nil { return 0, nil, err } } else { - hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - Data: buffer.View(payloadBytes).ToVectorisedView(), - Owner: e.owner, - }); err != nil { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(route.MaxHeaderLength()), + Data: buffer.View(payloadBytes).ToVectorisedView(), + }) + pkt.Owner = e.owner + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: e.TransProto, + TTL: route.DefaultTTL(), + TOS: stack.DefaultTOS, + }, pkt); err != nil { return 0, nil, err } } @@ -691,12 +696,13 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // slice. Save/restore doesn't support overlapping slices and will fail. var combinedVV buffer.VectorisedView if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { - headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader)) - headers = append(headers, pkt.NetworkHeader...) - headers = append(headers, pkt.TransportHeader...) + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + headers := make(buffer.View, 0, len(network)+len(transport)) + headers = append(headers, network...) + headers = append(headers, transport...) combinedVV = headers.ToVectorisedView() } else { - combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView() + combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() } combinedVV.Append(pkt.Data) packet.data = combinedVV diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 46702906b..290172ac9 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -746,11 +746,7 @@ func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedV func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { optLen := len(tf.opts) - hdr := &pkt.Header - packetSize := pkt.Data.Size() - // Initialize the header. - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) - pkt.TransportHeader = buffer.View(tcp) + tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen)) tcp.Encode(&header.TCPFields{ SrcPort: tf.id.LocalPort, DstPort: tf.id.RemotePort, @@ -762,8 +758,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta }) copy(tcp[header.TCPMinimumSize:], tf.opts) - length := uint16(hdr.UsedLength() + packetSize) - xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) + xsum := r.PseudoHeaderChecksum(ProtocolNumber, uint16(pkt.Size())) // Only calculate the checksum if offloading isn't supported. if gso != nil && gso.NeedsCsum { // This is called CHECKSUM_PARTIAL in the Linux kernel. We @@ -801,17 +796,18 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso packetSize = size } size -= packetSize - var pkt stack.PacketBuffer - pkt.Header = buffer.NewPrependable(hdrSize) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: hdrSize, + }) pkt.Hash = tf.txHash pkt.Owner = owner pkt.EgressRoute = r pkt.GSOOptions = gso pkt.NetworkProtocolNumber = r.NetworkProtocolNumber() data.ReadToVV(&pkt.Data, packetSize) - buildTCPHdr(r, tf, &pkt, gso) + buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) - pkts.PushBack(&pkt) + pkts.PushBack(pkt) } if tf.ttl == 0 { @@ -837,12 +833,12 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac return sendTCPBatch(r, tf, data, gso, owner) } - pkt := &stack.PacketBuffer{ - Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), - Data: data, - Hash: tf.txHash, - Owner: owner, - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen, + Data: data, + }) + pkt.Hash = tf.txHash + pkt.Owner = owner buildTCPHdr(r, tf, pkt, gso) if tf.ttl == 0 { diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 49a673b42..c5afa2680 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -21,7 +21,6 @@ package tcp import ( - "fmt" "runtime" "strings" "time" @@ -547,22 +546,22 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter { // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + // TCP header is variable length, peek at it first. + hdrLen := header.TCPMinimumSize + hdr, ok := pkt.Data.PullUp(hdrLen) if !ok { return false } // If the header has options, pull those up as well. if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { - hdr, ok = pkt.Data.PullUp(offset) - if !ok { - panic(fmt.Sprintf("There should be at least %d bytes in pkt.Data.", offset)) - } + // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of + // packets. + hdrLen = offset } - pkt.TransportHeader = hdr - pkt.Data.TrimFront(len(hdr)) - return true + _, ok = pkt.TransportHeader().Consume(hdrLen) + return ok } // NewProtocol returns a TCP transport protocol. diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index bb60dc29d..94307d31a 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -68,7 +68,7 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB route: r.Clone(), } s.data = pkt.Data.Clone(s.views[:]) - s.hdr = header.TCP(pkt.TransportHeader) + s.hdr = header.TCP(pkt.TransportHeader().View()) s.rcvdTime = time.Now() return s } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 37e7767d6..927bc71e0 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -257,8 +257,8 @@ func (c *Context) GetPacket() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) @@ -284,8 +284,8 @@ func (c *Context) GetPacketNonBlocking() []byte { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) return b @@ -318,9 +318,10 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt copy(icmp[header.ICMPv4PayloadOffset:], p2) // Inject packet. - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // BuildSegment builds a TCP segment based on the given Headers and payload. @@ -374,26 +375,29 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp // SendSegment sends a TCP segment that has already been built and written to a // buffer.VectorisedView. func (c *Context) SendSegment(s buffer.VectorisedView) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: s, }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendPacket builds and sends a TCP segment(with the provided payload & TCP // headers) in an IPv4 packet via the link layer endpoint. func (c *Context) SendPacket(payload []byte, h *Headers) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: c.BuildSegment(payload, h), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendPacketWithAddrs builds and sends a TCP segment(with the provided payload // & TCPheaders) in an IPv4 packet via the link layer endpoint using the // provided source and destination IPv4 addresses. func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: c.BuildSegmentWithAddrs(payload, h, src, dst), }) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt) } // SendAck sends an ACK packet. @@ -514,9 +518,8 @@ func (c *Context) GetV6Packet() []byte { if p.Proto != ipv6.ProtocolNumber { c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } - b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) - copy(b, p.Pkt.Header.View()) - copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) return b @@ -566,9 +569,10 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp t.SetChecksum(^t.CalculateChecksum(xsum)) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), }) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt) } // CreateConnected creates a connected TCP endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 4a2b6c03a..73608783c 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -986,13 +986,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { - // Allocate a buffer for the UDP header. - hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()), + Data: data, + }) + pkt.Owner = owner - // Initialize the header. - udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + // Initialize the UDP header. + udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize)) - length := uint16(hdr.UsedLength() + data.Size()) + length := uint16(pkt.Size()) udp.Encode(&header.UDPFields{ SrcPort: localPort, DstPort: remotePort, @@ -1019,12 +1022,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u Protocol: ProtocolNumber, TTL: ttl, TOS: tos, - }, &stack.PacketBuffer{ - Header: hdr, - Data: data, - TransportHeader: buffer.View(udp), - Owner: owner, - }); err != nil { + }, pkt); err != nil { r.Stats().UDP.PacketSendErrors.Increment() return err } @@ -1372,7 +1370,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. - hdr := header.UDP(pkt.TransportHeader) + hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. e.stack.Stats().UDP.MalformedPacketsReceived.Increment() @@ -1443,9 +1441,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Save any useful information from the network header to the packet. switch r.NetProto { case header.IPv4ProtocolNumber: - packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() + packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS() case header.IPv6ProtocolNumber: - packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() + packet.tos, _ = header.IPv6(pkt.NetworkHeader().View()).TOS() } // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 0e7464e3a..63d4bed7c 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -82,7 +82,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - hdr := header.UDP(pkt.TransportHeader) + hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { // Malformed packet. r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() @@ -130,7 +130,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() + payloadLen := pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size() + pkt.Data.Size() if payloadLen > available { payloadLen = available } @@ -139,22 +139,21 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans // For example, a raw or packet socket may use what UDP // considers an unreachable destination. Thus we deep copy pkt // to prevent multiple ownership and SR errors. - newHeader := append(buffer.View(nil), pkt.NetworkHeader...) - newHeader = append(newHeader, pkt.TransportHeader...) + newHeader := append(buffer.View(nil), pkt.NetworkHeader().View()...) + newHeader = append(newHeader, pkt.TransportHeader().View()...) payload := newHeader.ToVectorisedView() payload.AppendView(pkt.Data.ToView()) payload.CapLength(payloadLen) - hdr := buffer.NewPrependable(headerLen) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) - pkt.SetType(header.ICMPv4DstUnreachable) - pkt.SetCode(header.ICMPv4PortUnreachable) - pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - TransportHeader: buffer.View(pkt), - Data: payload, + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, }) + icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4PortUnreachable) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) case header.IPv6AddressSize: if !r.Stack().AllowICMPMessage() { @@ -175,24 +174,24 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize available := int(mtu) - headerLen - payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + payloadLen := len(network) + len(transport) + pkt.Data.Size() if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(len(pkt.NetworkHeader)+len(pkt.TransportHeader), []buffer.View{pkt.NetworkHeader, pkt.TransportHeader}) + payload := buffer.NewVectorisedView(len(network)+len(transport), []buffer.View{network, transport}) payload.Append(pkt.Data) payload.CapLength(payloadLen) - hdr := buffer.NewPrependable(headerLen) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize)) - pkt.SetType(header.ICMPv6DstUnreachable) - pkt.SetCode(header.ICMPv6PortUnreachable) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ - Header: hdr, - TransportHeader: buffer.View(pkt), - Data: payload, + icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: headerLen, + Data: payload, }) + icmpHdr := header.ICMPv6(icmpPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6PortUnreachable) + icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, icmpPkt.Data)) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt) } return true } @@ -215,14 +214,8 @@ func (*protocol) Wait() {} // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - h, ok := pkt.Data.PullUp(header.UDPMinimumSize) - if !ok { - // Packet is too small - return false - } - pkt.TransportHeader = h - pkt.Data.TrimFront(header.UDPMinimumSize) - return true + _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize) + return ok } // NewProtocol returns a UDP transport protocol. diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 1a32622ca..71776d6db 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -388,8 +388,8 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() h := flow.header4Tuple(outgoing) checkers = append( @@ -410,14 +410,14 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { buf := c.buildV4Packet(payload, &h) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } else { buf := c.buildV6Packet(payload, &h) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) } } @@ -804,9 +804,9 @@ func TestV4ReadSelfSource(t *testing.T) { h.srcAddr = h.dstAddr buf := c.buildV4Packet(payload, &h) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) @@ -1766,9 +1766,8 @@ func TestV4UnknownDestination(t *testing.T) { return } - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1844,9 +1843,8 @@ func TestV6UnknownDestination(t *testing.T) { return } - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + pkt := vv.ToView() if got, want := len(pkt), header.IPv6MinimumMTU; got > want { t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) } @@ -1897,9 +1895,9 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetLength(u.Length() + 1) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { @@ -1952,9 +1950,9 @@ func TestShortHeader(t *testing.T) { copy(buf[header.IPv6MinimumSize:], udpHdr) // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) @@ -1986,9 +1984,9 @@ func TestIncrementChecksumErrorsV4(t *testing.T) { } } - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2019,9 +2017,9 @@ func TestIncrementChecksumErrorsV6(t *testing.T) { u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetChecksum(u.Checksum() + 1) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2049,9 +2047,9 @@ func TestPayloadModifiedV4(t *testing.T) { buf := c.buildV4Packet(payload, &h) // Modify the payload so that the checksum value in the UDP header will be incorrect. buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2079,9 +2077,9 @@ func TestPayloadModifiedV6(t *testing.T) { buf := c.buildV6Packet(payload, &h) // Modify the payload so that the checksum value in the UDP header will be incorrect. buf[len(buf)-1]++ - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2110,9 +2108,9 @@ func TestChecksumZeroV4(t *testing.T) { // Set the checksum field in the UDP header to zero. u := header.UDP(buf[header.IPv4MinimumSize:]) u.SetChecksum(0) - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 0 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { @@ -2141,9 +2139,9 @@ func TestChecksumZeroV6(t *testing.T) { // Set the checksum field in the UDP header to zero. u := header.UDP(buf[header.IPv6MinimumSize:]) u.SetChecksum(0) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), - }) + })) const want = 1 if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { -- cgit v1.2.3 From 1736b2208f7eeec56531a9877ca53dc784fed544 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 14 Aug 2020 17:27:23 -0700 Subject: Use a single NetworkEndpoint per NIC per protocol The NetworkEndpoint does not need to be created for each address. Most of the work the NetworkEndpoint does is address agnostic. PiperOrigin-RevId: 326759605 --- pkg/tcpip/network/BUILD | 1 + pkg/tcpip/network/arp/arp.go | 15 +------ pkg/tcpip/network/ip_test.go | 77 +++++++++++++++++---------------- pkg/tcpip/network/ipv4/icmp.go | 7 +-- pkg/tcpip/network/ipv4/ipv4.go | 20 +-------- pkg/tcpip/network/ipv6/icmp.go | 7 +-- pkg/tcpip/network/ipv6/icmp_test.go | 6 +-- pkg/tcpip/network/ipv6/ipv6.go | 20 ++------- pkg/tcpip/network/ipv6/ndp_test.go | 5 +-- pkg/tcpip/stack/forwarder_test.go | 18 ++------ pkg/tcpip/stack/ndp.go | 8 ++-- pkg/tcpip/stack/nic.go | 86 +++++++++++++++++-------------------- pkg/tcpip/stack/nic_test.go | 30 ++++--------- pkg/tcpip/stack/registration.go | 8 +--- pkg/tcpip/stack/stack.go | 6 +-- pkg/tcpip/stack/stack_test.go | 20 ++------- pkg/tcpip/transport/udp/udp_test.go | 10 +---- 17 files changed, 123 insertions(+), 221 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 6a4839fb8..46083925c 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -12,6 +12,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 1ad788a17..920872c3f 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -66,14 +66,6 @@ func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { return e.linkEP.Capabilities() } -func (e *endpoint) ID() *stack.NetworkEndpointID { - return &stack.NetworkEndpointID{ProtocolAddress} -} - -func (e *endpoint) PrefixLen() int { - return 0 -} - func (e *endpoint) MaxHeaderLength() uint16 { return e.linkEP.MaxHeaderLength() + header.ARPSize } @@ -142,16 +134,13 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { - if addrWithPrefix.Address != ProtocolAddress { - return nil, tcpip.ErrBadLocalAddress - } +func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { return &endpoint{ protocol: p, nicID: nicID, linkEP: sender, linkAddrCache: linkAddrCache, - }, nil + } } // LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 491d936a1..9007346fe 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -41,6 +42,7 @@ const ( ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + nicID = 1 ) // testObject implements two interfaces: LinkEndpoint and TransportDispatcher. @@ -195,15 +197,15 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, }) - s.CreateNIC(1, loopback.New()) - s.AddAddress(1, ipv4.ProtocolNumber, local) + s.CreateNIC(nicID, loopback.New()) + s.AddAddress(nicID, ipv4.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, Gateway: ipv4Gateway, NIC: 1, }}) - return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) + return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) } func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { @@ -211,31 +213,45 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, }) - s.CreateNIC(1, loopback.New()) - s.AddAddress(1, ipv6.ProtocolNumber, local) + s.CreateNIC(nicID, loopback.New()) + s.AddAddress(nicID, ipv6.ProtocolNumber, local) s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv6EmptySubnet, Gateway: ipv6Gateway, NIC: 1, }}) - return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) + return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */) } -func buildDummyStack() *stack.Stack { - return stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, +func buildDummyStack(t *testing.T) *stack.Stack { + t.Helper() + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()}, }) + e := channel.New(0, 1280, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, localIpv4Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, localIpv4Addr, err) + } + + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, localIpv6Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, localIpv6Addr, err) + } + + return s } func TestIPv4Send(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } + ep := proto.NewEndpoint(nicID, nil, nil, &o, buildDummyStack(t)) + defer ep.Close() // Allocate and initialize the payload view. payload := buffer.NewView(100) @@ -271,10 +287,8 @@ func TestIPv4Send(t *testing.T) { func TestIPv4Receive(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } + ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) + defer ep.Close() totalLen := header.IPv4MinimumSize + 30 view := buffer.NewView(totalLen) @@ -343,10 +357,7 @@ func TestIPv4ReceiveControl(t *testing.T) { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } + ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) defer ep.Close() const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize @@ -407,10 +418,8 @@ func TestIPv4ReceiveControl(t *testing.T) { func TestIPv4FragmentationReceive(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } + ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) + defer ep.Close() totalLen := header.IPv4MinimumSize + 24 @@ -486,10 +495,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { func TestIPv6Send(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } + ep := proto.NewEndpoint(nicID, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t)) + defer ep.Close() // Allocate and initialize the payload view. payload := buffer.NewView(100) @@ -525,10 +532,8 @@ func TestIPv6Send(t *testing.T) { func TestIPv6Receive(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } + ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) + defer ep.Close() totalLen := header.IPv6MinimumSize + 30 view := buffer.NewView(totalLen) @@ -606,11 +611,7 @@ func TestIPv6ReceiveControl(t *testing.T) { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack()) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - + ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) defer ep.Close() dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 067d770f3..b5659a36b 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -37,8 +37,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // false. // // Drop packet if it doesn't have the basic IPv4 header or if the - // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + // original source address doesn't match an address we own. + src := hdr.SourceAddress() + if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 { return } @@ -53,7 +54,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 3cd48ceb3..79872ec9a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -52,8 +52,6 @@ const ( type endpoint struct { nicID tcpip.NICID - id stack.NetworkEndpointID - prefixLen int linkEP stack.LinkEndpoint dispatcher stack.TransportDispatcher protocol *protocol @@ -61,18 +59,14 @@ type endpoint struct { } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { - e := &endpoint{ +func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { + return &endpoint{ nicID: nicID, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, dispatcher: dispatcher, protocol: p, stack: st, } - - return e, nil } // DefaultTTL is the default time-to-live value for this endpoint. @@ -96,16 +90,6 @@ func (e *endpoint) NICID() tcpip.NICID { return e.nicID } -// ID returns the ipv4 endpoint ID. -func (e *endpoint) ID() *stack.NetworkEndpointID { - return &e.id -} - -// PrefixLen returns the ipv4 endpoint subnet prefix length in bits. -func (e *endpoint) PrefixLen() int { - return e.prefixLen -} - // MaxHeaderLength returns the maximum length needed by ipv4 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 39ae19295..66d3a953a 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -39,8 +39,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // is truncated, which would cause IsValid to return false. // // Drop packet if it doesn't have the basic IPv6 header or if the - // original source address doesn't match the endpoint's address. - if hdr.SourceAddress() != e.id.LocalAddress { + // original source address doesn't match an address we own. + src := hdr.SourceAddress() + if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 { return } @@ -67,7 +68,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack } // Deliver the control packet to the transport endpoint. - e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) { diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 2a2f7de01..9e4eeea77 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -114,10 +114,8 @@ func TestICMPCounts(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) - } + ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) + defer ep.Close() r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 0ade655b2..0eafe9790 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -46,12 +46,11 @@ const ( type endpoint struct { nicID tcpip.NICID - id stack.NetworkEndpointID - prefixLen int linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache dispatcher stack.TransportDispatcher protocol *protocol + stack *stack.Stack } // DefaultTTL is the default hop limit for this endpoint. @@ -70,16 +69,6 @@ func (e *endpoint) NICID() tcpip.NICID { return e.nicID } -// ID returns the ipv6 endpoint ID. -func (e *endpoint) ID() *stack.NetworkEndpointID { - return &e.id -} - -// PrefixLen returns the ipv6 endpoint subnet prefix length in bits. -func (e *endpoint) PrefixLen() int { - return e.prefixLen -} - // Capabilities implements stack.NetworkEndpoint.Capabilities. func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { return e.linkEP.Capabilities() @@ -464,16 +453,15 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { return &endpoint{ nicID: nicID, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, linkAddrCache: linkAddrCache, dispatcher: dispatcher, protocol: p, - }, nil + stack: st, + } } // SetOption implements NetworkProtocol.SetOption. diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 2efa82e60..af71a7d6b 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -63,10 +63,7 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) - if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) - } + ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) return s, ep } diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 944f622fd..5a684eb9d 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -46,8 +46,6 @@ const ( // protocol. They're all one byte fields to simplify parsing. type fwdTestNetworkEndpoint struct { nicID tcpip.NICID - id NetworkEndpointID - prefixLen int proto *fwdTestNetworkProtocol dispatcher TransportDispatcher ep LinkEndpoint @@ -61,18 +59,10 @@ func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID { return f.nicID } -func (f *fwdTestNetworkEndpoint) PrefixLen() int { - return f.prefixLen -} - func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID { - return &f.id -} - func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) @@ -99,7 +89,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH // endpoint. b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen) b[dstAddrOffset] = r.RemoteAddress[0] - b[srcAddrOffset] = f.id.LocalAddress[0] + b[srcAddrOffset] = r.LocalAddress[0] b[protocolNumberOffset] = byte(params.Protocol) return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) @@ -151,15 +141,13 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } -func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { +func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint { return &fwdTestNetworkEndpoint{ nicID: nicID, - id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, proto: f, dispatcher: dispatcher, ep: ep, - }, nil + } } func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error { diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 93567806b..b0873d1af 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -728,7 +728,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) + r := makeRoute(header.IPv6ProtocolNumber, ref.address(), snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() // Route should resolve immediately since snmc is a multicast address so a @@ -1353,7 +1353,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla return false } - stableAddr := prefixState.stableAddr.ref.ep.ID().LocalAddress + stableAddr := prefixState.stableAddr.ref.address() now := time.Now() // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary @@ -1690,7 +1690,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr prefix := addr.Subnet() state, ok := ndp.slaacPrefixes[prefix] - if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.ep.ID().LocalAddress { + if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.address() { return } @@ -1867,7 +1867,7 @@ func (ndp *ndpState) startSolicitingRouters() { } ndp.nic.mu.Unlock() - localAddr := ref.ep.ID().LocalAddress + localAddr := ref.address() r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 2315ea5b9..10d2b7964 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -45,8 +45,9 @@ type NIC struct { linkEP LinkEndpoint context NICContext - stats NICStats - neigh *neighborCache + stats NICStats + neigh *neighborCache + networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint mu struct { sync.RWMutex @@ -114,12 +115,13 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // of IPv6 is supported on this endpoint's LinkEndpoint. nic := &NIC{ - stack: stack, - id: id, - name: name, - linkEP: ep, - context: ctx, - stats: makeNICStats(), + stack: stack, + id: id, + name: name, + linkEP: ep, + context: ctx, + stats: makeNICStats(), + networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) @@ -140,7 +142,9 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.mu.packetEPs[netProto] = []PacketEndpoint{} } for _, netProto := range stack.networkProtocols { - nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{} + netNum := netProto.Number() + nic.mu.packetEPs[netNum] = nil + nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nic, ep, stack) } // Check for Neighbor Unreachability Detection support. @@ -205,7 +209,7 @@ func (n *NIC) disableLocked() *tcpip.Error { // Stop DAD for all the unicast IPv6 endpoints that are in the // permanentTentative state. for _, r := range n.mu.endpoints { - if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) { + if addr := r.address(); r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) { n.mu.ndp.stopDuplicateAddressDetection(addr) } } @@ -300,7 +304,7 @@ func (n *NIC) enable() *tcpip.Error { // Addresses may have aleady completed DAD but in the time since the NIC was // last enabled, other devices may have acquired the same addresses. for _, r := range n.mu.endpoints { - addr := r.ep.ID().LocalAddress + addr := r.address() if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) { continue } @@ -362,6 +366,11 @@ func (n *NIC) remove() *tcpip.Error { } } + // Release any resources the network endpoint may hold. + for _, ep := range n.networkEndpoints { + ep.Close() + } + // Detach from link endpoint, so no packet comes in. n.linkEP.Attach(nil) @@ -510,7 +519,7 @@ func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNe continue } - addr := r.ep.ID().LocalAddress + addr := r.address() scope, err := header.ScopeForIPv6Address(addr) if err != nil { // Should never happen as we got r from the primary IPv6 endpoint list and @@ -539,10 +548,10 @@ func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNe sb := cs[j] // Prefer same address as per RFC 6724 section 5 rule 1. - if sa.ref.ep.ID().LocalAddress == remoteAddr { + if sa.ref.address() == remoteAddr { return true } - if sb.ref.ep.ID().LocalAddress == remoteAddr { + if sb.ref.address() == remoteAddr { return false } @@ -819,17 +828,11 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } - netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol] + ep, ok := n.networkEndpoints[protocolAddress.Protocol] if !ok { return nil, tcpip.ErrUnknownProtocol } - // Create the new network endpoint. - ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP, n.stack) - if err != nil { - return nil, err - } - isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) // If the address is an IPv6 address and it is a permanent address, @@ -842,6 +845,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar ref := &referencedNetworkEndpoint{ refs: 1, + addr: protocolAddress.AddressWithPrefix, ep: ep, nic: n, protocol: protocolAddress.Protocol, @@ -898,7 +902,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { defer n.mu.RUnlock() addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints)) - for nid, ref := range n.mu.endpoints { + for _, ref := range n.mu.endpoints { // Don't include tentative, expired or temporary endpoints to // avoid confusion and prevent the caller from using those. switch ref.getKind() { @@ -907,11 +911,8 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { } addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: ref.protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: nid.LocalAddress, - PrefixLen: ref.ep.PrefixLen(), - }, + Protocol: ref.protocol, + AddressWithPrefix: ref.addrWithPrefix(), }) } return addrs @@ -934,11 +935,8 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { } addrs = append(addrs, tcpip.ProtocolAddress{ - Protocol: proto, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: ref.ep.ID().LocalAddress, - PrefixLen: ref.ep.PrefixLen(), - }, + Protocol: proto, + AddressWithPrefix: ref.addrWithPrefix(), }) } } @@ -969,10 +967,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit } if !ref.deprecated { - return tcpip.AddressWithPrefix{ - Address: ref.ep.ID().LocalAddress, - PrefixLen: ref.ep.PrefixLen(), - } + return ref.addrWithPrefix() } if deprecatedEndpoint == nil { @@ -981,10 +976,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit } if deprecatedEndpoint != nil { - return tcpip.AddressWithPrefix{ - Address: deprecatedEndpoint.ep.ID().LocalAddress, - PrefixLen: deprecatedEndpoint.ep.PrefixLen(), - } + return deprecatedEndpoint.addrWithPrefix() } return tcpip.AddressWithPrefix{} @@ -1048,7 +1040,7 @@ func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb Prim } func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { - id := *r.ep.ID() + id := NetworkEndpointID{LocalAddress: r.address()} // Nothing to do if the reference has already been replaced with a different // one. This happens in the case where 1) this endpoint's ref count hit zero @@ -1072,8 +1064,6 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { break } } - - r.ep.Close() } func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { @@ -1718,6 +1708,7 @@ const ( type referencedNetworkEndpoint struct { ep NetworkEndpoint + addr tcpip.AddressWithPrefix nic *NIC protocol tcpip.NetworkProtocolNumber @@ -1743,11 +1734,12 @@ type referencedNetworkEndpoint struct { deprecated bool } +func (r *referencedNetworkEndpoint) address() tcpip.Address { + return r.addr.Address +} + func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix { - return tcpip.AddressWithPrefix{ - Address: r.ep.ID().LocalAddress, - PrefixLen: r.ep.PrefixLen(), - } + return r.addr } func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 0870c8d9c..d312a79eb 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -101,11 +101,9 @@ var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) // We use this instead of ipv6.endpoint because the ipv6 package depends on // the stack package which this test lives in, causing a cyclic dependency. type testIPv6Endpoint struct { - nicID tcpip.NICID - id NetworkEndpointID - prefixLen int - linkEP LinkEndpoint - protocol *testIPv6Protocol + nicID tcpip.NICID + linkEP LinkEndpoint + protocol *testIPv6Protocol } // DefaultTTL implements NetworkEndpoint.DefaultTTL. @@ -146,16 +144,6 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip return tcpip.ErrNotSupported } -// ID implements NetworkEndpoint.ID. -func (e *testIPv6Endpoint) ID() *NetworkEndpointID { - return &e.id -} - -// PrefixLen implements NetworkEndpoint.PrefixLen. -func (e *testIPv6Endpoint) PrefixLen() int { - return e.prefixLen -} - // NICID implements NetworkEndpoint.NICID. func (e *testIPv6Endpoint) NICID() tcpip.NICID { return e.nicID @@ -204,14 +192,12 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) } // NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) { +func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint { return &testIPv6Endpoint{ - nicID: nicID, - id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, - linkEP: linkEP, - protocol: p, - }, nil + nicID: nicID, + linkEP: linkEP, + protocol: p, + } } // SetOption implements NetworkProtocol.SetOption. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 4570e8969..aca2f77f8 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -262,12 +262,6 @@ type NetworkEndpoint interface { // header to the given destination address. It takes ownership of pkt. WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error - // ID returns the network protocol endpoint ID. - ID() *NetworkEndpointID - - // PrefixLen returns the network endpoint's subnet prefix length in bits. - PrefixLen() int - // NICID returns the id of the NIC this endpoint belongs to. NICID() tcpip.NICID @@ -304,7 +298,7 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) (NetworkEndpoint, *tcpip.Error) + NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 5b19c5d59..9a1c8e409 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1321,7 +1321,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok && nic.enabled() { if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { - return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil + return makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil } } } else { @@ -1334,10 +1334,10 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if len(remoteAddr) == 0 { // If no remote address was provided, then the route // provided will refer to the link local address. - remoteAddr = ref.ep.ID().LocalAddress + remoteAddr = ref.address() } - r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) + r := makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr) if len(route.Gateway) > 0 { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 0273b3c63..b5a603098 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -70,8 +70,6 @@ const ( // protocol. They're all one byte fields to simplify parsing. type fakeNetworkEndpoint struct { nicID tcpip.NICID - id stack.NetworkEndpointID - prefixLen int proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher ep stack.LinkEndpoint @@ -85,21 +83,13 @@ func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { return f.nicID } -func (f *fakeNetworkEndpoint) PrefixLen() int { - return f.prefixLen -} - func (*fakeNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { - return &f.id -} - func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. - f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ + f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++ // Handle control packets. if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) { @@ -145,7 +135,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // endpoint. hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) hdr[dstAddrOffset] = r.RemoteAddress[0] - hdr[srcAddrOffset] = f.id.LocalAddress[0] + hdr[srcAddrOffset] = r.LocalAddress[0] hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { @@ -208,15 +198,13 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint { return &fakeNetworkEndpoint{ nicID: nicID, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, proto: f, dispatcher: dispatcher, ep: ep, - }, nil + } } func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error { diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 71776d6db..f87d99d5a 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1469,13 +1469,10 @@ func TestTTL(t *testing.T) { } else { p = ipv6.NewProtocol() } - ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{ + ep := p.NewEndpoint(0, nil, nil, nil, stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, })) - if err != nil { - t.Fatal(err) - } wantTTL = ep.DefaultTTL() ep.Close() } @@ -1505,13 +1502,10 @@ func TestSetTTL(t *testing.T) { } else { p = ipv6.NewProtocol() } - ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{ + ep := p.NewEndpoint(0, nil, nil, nil, stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, })) - if err != nil { - t.Fatal(err) - } ep.Close() testWrite(c, flow, checker.TTL(wantTTL)) -- cgit v1.2.3 From 9a7b5830aa063895f67ca0fdf653a46906374613 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Sat, 15 Aug 2020 00:04:30 -0700 Subject: Don't support address ranges Previously the netstack supported assignment of a range of addresses. This feature is not used so remove it. PiperOrigin-RevId: 326791119 --- pkg/tcpip/stack/nic.go | 63 ++------------ pkg/tcpip/stack/stack.go | 29 ------- pkg/tcpip/stack/stack_test.go | 194 ------------------------------------------ 3 files changed, 8 insertions(+), 278 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 10d2b7964..8a9a085f0 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -51,13 +51,12 @@ type NIC struct { mu struct { sync.RWMutex - enabled bool - spoofing bool - promiscuous bool - primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - addressRanges []tcpip.Subnet - mcastJoins map[NetworkEndpointID]uint32 + enabled bool + spoofing bool + promiscuous bool + primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint + endpoints map[NetworkEndpointID]*referencedNetworkEndpoint + mcastJoins map[NetworkEndpointID]uint32 // packetEPs is protected by mu, but the contained PacketEndpoint // values are not. packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint @@ -670,25 +669,6 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // A usable reference was not found, create a temporary one if requested by // the caller or if the address is found in the NIC's subnets. createTempEP := spoofingOrPromiscuous - if !createTempEP { - for _, sn := range n.mu.addressRanges { - // Skip the subnet address. - if address == sn.ID() { - continue - } - // For now just skip the broadcast address, until we support it. - // FIXME(b/137608825): Add support for sending/receiving directed - // (subnet) broadcast. - if address == sn.Broadcast() { - continue - } - if sn.Contains(address) { - createTempEP = true - break - } - } - } - n.mu.RUnlock() if !createTempEP { @@ -982,38 +962,11 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit return tcpip.AddressWithPrefix{} } -// AddAddressRange adds a range of addresses to n, so that it starts accepting -// packets targeted at the given addresses and network protocol. The range is -// given by a subnet address, and all addresses contained in the subnet are -// used except for the subnet address itself and the subnet's broadcast -// address. -func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { - n.mu.Lock() - n.mu.addressRanges = append(n.mu.addressRanges, subnet) - n.mu.Unlock() -} - -// RemoveAddressRange removes the given address range from n. -func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { - n.mu.Lock() - - // Use the same underlying array. - tmp := n.mu.addressRanges[:0] - for _, sub := range n.mu.addressRanges { - if sub != subnet { - tmp = append(tmp, sub) - } - } - n.mu.addressRanges = tmp - - n.mu.Unlock() -} - // AddressRanges returns the Subnets associated with this NIC. func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() - sns := make([]tcpip.Subnet, 0, len(n.mu.addressRanges)+len(n.mu.endpoints)) + sns := make([]tcpip.Subnet, 0, len(n.mu.endpoints)) for nid := range n.mu.endpoints { sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) if err != nil { @@ -1023,7 +976,7 @@ func (n *NIC) AddressRanges() []tcpip.Subnet { } sns = append(sns, sn) } - return append(sns, n.mu.addressRanges...) + return sns } // insertPrimaryEndpointLocked adds r to n's primary endpoint list as required diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 9a1c8e409..ae44cd5da 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1230,35 +1230,6 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc return nic.AddAddress(protocolAddress, peb) } -// AddAddressRange adds a range of addresses to the specified NIC. The range is -// given by a subnet address, and all addresses contained in the subnet are -// used except for the subnet address itself and the subnet's broadcast -// address. -func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { - s.mu.RLock() - defer s.mu.RUnlock() - - if nic, ok := s.nics[id]; ok { - nic.AddAddressRange(protocol, subnet) - return nil - } - - return tcpip.ErrUnknownNICID -} - -// RemoveAddressRange removes the range of addresses from the specified NIC. -func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { - s.mu.RLock() - defer s.mu.RUnlock() - - if nic, ok := s.nics[id]; ok { - nic.RemoveAddressRange(subnet) - return nil - } - - return tcpip.ErrUnknownNICID -} - // RemoveAddress removes an existing network-layer address from the specified // NIC. func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b5a603098..106645c50 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -23,7 +23,6 @@ import ( "math" "net" "sort" - "strings" "testing" "time" @@ -1641,149 +1640,6 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } } -// Add a range of addresses, then check that a packet is delivered. -func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - const localAddrByte byte = 0x01 - buf[dstAddrOffset] = localAddrByte - subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - testRecv(t, fakeNet, localAddrByte, ep, buf) -} - -func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { - t.Helper() - - // Loop over all addresses and check them. - numOfAddresses := 1 << uint(8-subnet.Prefix()) - if numOfAddresses < 1 || numOfAddresses > 255 { - t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) - } - - addrBytes := []byte(subnet.ID()) - for i := 0; i < numOfAddresses; i++ { - addr := tcpip.Address(addrBytes) - wantNicID := nicID - // The subnet and broadcast addresses are skipped. - if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() { - wantNicID = 0 - } - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID) - } - addrBytes[0]++ - } - - // Trying the next address should always fail since it is outside the range. - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = 0", fakeNetNumber, tcpip.Address(addrBytes), gotNicID) - } -} - -// Set a range of addresses, then remove it again, and check at each step that -// CheckLocalAddress returns the correct NIC for each address or zero if not -// existent. -func TestCheckLocalAddressForSubnet(t *testing.T) { - const nicID tcpip.NICID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}}) - } - - subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) - - if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */) - - if err := s.RemoveAddressRange(nicID, subnet); err != nil { - t.Fatal("RemoveAddressRange failed:", err) - } - - testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) -} - -// Set a range of addresses, then send a packet to a destination outside the -// range and then check it doesn't get delivered. -func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - { - subnet, err := tcpip.NewSubnet("\x00", "\x00") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - const localAddrByte byte = 0x01 - buf[dstAddrOffset] = localAddrByte - subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - testFailingRecv(t, fakeNet, localAddrByte, ep, buf) -} - func TestNetworkOptions(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, @@ -1827,56 +1683,6 @@ func TestNetworkOptions(t *testing.T) { } } -func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool { - ranges, ok := s.NICAddressRanges()[id] - if !ok { - return false - } - for _, r := range ranges { - if r == addrRange { - return true - } - } - return false -} - -func TestAddresRangeAddRemove(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, - }) - ep := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep); err != nil { - t.Fatal("CreateNIC failed:", err) - } - - addr := tcpip.Address("\x01\x01\x01\x01") - mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) - addrRange, err := tcpip.NewSubnet(addr, mask) - if err != nil { - t.Fatal("NewSubnet failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } - - if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil { - t.Fatal("AddAddressRange failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } - - if err := s.RemoveAddressRange(1, addrRange); err != nil { - t.Fatal("RemoveAddressRange failed:", err) - } - - if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { - t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) - } -} - func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { for _, addrLen := range []int{4, 16} { t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) { -- cgit v1.2.3 From e3c4bbd10a93ac824af3204c520437f3d5ff470c Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Mon, 17 Aug 2020 12:27:59 -0700 Subject: Remove address range functions Should have been removed in cl/326791119 https://github.com/google/gvisor/commit/9a7b5830aa063895f67ca0fdf653a46906374613 PiperOrigin-RevId: 327074156 --- pkg/tcpip/stack/nic.go | 18 ------------------ pkg/tcpip/stack/stack.go | 13 ------------- 2 files changed, 31 deletions(-) (limited to 'pkg/tcpip/stack') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8a9a085f0..728292782 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -19,7 +19,6 @@ import ( "math/rand" "reflect" "sort" - "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/sync" @@ -962,23 +961,6 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit return tcpip.AddressWithPrefix{} } -// AddressRanges returns the Subnets associated with this NIC. -func (n *NIC) AddressRanges() []tcpip.Subnet { - n.mu.RLock() - defer n.mu.RUnlock() - sns := make([]tcpip.Subnet, 0, len(n.mu.endpoints)) - for nid := range n.mu.endpoints { - sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) - if err != nil { - // This should never happen as the mask has been carefully crafted to - // match the address. - panic("Invalid endpoint subnet: " + err.Error()) - } - sns = append(sns, sn) - } - return sns -} - // insertPrimaryEndpointLocked adds r to n's primary endpoint list as required // by peb. // diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index ae44cd5da..a3f87c8af 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1102,19 +1102,6 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error { return nic.remove() } -// NICAddressRanges returns a map of NICIDs to their associated subnets. -func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { - s.mu.RLock() - defer s.mu.RUnlock() - - nics := map[tcpip.NICID][]tcpip.Subnet{} - - for id, nic := range s.nics { - nics[id] = append(nics[id], nic.AddressRanges()...) - } - return nics -} - // NICInfo captures the name and addresses assigned to a NIC. type NICInfo struct { Name string -- cgit v1.2.3