summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD3
-rw-r--r--pkg/tcpip/stack/fake_time_test.go209
-rw-r--r--pkg/tcpip/stack/forwarder_test.go6
-rw-r--r--pkg/tcpip/stack/ndp.go2
-rw-r--r--pkg/tcpip/stack/ndp_test.go12
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go67
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go3
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go47
-rw-r--r--pkg/tcpip/stack/nic.go34
-rw-r--r--pkg/tcpip/stack/nic_test.go5
-rw-r--r--pkg/tcpip/stack/packet_buffer.go29
-rw-r--r--pkg/tcpip/stack/registration.go45
-rw-r--r--pkg/tcpip/stack/stack.go72
-rw-r--r--pkg/tcpip/stack/stack_test.go13
-rw-r--r--pkg/tcpip/stack/transport_test.go6
15 files changed, 208 insertions, 345 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 900938dd1..7f1d79115 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -138,7 +138,6 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
- "fake_time_test.go",
"forwarder_test.go",
"linkaddrcache_test.go",
"neighbor_cache_test.go",
@@ -152,8 +151,8 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/faketime",
"//pkg/tcpip/header",
- "@com_github_dpjacques_clockwork//:go_default_library",
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
diff --git a/pkg/tcpip/stack/fake_time_test.go b/pkg/tcpip/stack/fake_time_test.go
deleted file mode 100644
index 92c8cb534..000000000
--- a/pkg/tcpip/stack/fake_time_test.go
+++ /dev/null
@@ -1,209 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package stack
-
-import (
- "container/heap"
- "sync"
- "time"
-
- "github.com/dpjacques/clockwork"
- "gvisor.dev/gvisor/pkg/tcpip"
-)
-
-type fakeClock struct {
- clock clockwork.FakeClock
-
- // mu protects the fields below.
- mu sync.RWMutex
-
- // times is min-heap of times. A heap is used for quick retrieval of the next
- // upcoming time of scheduled work.
- times *timeHeap
-
- // waitGroups stores one WaitGroup for all work scheduled to execute at the
- // same time via AfterFunc. This allows parallel execution of all functions
- // passed to AfterFunc scheduled for the same time.
- waitGroups map[time.Time]*sync.WaitGroup
-}
-
-func newFakeClock() *fakeClock {
- return &fakeClock{
- clock: clockwork.NewFakeClock(),
- times: &timeHeap{},
- waitGroups: make(map[time.Time]*sync.WaitGroup),
- }
-}
-
-var _ tcpip.Clock = (*fakeClock)(nil)
-
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (fc *fakeClock) NowNanoseconds() int64 {
- return fc.clock.Now().UnixNano()
-}
-
-// NowMonotonic implements tcpip.Clock.NowMonotonic.
-func (fc *fakeClock) NowMonotonic() int64 {
- return fc.NowNanoseconds()
-}
-
-// AfterFunc implements tcpip.Clock.AfterFunc.
-func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer {
- until := fc.clock.Now().Add(d)
- wg := fc.addWait(until)
- return &fakeTimer{
- clock: fc,
- until: until,
- timer: fc.clock.AfterFunc(d, func() {
- defer wg.Done()
- f()
- }),
- }
-}
-
-// addWait adds an additional wait to the WaitGroup for parallel execution of
-// all work scheduled for t. Returns a reference to the WaitGroup modified.
-func (fc *fakeClock) addWait(t time.Time) *sync.WaitGroup {
- fc.mu.RLock()
- wg, ok := fc.waitGroups[t]
- fc.mu.RUnlock()
-
- if ok {
- wg.Add(1)
- return wg
- }
-
- fc.mu.Lock()
- heap.Push(fc.times, t)
- fc.mu.Unlock()
-
- wg = &sync.WaitGroup{}
- wg.Add(1)
-
- fc.mu.Lock()
- fc.waitGroups[t] = wg
- fc.mu.Unlock()
-
- return wg
-}
-
-// removeWait removes a wait from the WaitGroup for parallel execution of all
-// work scheduled for t.
-func (fc *fakeClock) removeWait(t time.Time) {
- fc.mu.RLock()
- defer fc.mu.RUnlock()
-
- wg := fc.waitGroups[t]
- wg.Done()
-}
-
-// advance executes all work that have been scheduled to execute within d from
-// the current fake time. Blocks until all work has completed execution.
-func (fc *fakeClock) advance(d time.Duration) {
- // Block until all the work is done
- until := fc.clock.Now().Add(d)
- for {
- fc.mu.Lock()
- if fc.times.Len() == 0 {
- fc.mu.Unlock()
- return
- }
-
- t := heap.Pop(fc.times).(time.Time)
- if t.After(until) {
- // No work to do
- heap.Push(fc.times, t)
- fc.mu.Unlock()
- return
- }
- fc.mu.Unlock()
-
- diff := t.Sub(fc.clock.Now())
- fc.clock.Advance(diff)
-
- fc.mu.RLock()
- wg := fc.waitGroups[t]
- fc.mu.RUnlock()
-
- wg.Wait()
-
- fc.mu.Lock()
- delete(fc.waitGroups, t)
- fc.mu.Unlock()
- }
-}
-
-type fakeTimer struct {
- clock *fakeClock
- timer clockwork.Timer
-
- mu sync.RWMutex
- until time.Time
-}
-
-var _ tcpip.Timer = (*fakeTimer)(nil)
-
-// Reset implements tcpip.Timer.Reset.
-func (ft *fakeTimer) Reset(d time.Duration) {
- if !ft.timer.Reset(d) {
- return
- }
-
- ft.mu.Lock()
- defer ft.mu.Unlock()
-
- ft.clock.removeWait(ft.until)
- ft.until = ft.clock.clock.Now().Add(d)
- ft.clock.addWait(ft.until)
-}
-
-// Stop implements tcpip.Timer.Stop.
-func (ft *fakeTimer) Stop() bool {
- if !ft.timer.Stop() {
- return false
- }
-
- ft.mu.RLock()
- defer ft.mu.RUnlock()
-
- ft.clock.removeWait(ft.until)
- return true
-}
-
-type timeHeap []time.Time
-
-var _ heap.Interface = (*timeHeap)(nil)
-
-func (h timeHeap) Len() int {
- return len(h)
-}
-
-func (h timeHeap) Less(i, j int) bool {
- return h[i].Before(h[j])
-}
-
-func (h timeHeap) Swap(i, j int) {
- h[i], h[j] = h[j], h[i]
-}
-
-func (h *timeHeap) Push(x interface{}) {
- *h = append(*h, x.(time.Time))
-}
-
-func (h *timeHeap) Pop() interface{} {
- last := (*h)[len(*h)-1]
- *h = (*h)[:len(*h)-1]
- return last
-}
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 54759091a..e30927821 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -145,6 +145,10 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol
return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
}
+func (*fwdTestNetworkProtocol) ReturnError(*Route, tcpip.ICMPReason, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint {
return &fwdTestNetworkEndpoint{
nicID: nicID,
@@ -316,7 +320,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborC
}
// Enable forwarding.
- s.SetForwarding(true)
+ s.SetForwarding(proto.Number(), true)
// NIC 1 has the link address "a", and added the network address 1.
ep1 = &fwdTestLinkEndpoint{
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index b0873d1af..97ca00d16 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -817,7 +817,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
// per-interface basis; it is a stack-wide configuration, so we check
// stack's forwarding flag to determine if the NIC is a routing
// interface.
- if !ndp.configs.HandleRAs || ndp.nic.stack.forwarding {
+ if !ndp.configs.HandleRAs || ndp.nic.stack.Forwarding(header.IPv6ProtocolNumber) {
return
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 67dc5364f..5e43a9b0b 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -1120,7 +1120,7 @@ func TestNoRouterDiscovery(t *testing.T) {
},
NDPDisp: &ndpDisp,
})
- s.SetForwarding(forwarding)
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -1365,7 +1365,7 @@ func TestNoPrefixDiscovery(t *testing.T) {
},
NDPDisp: &ndpDisp,
})
- s.SetForwarding(forwarding)
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -1723,7 +1723,7 @@ func TestNoAutoGenAddr(t *testing.T) {
},
NDPDisp: &ndpDisp,
})
- s.SetForwarding(forwarding)
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(1) = %s", err)
@@ -4640,7 +4640,7 @@ func TestCleanupNDPState(t *testing.T) {
name: "Enable forwarding",
cleanupFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
- s.SetForwarding(true)
+ s.SetForwarding(ipv6.ProtocolNumber, true)
},
keepAutoGenLinkLocal: true,
maxAutoGenAddrEvents: 4,
@@ -5286,11 +5286,11 @@ func TestStopStartSolicitingRouters(t *testing.T) {
name: "Enable and disable forwarding",
startFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
- s.SetForwarding(false)
+ s.SetForwarding(ipv6.ProtocolNumber, false)
},
stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
t.Helper()
- s.SetForwarding(true)
+ s.SetForwarding(ipv6.ProtocolNumber, true)
},
},
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index b4fa69e3e..a0b7da5cd 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -30,6 +30,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
)
const (
@@ -239,7 +240,7 @@ type entryEvent struct {
func TestNeighborCacheGetConfig(t *testing.T) {
nudDisp := testNUDDispatcher{}
c := DefaultNUDConfigurations()
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, c, clock)
if got, want := neigh.config(), c; got != want {
@@ -257,7 +258,7 @@ func TestNeighborCacheGetConfig(t *testing.T) {
func TestNeighborCacheSetConfig(t *testing.T) {
nudDisp := testNUDDispatcher{}
c := DefaultNUDConfigurations()
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, c, clock)
c.MinRandomFactor = 1
@@ -279,7 +280,7 @@ func TestNeighborCacheSetConfig(t *testing.T) {
func TestNeighborCacheEntry(t *testing.T) {
c := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, c, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -298,7 +299,7 @@ func TestNeighborCacheEntry(t *testing.T) {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
@@ -339,7 +340,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -358,7 +359,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
@@ -409,7 +410,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) {
}
type testContext struct {
- clock *fakeClock
+ clock *faketime.ManualClock
neigh *neighborCache
store *testEntryStore
linkRes *testNeighborResolver
@@ -418,7 +419,7 @@ type testContext struct {
func newTestContext(c NUDConfigurations) testContext {
nudDisp := &testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(nudDisp, c, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -454,7 +455,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error {
if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock {
return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- c.clock.advance(c.neigh.config().RetransmitTimer)
+ c.clock.Advance(c.neigh.config().RetransmitTimer)
var wantEvents []testEntryEventInfo
@@ -567,7 +568,7 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- c.clock.advance(c.neigh.config().RetransmitTimer)
+ c.clock.Advance(c.neigh.config().RetransmitTimer)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -803,7 +804,7 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- c.clock.advance(typicalLatency)
+ c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -876,7 +877,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -902,7 +903,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) {
if doneCh == nil {
t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
@@ -944,7 +945,7 @@ func TestNeighborCacheRemoveWaker(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -974,7 +975,7 @@ func TestNeighborCacheRemoveWaker(t *testing.T) {
// Remove the waker before the neighbor cache has the opportunity to send a
// notification.
neigh.removeWaker(entry.Addr, &w)
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
@@ -1073,7 +1074,7 @@ func TestNeighborCacheClear(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1092,7 +1093,7 @@ func TestNeighborCacheClear(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
@@ -1188,7 +1189,7 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- c.clock.advance(typicalLatency)
+ c.clock.Advance(typicalLatency)
wantEvents := []testEntryEventInfo{
{
EventType: entryTestAdded,
@@ -1249,7 +1250,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
config.MaxRandomFactor = 1
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1277,7 +1278,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
@@ -1325,7 +1326,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
@@ -1412,7 +1413,7 @@ func TestNeighborCacheConcurrent(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1440,7 +1441,7 @@ func TestNeighborCacheConcurrent(t *testing.T) {
wg.Wait()
// Process all the requests for a single entry concurrently
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
}
// All goroutines add in the same order and add more values than can fit in
@@ -1472,7 +1473,7 @@ func TestNeighborCacheReplace(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1491,7 +1492,7 @@ func TestNeighborCacheReplace(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
select {
case <-doneCh:
default:
@@ -1541,7 +1542,7 @@ func TestNeighborCacheReplace(t *testing.T) {
if err != tcpip.ErrWouldBlock {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(config.DelayFirstProbeTime + typicalLatency)
+ clock.Advance(config.DelayFirstProbeTime + typicalLatency)
select {
case <-doneCh:
default:
@@ -1552,7 +1553,7 @@ func TestNeighborCacheReplace(t *testing.T) {
// Verify the entry's new link address
{
e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
if err != nil {
t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
}
@@ -1572,7 +1573,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
config := DefaultNUDConfigurations()
nudDisp := testNUDDispatcher{}
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(&nudDisp, config, clock)
store := newTestEntryStore()
@@ -1595,7 +1596,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
- clock.advance(typicalLatency)
+ clock.Advance(typicalLatency)
got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
if err != nil {
t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
@@ -1618,7 +1619,7 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
- clock.advance(waitFor)
+ clock.Advance(waitFor)
if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
}
@@ -1636,7 +1637,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
config := DefaultNUDConfigurations()
config.RetransmitTimer = time.Millisecond // small enough to cause timeout
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(nil, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
@@ -1654,7 +1655,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
}
waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
- clock.advance(waitFor)
+ clock.Advance(waitFor)
if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
}
@@ -1664,7 +1665,7 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
// resolved immediately and don't send resolution requests.
func TestNeighborCacheStaticResolution(t *testing.T) {
config := DefaultNUDConfigurations()
- clock := newFakeClock()
+ clock := faketime.NewManualClock()
neigh := newTestNeighborCache(nil, config, clock)
store := newTestEntryStore()
linkRes := &testNeighborResolver{
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 0068cacb8..213646160 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -73,8 +73,7 @@ const (
type neighborEntry struct {
neighborEntryEntry
- nic *NIC
- protocol tcpip.NetworkProtocolNumber
+ nic *NIC
// linkRes provides the functionality to send reachability probes, used in
// Neighbor Unreachability Detection.
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index b769fb2fa..e530ec7ea 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -27,6 +27,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/faketime"
)
const (
@@ -221,8 +222,8 @@ func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe
return entryTestNetNumber
}
-func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *fakeClock) {
- clock := newFakeClock()
+func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *faketime.ManualClock) {
+ clock := faketime.NewManualClock()
disp := testNUDDispatcher{}
nic := NIC{
id: entryTestNICID,
@@ -267,7 +268,7 @@ func TestEntryInitiallyUnknown(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.RetransmitTimer)
+ clock.Advance(c.RetransmitTimer)
// No probes should have been sent.
linkRes.mu.Lock()
@@ -300,7 +301,7 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(time.Hour)
+ clock.Advance(time.Hour)
// No probes should have been sent.
linkRes.mu.Lock()
@@ -410,7 +411,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
updatedAt := e.neigh.UpdatedAt
e.mu.Unlock()
- clock.advance(c.RetransmitTimer)
+ clock.Advance(c.RetransmitTimer)
// UpdatedAt should remain the same during address resolution.
wantProbes := []entryTestProbeInfo{
@@ -439,7 +440,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.RetransmitTimer)
+ clock.Advance(c.RetransmitTimer)
// UpdatedAt should change after failing address resolution. Timing out after
// sending the last probe transitions the entry to Failed.
@@ -459,7 +460,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
}
}
- clock.advance(c.RetransmitTimer)
+ clock.Advance(c.RetransmitTimer)
wantEvents := []testEntryEventInfo{
{
@@ -748,7 +749,7 @@ func TestEntryIncompleteToFailed(t *testing.T) {
e.mu.Unlock()
waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes)
- clock.advance(waitFor)
+ clock.Advance(waitFor)
wantProbes := []entryTestProbeInfo{
// The Incomplete-to-Incomplete state transition is tested here by
@@ -983,7 +984,7 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -1612,7 +1613,7 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -1706,7 +1707,7 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
}
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -1989,7 +1990,7 @@ func TestEntryDelayToProbe(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2069,7 +2070,7 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2166,7 +2167,7 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2267,7 +2268,7 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2364,7 +2365,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// Probe caused by the Delay-to-Probe transition
@@ -2398,7 +2399,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -2463,7 +2464,7 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2503,7 +2504,7 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
}
e.mu.Unlock()
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -2575,7 +2576,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
e.handlePacketQueuedLocked()
e.mu.Unlock()
- clock.advance(c.DelayFirstProbeTime)
+ clock.Advance(c.DelayFirstProbeTime)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2612,7 +2613,7 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin
}
e.mu.Unlock()
- clock.advance(c.BaseReachableTime)
+ clock.Advance(c.BaseReachableTime)
wantEvents := []testEntryEventInfo{
{
@@ -2682,7 +2683,7 @@ func TestEntryProbeToFailed(t *testing.T) {
e.mu.Unlock()
waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
- clock.advance(waitFor)
+ clock.Advance(waitFor)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
@@ -2787,7 +2788,7 @@ func TestEntryFailedGetsDeleted(t *testing.T) {
e.mu.Unlock()
waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
- clock.advance(waitFor)
+ clock.Advance(waitFor)
wantProbes := []entryTestProbeInfo{
// The first probe is caused by the Unknown-to-Incomplete transition.
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 204bfc433..06d70dd1c 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -337,7 +337,7 @@ func (n *NIC) enable() *tcpip.Error {
// does. That is, routers do not learn from RAs (e.g. on-link prefixes
// and default routers). Therefore, soliciting RAs from other routers on
// a link is unnecessary for routers.
- if !n.stack.forwarding {
+ if !n.stack.Forwarding(header.IPv6ProtocolNumber) {
n.mu.ndp.startSolicitingRouters()
}
@@ -1242,9 +1242,9 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
local = n.linkEP.LinkAddress()
}
- // Are any packet sockets listening for this network protocol?
+ // Are any packet type sockets listening for this network protocol?
packetEPs := n.mu.packetEPs[protocol]
- // Add any other packet sockets that maybe listening for all protocols.
+ // Add any other packet type sockets that may be listening for all protocols.
packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
n.mu.RUnlock()
for _, ep := range packetEPs {
@@ -1265,6 +1265,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
return
}
if hasTransportHdr {
+ pkt.TransportProtocolNumber = transProtoNum
// Parse the transport header if present.
if state, ok := n.stack.transportProtocols[transProtoNum]; ok {
state.proto.Parse(pkt)
@@ -1303,7 +1304,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// packet and forward it to the NIC.
//
// TODO: Should we be forwarding the packet even if promiscuous?
- if n.stack.Forwarding() {
+ if n.stack.Forwarding(protocol) {
r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
if err != nil {
n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
@@ -1330,6 +1331,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// n doesn't have a destination endpoint.
// Send the packet out of n.
// TODO(b/128629022): move this logic to route.WritePacket.
+ // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
if ch, err := r.Resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
@@ -1452,10 +1454,28 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
}
}
- // We could not find an appropriate destination for this packet, so
- // deliver it to the global handler.
- if !transProto.HandleUnknownDestinationPacket(r, id, pkt) {
+ // We could not find an appropriate destination for this packet so
+ // give the protocol specific error handler a chance to handle it.
+ // If it doesn't handle it then we should do so.
+ switch transProto.HandleUnknownDestinationPacket(r, id, pkt) {
+ case UnknownDestinationPacketMalformed:
n.stack.stats.MalformedRcvdPackets.Increment()
+ case UnknownDestinationPacketUnhandled:
+ // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
+ // Unreachable messages with code:
+ // 3 (Port Unreachable), when the designated transport protocol
+ // (e.g., UDP) is unable to demultiplex the datagram but has no
+ // protocol mechanism to inform the sender.
+ np, ok := n.stack.networkProtocols[r.NetProto]
+ if !ok {
+ // For this to happen stack.makeRoute() must have been called with the
+ // incorrect protocol number. Since we have successfully completed
+ // network layer processing this should be impossible.
+ panic(fmt.Sprintf("expected stack to have a NetworkProtocol for proto = %d", r.NetProto))
+ }
+
+ _ = np.ReturnError(r, &tcpip.ICMPReasonPortUnreachable{}, pkt)
+ case UnknownDestinationPacketHandled:
}
}
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index dd6474297..ef6e63b3e 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -221,6 +221,11 @@ func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bo
return 0, false, false
}
+// ReturnError implements NetworkProtocol.ReturnError.
+func (*testIPv6Protocol) ReturnError(*Route, tcpip.ICMPReason, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
var _ LinkAddressResolver = (*testIPv6Protocol)(nil)
// LinkAddressProtocol implements LinkAddressResolver.
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 1932aaeb7..a7d9d59fa 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -80,11 +80,17 @@ type PacketBuffer struct {
// data are held in the same underlying buffer storage.
header buffer.Prependable
- // NetworkProtocolNumber is only valid when NetworkHeader is set.
+ // NetworkProtocolNumber is only valid when NetworkHeader().View().IsEmpty()
+ // returns false.
// TODO(gvisor.dev/issue/3574): Remove the separately passed protocol
// numbers in registration APIs that take a PacketBuffer.
NetworkProtocolNumber tcpip.NetworkProtocolNumber
+ // TransportProtocol is only valid if it is non zero.
+ // TODO(gvisor.dev/issue/3810): This and the network protocol number should
+ // be moved into the headerinfo. This should resolve the validity issue.
+ TransportProtocolNumber tcpip.TransportProtocolNumber
+
// Hash is the transport layer hash of this packet. A value of zero
// indicates no valid hash has been set.
Hash uint32
@@ -234,16 +240,17 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum
// underlying packet payload.
func (pk *PacketBuffer) Clone() *PacketBuffer {
newPk := &PacketBuffer{
- PacketBufferEntry: pk.PacketBufferEntry,
- Data: pk.Data.Clone(nil),
- headers: pk.headers,
- header: pk.header,
- Hash: pk.Hash,
- Owner: pk.Owner,
- EgressRoute: pk.EgressRoute,
- GSOOptions: pk.GSOOptions,
- NetworkProtocolNumber: pk.NetworkProtocolNumber,
- NatDone: pk.NatDone,
+ PacketBufferEntry: pk.PacketBufferEntry,
+ Data: pk.Data.Clone(nil),
+ headers: pk.headers,
+ header: pk.header,
+ Hash: pk.Hash,
+ Owner: pk.Owner,
+ EgressRoute: pk.EgressRoute,
+ GSOOptions: pk.GSOOptions,
+ NetworkProtocolNumber: pk.NetworkProtocolNumber,
+ NatDone: pk.NatDone,
+ TransportProtocolNumber: pk.TransportProtocolNumber,
}
return newPk
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 4fa86a3ac..77640cd8a 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -125,6 +125,26 @@ type PacketEndpoint interface {
HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
+// UnknownDestinationPacketDisposition enumerates the possible return vaues from
+// HandleUnknownDestinationPacket().
+type UnknownDestinationPacketDisposition int
+
+const (
+ // UnknownDestinationPacketMalformed denotes that the packet was malformed
+ // and no further processing should be attempted other than updating
+ // statistics.
+ UnknownDestinationPacketMalformed UnknownDestinationPacketDisposition = iota
+
+ // UnknownDestinationPacketUnhandled tells the caller that the packet was
+ // well formed but that the issue was not handled and the stack should take
+ // the default action.
+ UnknownDestinationPacketUnhandled
+
+ // UnknownDestinationPacketHandled tells the caller that it should do
+ // no further processing.
+ UnknownDestinationPacketHandled
+)
+
// TransportProtocol is the interface that needs to be implemented by transport
// protocols (e.g., tcp, udp) that want to be part of the networking stack.
type TransportProtocol interface {
@@ -147,14 +167,12 @@ type TransportProtocol interface {
ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this
- // protocol but that don't match any existing endpoint. For example,
- // it is targeted at a port that have no listeners.
+ // protocol that don't match any existing endpoint. For example,
+ // it is targeted at a port that has no listeners.
//
- // The return value indicates whether the packet was well-formed (for
- // stats purposes only).
- //
- // HandleUnknownDestinationPacket takes ownership of pkt.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
+ // HandleUnknownDestinationPacket takes ownership of pkt if it handles
+ // the issue.
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -324,6 +342,19 @@ type NetworkProtocol interface {
// does not encapsulate anything).
// - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader.
Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool)
+
+ // ReturnError attempts to send a suitable error message to the sender
+ // of a received packet.
+ // - pkt holds the problematic packet.
+ // - reason indicates what the reason for wanting a message is.
+ // - route is the routing information for the received packet
+ // ReturnError returns an error if the send failed and nil on success.
+ // Note that deciding to deliberately send no message is a success.
+ //
+ // TODO(gvisor.dev/issues/3871): This method should be removed or simplified
+ // after all (or all but one) of the ICMP error dispatch occurs through the
+ // protocol specific modules. May become SendPortNotFound(r, pkt).
+ ReturnError(r *Route, reason tcpip.ICMPReason, pkt *PacketBuffer) *tcpip.Error
}
// NetworkDispatcher contains the methods used by the network stack to deliver
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 6a683545d..e7b7e95d4 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -144,10 +144,7 @@ type TCPReceiverState struct {
// PendingBufUsed is the number of bytes pending in the receive
// queue.
- PendingBufUsed seqnum.Size
-
- // PendingBufSize is the size of the socket receive buffer.
- PendingBufSize seqnum.Size
+ PendingBufUsed int
}
// TCPSenderState holds a copy of the internal state of the sender for
@@ -405,6 +402,13 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+ // forwarding contains the whether packet forwarding is enabled or not for
+ // different network protocols.
+ forwarding struct {
+ sync.RWMutex
+ protocols map[tcpip.NetworkProtocolNumber]bool
+ }
+
// rawFactory creates raw endpoints. If nil, raw endpoints are
// disabled. It is set during Stack creation and is immutable.
rawFactory RawFactory
@@ -415,9 +419,8 @@ type Stack struct {
linkAddrCache *linkAddrCache
- mu sync.RWMutex
- nics map[tcpip.NICID]*NIC
- forwarding bool
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
// cleanupEndpointsMu protects cleanupEndpoints.
cleanupEndpointsMu sync.Mutex
@@ -749,6 +752,7 @@ func New(opts Options) *Stack {
Max: DefaultMaxBufferSize,
},
}
+ s.forwarding.protocols = make(map[tcpip.NetworkProtocolNumber]bool)
// Add specified network protocols.
for _, netProto := range opts.NetworkProtocols {
@@ -866,46 +870,42 @@ func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
-// SetForwarding enables or disables the packet forwarding between NICs.
-//
-// When forwarding becomes enabled, any host-only state on all NICs will be
-// cleaned up and if IPv6 is enabled, NDP Router Solicitations will be started.
-// When forwarding becomes disabled and if IPv6 is enabled, NDP Router
-// Solicitations will be stopped.
-func (s *Stack) SetForwarding(enable bool) {
- // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
- s.mu.Lock()
- defer s.mu.Unlock()
+// SetForwarding enables or disables packet forwarding between NICs.
+func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) {
+ s.forwarding.Lock()
+ defer s.forwarding.Unlock()
- // If forwarding status didn't change, do nothing further.
- if s.forwarding == enable {
+ // If this stack does not support the protocol, do nothing.
+ if _, ok := s.networkProtocols[protocol]; !ok {
return
}
- s.forwarding = enable
-
- // If this stack does not support IPv6, do nothing further.
- if _, ok := s.networkProtocols[header.IPv6ProtocolNumber]; !ok {
+ // If the forwarding value for this protocol hasn't changed then do
+ // nothing.
+ if forwarding := s.forwarding.protocols[protocol]; forwarding == enable {
return
}
- if enable {
- for _, nic := range s.nics {
- nic.becomeIPv6Router()
- }
- } else {
- for _, nic := range s.nics {
- nic.becomeIPv6Host()
+ s.forwarding.protocols[protocol] = enable
+
+ if protocol == header.IPv6ProtocolNumber {
+ if enable {
+ for _, nic := range s.nics {
+ nic.becomeIPv6Router()
+ }
+ } else {
+ for _, nic := range s.nics {
+ nic.becomeIPv6Host()
+ }
}
}
}
-// Forwarding returns if the packet forwarding between NICs is enabled.
-func (s *Stack) Forwarding() bool {
- // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
- s.mu.RLock()
- defer s.mu.RUnlock()
- return s.forwarding
+// Forwarding returns if packet forwarding between NICs is enabled.
+func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
+ s.forwarding.RLock()
+ defer s.forwarding.RUnlock()
+ return s.forwarding.protocols[protocol]
}
// SetRouteTable assigns the route table to be used by this stack. It
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 60b54c244..9ef6787c6 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -216,13 +216,18 @@ func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption)
}
}
-// Close implements TransportProtocol.Close.
+// ReturnError implements NetworkProtocol.ReturnError
+func (*fakeNetworkProtocol) ReturnError(*stack.Route, tcpip.ICMPReason, *stack.PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+// Close implements NetworkProtocol.Close.
func (*fakeNetworkProtocol) Close() {}
-// Wait implements TransportProtocol.Wait.
+// Wait implements NetworkProtocol.Wait.
func (*fakeNetworkProtocol) Wait() {}
-// Parse implements TransportProtocol.Parse.
+// Parse implements NetworkProtocol.Parse.
func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen)
if !ok {
@@ -2091,7 +2096,7 @@ func TestNICForwarding(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
- s.SetForwarding(true)
+ s.SetForwarding(fakeNetNumber, true)
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID1, ep1); err != nil {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index ef3457e32..cbb34d224 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -287,8 +287,8 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
- return true
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition {
+ return stack.UnknownDestinationPacketHandled
}
func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error {
@@ -549,7 +549,7 @@ func TestTransportForwarding(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
})
- s.SetForwarding(true)
+ s.SetForwarding(fakeNetNumber, true)
// TODO(b/123449044): Change this to a channel NIC.
ep1 := loopback.New()