diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 3 | ||||
-rw-r--r-- | pkg/tcpip/stack/fake_time_test.go | 209 | ||||
-rw-r--r-- | pkg/tcpip/stack/forwarder_test.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache_test.go | 67 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry_test.go | 47 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic_test.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 45 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 72 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 6 |
15 files changed, 208 insertions, 345 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 900938dd1..7f1d79115 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -138,7 +138,6 @@ go_test( name = "stack_test", size = "small", srcs = [ - "fake_time_test.go", "forwarder_test.go", "linkaddrcache_test.go", "neighbor_cache_test.go", @@ -152,8 +151,8 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", - "@com_github_dpjacques_clockwork//:go_default_library", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], diff --git a/pkg/tcpip/stack/fake_time_test.go b/pkg/tcpip/stack/fake_time_test.go deleted file mode 100644 index 92c8cb534..000000000 --- a/pkg/tcpip/stack/fake_time_test.go +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stack - -import ( - "container/heap" - "sync" - "time" - - "github.com/dpjacques/clockwork" - "gvisor.dev/gvisor/pkg/tcpip" -) - -type fakeClock struct { - clock clockwork.FakeClock - - // mu protects the fields below. - mu sync.RWMutex - - // times is min-heap of times. A heap is used for quick retrieval of the next - // upcoming time of scheduled work. - times *timeHeap - - // waitGroups stores one WaitGroup for all work scheduled to execute at the - // same time via AfterFunc. This allows parallel execution of all functions - // passed to AfterFunc scheduled for the same time. - waitGroups map[time.Time]*sync.WaitGroup -} - -func newFakeClock() *fakeClock { - return &fakeClock{ - clock: clockwork.NewFakeClock(), - times: &timeHeap{}, - waitGroups: make(map[time.Time]*sync.WaitGroup), - } -} - -var _ tcpip.Clock = (*fakeClock)(nil) - -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (fc *fakeClock) NowNanoseconds() int64 { - return fc.clock.Now().UnixNano() -} - -// NowMonotonic implements tcpip.Clock.NowMonotonic. -func (fc *fakeClock) NowMonotonic() int64 { - return fc.NowNanoseconds() -} - -// AfterFunc implements tcpip.Clock.AfterFunc. -func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { - until := fc.clock.Now().Add(d) - wg := fc.addWait(until) - return &fakeTimer{ - clock: fc, - until: until, - timer: fc.clock.AfterFunc(d, func() { - defer wg.Done() - f() - }), - } -} - -// addWait adds an additional wait to the WaitGroup for parallel execution of -// all work scheduled for t. Returns a reference to the WaitGroup modified. -func (fc *fakeClock) addWait(t time.Time) *sync.WaitGroup { - fc.mu.RLock() - wg, ok := fc.waitGroups[t] - fc.mu.RUnlock() - - if ok { - wg.Add(1) - return wg - } - - fc.mu.Lock() - heap.Push(fc.times, t) - fc.mu.Unlock() - - wg = &sync.WaitGroup{} - wg.Add(1) - - fc.mu.Lock() - fc.waitGroups[t] = wg - fc.mu.Unlock() - - return wg -} - -// removeWait removes a wait from the WaitGroup for parallel execution of all -// work scheduled for t. -func (fc *fakeClock) removeWait(t time.Time) { - fc.mu.RLock() - defer fc.mu.RUnlock() - - wg := fc.waitGroups[t] - wg.Done() -} - -// advance executes all work that have been scheduled to execute within d from -// the current fake time. Blocks until all work has completed execution. -func (fc *fakeClock) advance(d time.Duration) { - // Block until all the work is done - until := fc.clock.Now().Add(d) - for { - fc.mu.Lock() - if fc.times.Len() == 0 { - fc.mu.Unlock() - return - } - - t := heap.Pop(fc.times).(time.Time) - if t.After(until) { - // No work to do - heap.Push(fc.times, t) - fc.mu.Unlock() - return - } - fc.mu.Unlock() - - diff := t.Sub(fc.clock.Now()) - fc.clock.Advance(diff) - - fc.mu.RLock() - wg := fc.waitGroups[t] - fc.mu.RUnlock() - - wg.Wait() - - fc.mu.Lock() - delete(fc.waitGroups, t) - fc.mu.Unlock() - } -} - -type fakeTimer struct { - clock *fakeClock - timer clockwork.Timer - - mu sync.RWMutex - until time.Time -} - -var _ tcpip.Timer = (*fakeTimer)(nil) - -// Reset implements tcpip.Timer.Reset. -func (ft *fakeTimer) Reset(d time.Duration) { - if !ft.timer.Reset(d) { - return - } - - ft.mu.Lock() - defer ft.mu.Unlock() - - ft.clock.removeWait(ft.until) - ft.until = ft.clock.clock.Now().Add(d) - ft.clock.addWait(ft.until) -} - -// Stop implements tcpip.Timer.Stop. -func (ft *fakeTimer) Stop() bool { - if !ft.timer.Stop() { - return false - } - - ft.mu.RLock() - defer ft.mu.RUnlock() - - ft.clock.removeWait(ft.until) - return true -} - -type timeHeap []time.Time - -var _ heap.Interface = (*timeHeap)(nil) - -func (h timeHeap) Len() int { - return len(h) -} - -func (h timeHeap) Less(i, j int) bool { - return h[i].Before(h[j]) -} - -func (h timeHeap) Swap(i, j int) { - h[i], h[j] = h[j], h[i] -} - -func (h *timeHeap) Push(x interface{}) { - *h = append(*h, x.(time.Time)) -} - -func (h *timeHeap) Pop() interface{} { - last := (*h)[len(*h)-1] - *h = (*h)[:len(*h)-1] - return last -} diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 54759091a..e30927821 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -145,6 +145,10 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } +func (*fwdTestNetworkProtocol) ReturnError(*Route, tcpip.ICMPReason, *PacketBuffer) *tcpip.Error { + return nil +} + func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint { return &fwdTestNetworkEndpoint{ nicID: nicID, @@ -316,7 +320,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC } // Enable forwarding. - s.SetForwarding(true) + s.SetForwarding(proto.Number(), true) // NIC 1 has the link address "a", and added the network address 1. ep1 = &fwdTestLinkEndpoint{ diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index b0873d1af..97ca00d16 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -817,7 +817,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // per-interface basis; it is a stack-wide configuration, so we check // stack's forwarding flag to determine if the NIC is a routing // interface. - if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding { + if !ndp.configs.HandleRAs || ndp.nic.stack.Forwarding(header.IPv6ProtocolNumber) { return } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 67dc5364f..5e43a9b0b 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1120,7 +1120,7 @@ func TestNoRouterDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1365,7 +1365,7 @@ func TestNoPrefixDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -1723,7 +1723,7 @@ func TestNoAutoGenAddr(t *testing.T) { }, NDPDisp: &ndpDisp, }) - s.SetForwarding(forwarding) + s.SetForwarding(ipv6.ProtocolNumber, forwarding) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(1) = %s", err) @@ -4640,7 +4640,7 @@ func TestCleanupNDPState(t *testing.T) { name: "Enable forwarding", cleanupFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(true) + s.SetForwarding(ipv6.ProtocolNumber, true) }, keepAutoGenLinkLocal: true, maxAutoGenAddrEvents: 4, @@ -5286,11 +5286,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { name: "Enable and disable forwarding", startFn: func(t *testing.T, s *stack.Stack) { t.Helper() - s.SetForwarding(false) + s.SetForwarding(ipv6.ProtocolNumber, false) }, stopFn: func(t *testing.T, s *stack.Stack, _ bool) { t.Helper() - s.SetForwarding(true) + s.SetForwarding(ipv6.ProtocolNumber, true) }, }, diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index b4fa69e3e..a0b7da5cd 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -30,6 +30,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) const ( @@ -239,7 +240,7 @@ type entryEvent struct { func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, c, clock) if got, want := neigh.config(), c; got != want { @@ -257,7 +258,7 @@ func TestNeighborCacheGetConfig(t *testing.T) { func TestNeighborCacheSetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, c, clock) c.MinRandomFactor = 1 @@ -279,7 +280,7 @@ func TestNeighborCacheSetConfig(t *testing.T) { func TestNeighborCacheEntry(t *testing.T) { c := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, c, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -298,7 +299,7 @@ func TestNeighborCacheEntry(t *testing.T) { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { @@ -339,7 +340,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -358,7 +359,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { @@ -409,7 +410,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } type testContext struct { - clock *fakeClock + clock *faketime.ManualClock neigh *neighborCache store *testEntryStore linkRes *testNeighborResolver @@ -418,7 +419,7 @@ type testContext struct { func newTestContext(c NUDConfigurations) testContext { nudDisp := &testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(nudDisp, c, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -454,7 +455,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.neigh.config().RetransmitTimer) var wantEvents []testEntryEventInfo @@ -567,7 +568,7 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(c.neigh.config().RetransmitTimer) + c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -803,7 +804,7 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(typicalLatency) + c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -876,7 +877,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -902,7 +903,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { if doneCh == nil { t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: @@ -944,7 +945,7 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -974,7 +975,7 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { // Remove the waker before the neighbor cache has the opportunity to send a // notification. neigh.removeWaker(entry.Addr, &w) - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: @@ -1073,7 +1074,7 @@ func TestNeighborCacheClear(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1092,7 +1093,7 @@ func TestNeighborCacheClear(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { @@ -1188,7 +1189,7 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - c.clock.advance(typicalLatency) + c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1249,7 +1250,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { config.MaxRandomFactor = 1 nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1277,7 +1278,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: default: @@ -1325,7 +1326,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: default: @@ -1412,7 +1413,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1440,7 +1441,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { wg.Wait() // Process all the requests for a single entry concurrently - clock.advance(typicalLatency) + clock.Advance(typicalLatency) } // All goroutines add in the same order and add more values than can fit in @@ -1472,7 +1473,7 @@ func TestNeighborCacheReplace(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1491,7 +1492,7 @@ func TestNeighborCacheReplace(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) select { case <-doneCh: default: @@ -1541,7 +1542,7 @@ func TestNeighborCacheReplace(t *testing.T) { if err != tcpip.ErrWouldBlock { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(config.DelayFirstProbeTime + typicalLatency) + clock.Advance(config.DelayFirstProbeTime + typicalLatency) select { case <-doneCh: default: @@ -1552,7 +1553,7 @@ func TestNeighborCacheReplace(t *testing.T) { // Verify the entry's new link address { e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - clock.advance(typicalLatency) + clock.Advance(typicalLatency) if err != nil { t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) } @@ -1572,7 +1573,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { config := DefaultNUDConfigurations() nudDisp := testNUDDispatcher{} - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(&nudDisp, config, clock) store := newTestEntryStore() @@ -1595,7 +1596,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } - clock.advance(typicalLatency) + clock.Advance(typicalLatency) got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) if err != nil { t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) @@ -1618,7 +1619,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) } @@ -1636,7 +1637,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { config := DefaultNUDConfigurations() config.RetransmitTimer = time.Millisecond // small enough to cause timeout - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(nil, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ @@ -1654,7 +1655,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) } @@ -1664,7 +1665,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { // resolved immediately and don't send resolution requests. func TestNeighborCacheStaticResolution(t *testing.T) { config := DefaultNUDConfigurations() - clock := newFakeClock() + clock := faketime.NewManualClock() neigh := newTestNeighborCache(nil, config, clock) store := newTestEntryStore() linkRes := &testNeighborResolver{ diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 0068cacb8..213646160 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -73,8 +73,7 @@ const ( type neighborEntry struct { neighborEntryEntry - nic *NIC - protocol tcpip.NetworkProtocolNumber + nic *NIC // linkRes provides the functionality to send reachability probes, used in // Neighbor Unreachability Detection. diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index b769fb2fa..e530ec7ea 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -27,6 +27,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) const ( @@ -221,8 +222,8 @@ func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe return entryTestNetNumber } -func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *fakeClock) { - clock := newFakeClock() +func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) { + clock := faketime.NewManualClock() disp := testNUDDispatcher{} nic := NIC{ id: entryTestNICID, @@ -267,7 +268,7 @@ func TestEntryInitiallyUnknown(t *testing.T) { } e.mu.Unlock() - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) // No probes should have been sent. linkRes.mu.Lock() @@ -300,7 +301,7 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { } e.mu.Unlock() - clock.advance(time.Hour) + clock.Advance(time.Hour) // No probes should have been sent. linkRes.mu.Lock() @@ -410,7 +411,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { updatedAt := e.neigh.UpdatedAt e.mu.Unlock() - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) // UpdatedAt should remain the same during address resolution. wantProbes := []entryTestProbeInfo{ @@ -439,7 +440,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Unlock() - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) // UpdatedAt should change after failing address resolution. Timing out after // sending the last probe transitions the entry to Failed. @@ -459,7 +460,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } } - clock.advance(c.RetransmitTimer) + clock.Advance(c.RetransmitTimer) wantEvents := []testEntryEventInfo{ { @@ -748,7 +749,7 @@ func TestEntryIncompleteToFailed(t *testing.T) { e.mu.Unlock() waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) wantProbes := []entryTestProbeInfo{ // The Incomplete-to-Incomplete state transition is tested here by @@ -983,7 +984,7 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -1612,7 +1613,7 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -1706,7 +1707,7 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -1989,7 +1990,7 @@ func TestEntryDelayToProbe(t *testing.T) { } e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2069,7 +2070,7 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2166,7 +2167,7 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2267,7 +2268,7 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2364,7 +2365,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // Probe caused by the Delay-to-Probe transition @@ -2398,7 +2399,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { } e.mu.Unlock() - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -2463,7 +2464,7 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2503,7 +2504,7 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { } e.mu.Unlock() - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -2575,7 +2576,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e.handlePacketQueuedLocked() e.mu.Unlock() - clock.advance(c.DelayFirstProbeTime) + clock.Advance(c.DelayFirstProbeTime) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2612,7 +2613,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin } e.mu.Unlock() - clock.advance(c.BaseReachableTime) + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { @@ -2682,7 +2683,7 @@ func TestEntryProbeToFailed(t *testing.T) { e.mu.Unlock() waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) - clock.advance(waitFor) + clock.Advance(waitFor) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. @@ -2787,7 +2788,7 @@ func TestEntryFailedGetsDeleted(t *testing.T) { e.mu.Unlock() waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime - clock.advance(waitFor) + clock.Advance(waitFor) wantProbes := []entryTestProbeInfo{ // The first probe is caused by the Unknown-to-Incomplete transition. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 204bfc433..06d70dd1c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -337,7 +337,7 @@ func (n *NIC) enable() *tcpip.Error { // does. That is, routers do not learn from RAs (e.g. on-link prefixes // and default routers). Therefore, soliciting RAs from other routers on // a link is unnecessary for routers. - if !n.stack.forwarding { + if !n.stack.Forwarding(header.IPv6ProtocolNumber) { n.mu.ndp.startSolicitingRouters() } @@ -1242,9 +1242,9 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp local = n.linkEP.LinkAddress() } - // Are any packet sockets listening for this network protocol? + // Are any packet type sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] - // Add any other packet sockets that maybe listening for all protocols. + // Add any other packet type sockets that may be listening for all protocols. packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) n.mu.RUnlock() for _, ep := range packetEPs { @@ -1265,6 +1265,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp return } if hasTransportHdr { + pkt.TransportProtocolNumber = transProtoNum // Parse the transport header if present. if state, ok := n.stack.transportProtocols[transProtoNum]; ok { state.proto.Parse(pkt) @@ -1303,7 +1304,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // packet and forward it to the NIC. // // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding() { + if n.stack.Forwarding(protocol) { r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) if err != nil { n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() @@ -1330,6 +1331,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // n doesn't have a destination endpoint. // Send the packet out of n. // TODO(b/128629022): move this logic to route.WritePacket. + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt) @@ -1452,10 +1454,28 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN } } - // We could not find an appropriate destination for this packet, so - // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, pkt) { + // We could not find an appropriate destination for this packet so + // give the protocol specific error handler a chance to handle it. + // If it doesn't handle it then we should do so. + switch transProto.HandleUnknownDestinationPacket(r, id, pkt) { + case UnknownDestinationPacketMalformed: n.stack.stats.MalformedRcvdPackets.Increment() + case UnknownDestinationPacketUnhandled: + // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination + // Unreachable messages with code: + // 3 (Port Unreachable), when the designated transport protocol + // (e.g., UDP) is unable to demultiplex the datagram but has no + // protocol mechanism to inform the sender. + np, ok := n.stack.networkProtocols[r.NetProto] + if !ok { + // For this to happen stack.makeRoute() must have been called with the + // incorrect protocol number. Since we have successfully completed + // network layer processing this should be impossible. + panic(fmt.Sprintf("expected stack to have a NetworkProtocol for proto = %d", r.NetProto)) + } + + _ = np.ReturnError(r, &tcpip.ICMPReasonPortUnreachable{}, pkt) + case UnknownDestinationPacketHandled: } } diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index dd6474297..ef6e63b3e 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -221,6 +221,11 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo return 0, false, false } +// ReturnError implements NetworkProtocol.ReturnError. +func (*testIPv6Protocol) ReturnError(*Route, tcpip.ICMPReason, *PacketBuffer) *tcpip.Error { + return nil +} + var _ LinkAddressResolver = (*testIPv6Protocol)(nil) // LinkAddressProtocol implements LinkAddressResolver. diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 1932aaeb7..a7d9d59fa 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -80,11 +80,17 @@ type PacketBuffer struct { // data are held in the same underlying buffer storage. header buffer.Prependable - // NetworkProtocolNumber is only valid when NetworkHeader is set. + // NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty() + // returns false. // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol // numbers in registration APIs that take a PacketBuffer. NetworkProtocolNumber tcpip.NetworkProtocolNumber + // TransportProtocol is only valid if it is non zero. + // TODO(gvisor.dev/issue/3810): This and the network protocol number should + // be moved into the headerinfo. This should resolve the validity issue. + TransportProtocolNumber tcpip.TransportProtocolNumber + // Hash is the transport layer hash of this packet. A value of zero // indicates no valid hash has been set. Hash uint32 @@ -234,16 +240,17 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum // underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { newPk := &PacketBuffer{ - PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), - headers: pk.headers, - header: pk.header, - Hash: pk.Hash, - Owner: pk.Owner, - EgressRoute: pk.EgressRoute, - GSOOptions: pk.GSOOptions, - NetworkProtocolNumber: pk.NetworkProtocolNumber, - NatDone: pk.NatDone, + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + headers: pk.headers, + header: pk.header, + Hash: pk.Hash, + Owner: pk.Owner, + EgressRoute: pk.EgressRoute, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + TransportProtocolNumber: pk.TransportProtocolNumber, } return newPk } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 4fa86a3ac..77640cd8a 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -125,6 +125,26 @@ type PacketEndpoint interface { HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } +// UnknownDestinationPacketDisposition enumerates the possible return vaues from +// HandleUnknownDestinationPacket(). +type UnknownDestinationPacketDisposition int + +const ( + // UnknownDestinationPacketMalformed denotes that the packet was malformed + // and no further processing should be attempted other than updating + // statistics. + UnknownDestinationPacketMalformed UnknownDestinationPacketDisposition = iota + + // UnknownDestinationPacketUnhandled tells the caller that the packet was + // well formed but that the issue was not handled and the stack should take + // the default action. + UnknownDestinationPacketUnhandled + + // UnknownDestinationPacketHandled tells the caller that it should do + // no further processing. + UnknownDestinationPacketHandled +) + // TransportProtocol is the interface that needs to be implemented by transport // protocols (e.g., tcp, udp) that want to be part of the networking stack. type TransportProtocol interface { @@ -147,14 +167,12 @@ type TransportProtocol interface { ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this - // protocol but that don't match any existing endpoint. For example, - // it is targeted at a port that have no listeners. + // protocol that don't match any existing endpoint. For example, + // it is targeted at a port that has no listeners. // - // The return value indicates whether the packet was well-formed (for - // stats purposes only). - // - // HandleUnknownDestinationPacket takes ownership of pkt. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool + // HandleUnknownDestinationPacket takes ownership of pkt if it handles + // the issue. + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -324,6 +342,19 @@ type NetworkProtocol interface { // does not encapsulate anything). // - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader. Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) + + // ReturnError attempts to send a suitable error message to the sender + // of a received packet. + // - pkt holds the problematic packet. + // - reason indicates what the reason for wanting a message is. + // - route is the routing information for the received packet + // ReturnError returns an error if the send failed and nil on success. + // Note that deciding to deliberately send no message is a success. + // + // TODO(gvisor.dev/issues/3871): This method should be removed or simplified + // after all (or all but one) of the ICMP error dispatch occurs through the + // protocol specific modules. May become SendPortNotFound(r, pkt). + ReturnError(r *Route, reason tcpip.ICMPReason, pkt *PacketBuffer) *tcpip.Error } // NetworkDispatcher contains the methods used by the network stack to deliver diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6a683545d..e7b7e95d4 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -144,10 +144,7 @@ type TCPReceiverState struct { // PendingBufUsed is the number of bytes pending in the receive // queue. - PendingBufUsed seqnum.Size - - // PendingBufSize is the size of the socket receive buffer. - PendingBufSize seqnum.Size + PendingBufUsed int } // TCPSenderState holds a copy of the internal state of the sender for @@ -405,6 +402,13 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + // forwarding contains the whether packet forwarding is enabled or not for + // different network protocols. + forwarding struct { + sync.RWMutex + protocols map[tcpip.NetworkProtocolNumber]bool + } + // rawFactory creates raw endpoints. If nil, raw endpoints are // disabled. It is set during Stack creation and is immutable. rawFactory RawFactory @@ -415,9 +419,8 @@ type Stack struct { linkAddrCache *linkAddrCache - mu sync.RWMutex - nics map[tcpip.NICID]*NIC - forwarding bool + mu sync.RWMutex + nics map[tcpip.NICID]*NIC // cleanupEndpointsMu protects cleanupEndpoints. cleanupEndpointsMu sync.Mutex @@ -749,6 +752,7 @@ func New(opts Options) *Stack { Max: DefaultMaxBufferSize, }, } + s.forwarding.protocols = make(map[tcpip.NetworkProtocolNumber]bool) // Add specified network protocols. for _, netProto := range opts.NetworkProtocols { @@ -866,46 +870,42 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwarding enables or disables the packet forwarding between NICs. -// -// When forwarding becomes enabled, any host-only state on all NICs will be -// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started. -// When forwarding becomes disabled and if IPv6 is enabled, NDP Router -// Solicitations will be stopped. -func (s *Stack) SetForwarding(enable bool) { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.Lock() - defer s.mu.Unlock() +// SetForwarding enables or disables packet forwarding between NICs. +func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) { + s.forwarding.Lock() + defer s.forwarding.Unlock() - // If forwarding status didn't change, do nothing further. - if s.forwarding == enable { + // If this stack does not support the protocol, do nothing. + if _, ok := s.networkProtocols[protocol]; !ok { return } - s.forwarding = enable - - // If this stack does not support IPv6, do nothing further. - if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok { + // If the forwarding value for this protocol hasn't changed then do + // nothing. + if forwarding := s.forwarding.protocols[protocol]; forwarding == enable { return } - if enable { - for _, nic := range s.nics { - nic.becomeIPv6Router() - } - } else { - for _, nic := range s.nics { - nic.becomeIPv6Host() + s.forwarding.protocols[protocol] = enable + + if protocol == header.IPv6ProtocolNumber { + if enable { + for _, nic := range s.nics { + nic.becomeIPv6Router() + } + } else { + for _, nic := range s.nics { + nic.becomeIPv6Host() + } } } } -// Forwarding returns if the packet forwarding between NICs is enabled. -func (s *Stack) Forwarding() bool { - // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward. - s.mu.RLock() - defer s.mu.RUnlock() - return s.forwarding +// Forwarding returns if packet forwarding between NICs is enabled. +func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { + s.forwarding.RLock() + defer s.forwarding.RUnlock() + return s.forwarding.protocols[protocol] } // SetRouteTable assigns the route table to be used by this stack. It diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 60b54c244..9ef6787c6 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -216,13 +216,18 @@ func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) } } -// Close implements TransportProtocol.Close. +// ReturnError implements NetworkProtocol.ReturnError +func (*fakeNetworkProtocol) ReturnError(*stack.Route, tcpip.ICMPReason, *stack.PacketBuffer) *tcpip.Error { + return nil +} + +// Close implements NetworkProtocol.Close. func (*fakeNetworkProtocol) Close() {} -// Wait implements TransportProtocol.Wait. +// Wait implements NetworkProtocol.Wait. func (*fakeNetworkProtocol) Wait() {} -// Parse implements TransportProtocol.Parse. +// Parse implements NetworkProtocol.Parse. func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) { hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen) if !ok { @@ -2091,7 +2096,7 @@ func TestNICForwarding(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) ep1 := channel.New(10, defaultMTU, "") if err := s.CreateNIC(nicID1, ep1); err != nil { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index ef3457e32..cbb34d224 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -287,8 +287,8 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { - return true +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + return stack.UnknownDestinationPacketHandled } func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { @@ -549,7 +549,7 @@ func TestTransportForwarding(t *testing.T) { NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, }) - s.SetForwarding(true) + s.SetForwarding(fakeNetNumber, true) // TODO(b/123449044): Change this to a channel NIC. ep1 := loopback.New() |