diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/stack/stack.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 84 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 11 |
3 files changed, 106 insertions, 3 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index d3f75cb36..e8f1c110e 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -830,6 +830,20 @@ func (s *Stack) AddRoute(route tcpip.Route) { s.routeTable = append(s.routeTable, route) } +// RemoveRoutes removes matching routes from the route table. +func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { + s.mu.Lock() + defer s.mu.Unlock() + + var filteredRoutes []tcpip.Route + for _, route := range s.routeTable { + if !match(route) { + filteredRoutes = append(filteredRoutes, route) + } + } + s.routeTable = filteredRoutes +} + // NewEndpoint creates a new transport layer endpoint of the given protocol. func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { t, ok := s.transportProtocols[transport] diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e75f58c64..4eed4ced4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -3672,3 +3672,87 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) } } + +// TestAddRoute tests Stack.AddRoute +func TestAddRoute(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + subnet1, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + + expected := []tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + } + + // Initialize the route table with one route. + s.SetRouteTable([]tcpip.Route{expected[0]}) + + // Add another route. + s.AddRoute(expected[1]) + + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} + +// TestRemoveRoutes tests Stack.RemoveRoutes +func TestRemoveRoutes(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + addressToRemove := tcpip.Address("\x01") + subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet3, err := tcpip.NewSubnet("\x02", "\x02") + if err != nil { + t.Fatal(err) + } + + // Initialize the route table with three routes. + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + {Destination: subnet3, Gateway: "\x00", NIC: 1}, + }) + + // Remove routes with the specific address. + s.RemoveRoutes(func(r tcpip.Route) bool { + return r.Destination.ID() == addressToRemove + }) + + expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}} + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index a7d54d3b9..ac4d39d3e 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -356,10 +356,9 @@ func (s *Subnet) IsBroadcast(address Address) bool { return s.Prefix() <= 30 && s.Broadcast() == address } -// Equal returns true if s equals o. -// -// Needed to use cmp.Equal on Subnet as its fields are unexported. +// Equal returns true if this Subnet is equal to the given Subnet. func (s Subnet) Equal(o Subnet) bool { + // If this changes, update Route.Equal accordingly. return s == o } @@ -1260,6 +1259,12 @@ func (r Route) String() string { return out.String() } +// Equal returns true if the given Route is equal to this Route. +func (r Route) Equal(to Route) bool { + // NOTE: This relies on the fact that r.Destination == to.Destination + return r == to +} + // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 |