From 59e2c9f16a9a4cce2ecf8b6449a47316fdf76ca2 Mon Sep 17 00:00:00 2001 From: Ian Lewis Date: Tue, 27 Oct 2020 00:16:14 -0700 Subject: Add basic address deletion to netlink Updates #3921 PiperOrigin-RevId: 339195417 --- pkg/sentry/inet/inet.go | 6 +- pkg/sentry/inet/test_stack.go | 21 +++++++ pkg/sentry/socket/hostinet/stack.go | 13 +++-- pkg/sentry/socket/netlink/route/protocol.go | 56 +++++++++++++++++++ pkg/sentry/socket/netstack/stack.go | 85 ++++++++++++++++++++++------- pkg/tcpip/stack/stack.go | 14 +++++ pkg/tcpip/stack/stack_test.go | 84 ++++++++++++++++++++++++++++ pkg/tcpip/tcpip.go | 11 +++- 8 files changed, 262 insertions(+), 28 deletions(-) (limited to 'pkg') diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index fbe6d6aa6..f31277d30 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -32,9 +32,13 @@ type Stack interface { InterfaceAddrs() map[int32][]InterfaceAddr // AddInterfaceAddr adds an address to the network interface identified by - // index. + // idx. AddInterfaceAddr(idx int32, addr InterfaceAddr) error + // RemoveInterfaceAddr removes an address from the network interface + // identified by idx. + RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error + // SupportsIPv6 returns true if the stack supports IPv6 connectivity. SupportsIPv6() bool diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 1779cc6f3..9ebeba8a3 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -15,6 +15,9 @@ package inet import ( + "bytes" + "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -58,6 +61,24 @@ func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error { return nil } +// RemoveInterfaceAddr implements Stack.RemoveInterfaceAddr. +func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error { + interfaceAddrs, ok := s.InterfaceAddrsMap[idx] + if !ok { + return fmt.Errorf("unknown idx: %d", idx) + } + + var filteredAddrs []InterfaceAddr + for _, interfaceAddr := range interfaceAddrs { + if !bytes.Equal(interfaceAddr.Addr, addr.Addr) { + filteredAddrs = append(filteredAddrs, addr) + } + } + s.InterfaceAddrsMap[idx] = filteredAddrs + + return nil +} + // SupportsIPv6 implements Stack.SupportsIPv6. func (s *TestStack) SupportsIPv6() bool { return s.SupportsIPv6Flag diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 144ed593c..7e7857ac3 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -324,7 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { } // AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { +func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error { + return syserror.EACCES +} + +// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. +func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error { return syserror.EACCES } @@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. -func (s *Stack) SetTCPSACKEnabled(enabled bool) error { +func (s *Stack) SetTCPSACKEnabled(bool) error { return syserror.EACCES } @@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { } // SetTCPRecovery implements inet.Stack.SetTCPRecovery. -func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { +func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error { return syserror.EACCES } @@ -495,6 +500,6 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { } // SetForwarding implements inet.Stack.SetForwarding. -func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { +func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { return syserror.EACCES } diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index c71cce064..22216158e 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -423,6 +423,11 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin } attrs = rest + // NOTE: A netlink message will contain multiple header attributes. + // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent + // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the + // local interface address. We add the local interface address here + // and ignore the IFA_ADDRESS. switch ahdr.Type { case linux.IFA_LOCAL: err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ @@ -439,8 +444,57 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin } else if err != nil { return syserr.ErrInvalidArgument } + case linux.IFA_ADDRESS: + default: + return syserr.ErrNotSupported + } + } + return nil +} + +// delAddr handles RTM_DELADDR requests. +func (p *Protocol) delAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network stack. + return syserr.ErrProtocolNotSupported + } + + var ifa linux.InterfaceAddrMessage + attrs, ok := msg.GetData(&ifa) + if !ok { + return syserr.ErrInvalidArgument + } + + for !attrs.Empty() { + ahdr, value, rest, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument + } + attrs = rest + + // NOTE: A netlink message will contain multiple header attributes. + // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent + // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the + // local interface address. We use the local interface address to + // remove the address and ignore the IFA_ADDRESS. + switch ahdr.Type { + case linux.IFA_LOCAL: + err := stack.RemoveInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ + Family: ifa.Family, + PrefixLen: ifa.PrefixLen, + Flags: ifa.Flags, + Addr: value, + }) + if err != nil { + return syserr.ErrInvalidArgument + } + case linux.IFA_ADDRESS: + default: + return syserr.ErrNotSupported } } + return nil } @@ -485,6 +539,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms return p.dumpRoutes(ctx, msg, ms) case linux.RTM_NEWADDR: return p.newAddr(ctx, msg, ms) + case linux.RTM_DELADDR: + return p.delAddr(ctx, msg, ms) default: return syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 1028d2a6e..fa9ac9059 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -100,56 +100,101 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return nicAddrs } -// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { +// convertAddr converts an InterfaceAddr to a ProtocolAddress. +func convertAddr(addr inet.InterfaceAddr) (tcpip.ProtocolAddress, error) { var ( - protocol tcpip.NetworkProtocolNumber - address tcpip.Address + protocol tcpip.NetworkProtocolNumber + address tcpip.Address + protocolAddress tcpip.ProtocolAddress ) switch addr.Family { case linux.AF_INET: - if len(addr.Addr) < header.IPv4AddressSize { - return syserror.EINVAL + if len(addr.Addr) != header.IPv4AddressSize { + return protocolAddress, syserror.EINVAL } if addr.PrefixLen > header.IPv4AddressSize*8 { - return syserror.EINVAL + return protocolAddress, syserror.EINVAL } protocol = ipv4.ProtocolNumber - address = tcpip.Address(addr.Addr[:header.IPv4AddressSize]) - + address = tcpip.Address(addr.Addr) case linux.AF_INET6: - if len(addr.Addr) < header.IPv6AddressSize { - return syserror.EINVAL + if len(addr.Addr) != header.IPv6AddressSize { + return protocolAddress, syserror.EINVAL } if addr.PrefixLen > header.IPv6AddressSize*8 { - return syserror.EINVAL + return protocolAddress, syserror.EINVAL } protocol = ipv6.ProtocolNumber - address = tcpip.Address(addr.Addr[:header.IPv6AddressSize]) - + address = tcpip.Address(addr.Addr) default: - return syserror.ENOTSUP + return protocolAddress, syserror.ENOTSUP } - protocolAddress := tcpip.ProtocolAddress{ + protocolAddress = tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: address, PrefixLen: int(addr.PrefixLen), }, } + return protocolAddress, nil +} + +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + protocolAddress, err := convertAddr(addr) + if err != nil { + return err + } // Attach address to interface. - if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + nicID := tcpip.NICID(idx) + if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + return syserr.TranslateNetstackError(err).ToError() + } + + // Add route for local network if it doesn't exist already. + localRoute := tcpip.Route{ + Destination: protocolAddress.AddressWithPrefix.Subnet(), + Gateway: "", // No gateway for local network. + NIC: nicID, + } + + for _, rt := range s.Stack.GetRouteTable() { + if rt.Equal(localRoute) { + return nil + } + } + + // Local route does not exist yet. Add it. + s.Stack.AddRoute(localRoute) + + return nil +} + +// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. +func (s *Stack) RemoveInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + protocolAddress, err := convertAddr(addr) + if err != nil { + return err + } + + // Remove addresses matching the address and prefix. + nicID := tcpip.NICID(idx) + if err := s.Stack.RemoveAddress(nicID, protocolAddress.AddressWithPrefix.Address); err != nil { return syserr.TranslateNetstackError(err).ToError() } - // Add route for local network. - s.Stack.AddRoute(tcpip.Route{ + // Remove the corresponding local network route if it exists. + localRoute := tcpip.Route{ Destination: protocolAddress.AddressWithPrefix.Subnet(), Gateway: "", // No gateway for local network. - NIC: tcpip.NICID(idx), + NIC: nicID, + } + s.Stack.RemoveRoutes(func(rt tcpip.Route) bool { + return rt.Equal(localRoute) }) + return nil } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index d3f75cb36..e8f1c110e 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -830,6 +830,20 @@ func (s *Stack) AddRoute(route tcpip.Route) { s.routeTable = append(s.routeTable, route) } +// RemoveRoutes removes matching routes from the route table. +func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { + s.mu.Lock() + defer s.mu.Unlock() + + var filteredRoutes []tcpip.Route + for _, route := range s.routeTable { + if !match(route) { + filteredRoutes = append(filteredRoutes, route) + } + } + s.routeTable = filteredRoutes +} + // NewEndpoint creates a new transport layer endpoint of the given protocol. func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { t, ok := s.transportProtocols[transport] diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e75f58c64..4eed4ced4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -3672,3 +3672,87 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) } } + +// TestAddRoute tests Stack.AddRoute +func TestAddRoute(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + subnet1, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + + expected := []tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + } + + // Initialize the route table with one route. + s.SetRouteTable([]tcpip.Route{expected[0]}) + + // Add another route. + s.AddRoute(expected[1]) + + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} + +// TestRemoveRoutes tests Stack.RemoveRoutes +func TestRemoveRoutes(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + addressToRemove := tcpip.Address("\x01") + subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet3, err := tcpip.NewSubnet("\x02", "\x02") + if err != nil { + t.Fatal(err) + } + + // Initialize the route table with three routes. + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + {Destination: subnet3, Gateway: "\x00", NIC: 1}, + }) + + // Remove routes with the specific address. + s.RemoveRoutes(func(r tcpip.Route) bool { + return r.Destination.ID() == addressToRemove + }) + + expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}} + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index a7d54d3b9..ac4d39d3e 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -356,10 +356,9 @@ func (s *Subnet) IsBroadcast(address Address) bool { return s.Prefix() <= 30 && s.Broadcast() == address } -// Equal returns true if s equals o. -// -// Needed to use cmp.Equal on Subnet as its fields are unexported. +// Equal returns true if this Subnet is equal to the given Subnet. func (s Subnet) Equal(o Subnet) bool { + // If this changes, update Route.Equal accordingly. return s == o } @@ -1260,6 +1259,12 @@ func (r Route) String() string { return out.String() } +// Equal returns true if the given Route is equal to this Route. +func (r Route) Equal(to Route) bool { + // NOTE: This relies on the fact that r.Destination == to.Destination + return r == to +} + // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 -- cgit v1.2.3