summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/conntrack.go4
-rw-r--r--pkg/tcpip/stack/forwarder_test.go683
-rw-r--r--pkg/tcpip/stack/iptables.go131
-rw-r--r--pkg/tcpip/stack/iptables_types.go70
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go77
-rw-r--r--pkg/tcpip/stack/ndp_test.go792
-rw-r--r--pkg/tcpip/stack/nic.go125
-rw-r--r--pkg/tcpip/stack/nic_test.go6
-rw-r--r--pkg/tcpip/stack/nud_test.go16
-rw-r--r--pkg/tcpip/stack/packet_buffer.go2
-rw-r--r--pkg/tcpip/stack/registration.go12
-rw-r--r--pkg/tcpip/stack/route.go42
-rw-r--r--pkg/tcpip/stack/stack.go219
-rw-r--r--pkg/tcpip/stack/stack_test.go73
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go18
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go4
-rw-r--r--pkg/tcpip/stack/transport_test.go93
17 files changed, 1437 insertions, 930 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 7dd344b4f..836682ea0 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -572,7 +572,9 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim
// reapTupleLocked tries to remove tuple and its reply from the table. It
// returns whether the tuple's connection has timed out.
//
-// Preconditions: ct.mu is locked for reading and bucket is locked.
+// Preconditions:
+// * ct.mu is locked for reading.
+// * bucket is locked.
func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
if !tuple.conn.timedOut(now) {
return false
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 9dff23623..38c5bac71 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -16,6 +16,7 @@ package stack
import (
"encoding/binary"
+ "math"
"testing"
"time"
@@ -25,8 +26,9 @@ import (
)
const (
- fwdTestNetHeaderLen = 12
- fwdTestNetDefaultPrefixLen = 8
+ fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+ fwdTestNetHeaderLen = 12
+ fwdTestNetDefaultPrefixLen = 8
// fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
// except where another value is explicitly used. It is chosen to match
@@ -49,6 +51,8 @@ type fwdTestNetworkEndpoint struct {
ep LinkEndpoint
}
+var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
+
func (f *fwdTestNetworkEndpoint) MTU() uint32 {
return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
@@ -90,7 +94,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
b[srcAddrOffset] = r.LocalAddress[0]
b[protocolNumberOffset] = byte(params.Protocol)
- return f.ep.WritePacket(r, gso, fakeNetNumber, pkt)
+ return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
}
// WritePackets implements LinkEndpoint.WritePackets.
@@ -108,15 +112,17 @@ func (*fwdTestNetworkEndpoint) Close() {}
// resolution.
type fwdTestNetworkProtocol struct {
addrCache *linkAddrCache
+ neigh *neighborCache
addrResolveDelay time.Duration
- onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress)
+ onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
}
+var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil)
var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)
func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
- return fakeNetNumber
+ return fwdTestNetNumber
}
func (f *fwdTestNetworkProtocol) MinimumPacketSize() int {
@@ -139,7 +145,7 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol
return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
}
-func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint {
+func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint {
return &fwdTestNetworkEndpoint{
nicID: nicID,
proto: f,
@@ -148,22 +154,22 @@ func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache Li
}
}
-func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
+func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error {
+func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func (f *fwdTestNetworkProtocol) Close() {}
+func (*fwdTestNetworkProtocol) Close() {}
-func (f *fwdTestNetworkProtocol) Wait() {}
+func (*fwdTestNetworkProtocol) Wait() {}
func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
- if f.addrCache != nil && f.onLinkAddressResolved != nil {
+ if f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
- f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr)
+ f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr)
})
}
return nil
@@ -176,8 +182,8 @@ func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip
return "", false
}
-func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
- return fakeNetNumber
+func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return fwdTestNetNumber
}
// fwdTestPacketInfo holds all the information about an outbound packet.
@@ -298,13 +304,16 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco
panic("not implemented")
}
-func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
+func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) {
// Create a stack with the network protocol and two NICs.
s := New(Options{
NetworkProtocols: []NetworkProtocol{proto},
+ UseNeighborCache: useNeighborCache,
})
- proto.addrCache = s.linkAddrCache
+ if !useNeighborCache {
+ proto.addrCache = s.linkAddrCache
+ }
// Enable forwarding.
s.SetForwarding(proto.Number(), true)
@@ -318,7 +327,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f
if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC #1 failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress #1 failed:", err)
}
@@ -331,10 +340,19 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f
if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC #2 failed:", err)
}
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
+ if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil {
t.Fatal("AddAddress #2 failed:", err)
}
+ if useNeighborCache {
+ // Control the neighbor cache for NIC 2.
+ nic, ok := s.nics[2]
+ if !ok {
+ t.Fatal("failed to get the neighbor cache for NIC 2")
+ }
+ proto.neigh = nic.neigh
+ }
+
// Route all packets to NIC 2.
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
@@ -348,79 +366,129 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f
}
func TestForwardingWithStaticResolver(t *testing.T) {
- // Create a network protocol with a static resolver.
- proto := &fwdTestNetworkProtocol{
- onResolveStaticAddress:
- // The network address 3 is resolved to the link address "c".
- func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
- if addr == "\x03" {
- return "c", true
- }
- return "", false
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ // Create a network protocol with a static resolver.
+ proto := &fwdTestNetworkProtocol{
+ onResolveStaticAddress:
+ // The network address 3 is resolved to the link address "c".
+ func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "\x03" {
+ return "c", true
+ }
+ return "", false
+ },
+ }
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
+ ep1, ep2 := fwdTestNetFactory(t, proto, test.useNeighborCache)
- var p fwdTestPacketInfo
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
- select {
- case p = <-ep2.C:
- default:
- t.Fatal("packet not forwarded")
- }
+ var p fwdTestPacketInfo
- // Test that the static address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ select {
+ case p = <-ep2.C:
+ default:
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the static address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ })
}
}
func TestForwardingWithFakeResolver(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
- // Any address will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any address will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Any address will be resolved to the link address "c".
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
- var p fwdTestPacketInfo
+ var p fwdTestPacketInfo
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ })
}
}
@@ -428,13 +496,15 @@ func TestForwardingWithNoResolver(t *testing.T) {
// Create a network protocol without a resolver.
proto := &fwdTestNetworkProtocol{}
- ep1, ep2 := fwdTestNetFactory(t, proto)
+ // Whether or not we use the neighbor cache here does not matter since
+ // neither linkAddrCache nor neighborCache will be used.
+ ep1, ep2 := fwdTestNetFactory(t, proto, false /* useNeighborCache */)
// inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
buf[dstAddrOffset] = 3
- ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
Data: buf.ToVectorisedView(),
}))
@@ -446,203 +516,334 @@ func TestForwardingWithNoResolver(t *testing.T) {
}
func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
- // Only packets to address 3 will be resolved to the
- // link address "c".
- if addr == "\x03" {
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
- }
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Only packets to address 3 will be resolved to the
+ // link address "c".
+ if addr == "\x03" {
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ }
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Only packets to address 3 will be resolved to the
+ // link address "c".
+ if addr == "\x03" {
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ }
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
-
- // Inject an inbound packet to address 4 on NIC 1. This packet should
- // not be forwarded.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 4
- ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- // Inject an inbound packet to address 3 on NIC 1, and see if it is
- // forwarded to NIC 2.
- buf = buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
-
- var p fwdTestPacketInfo
-
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
+
+ // Inject an inbound packet to address 4 on NIC 1. This packet should
+ // not be forwarded.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 4
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf = buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
- if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
- }
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ })
}
}
func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
- // Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Any packets will be resolved to the link address "c".
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
-
- // Inject two inbound packets to address 3 on NIC 1.
- for i := 0; i < 2; i++ {
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
-
- for i := 0; i < 2; i++ {
- var p fwdTestPacketInfo
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
-
- if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
- }
+ // Inject two inbound packets to address 3 on NIC 1.
+ for i := 0; i < 2; i++ {
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
+ for i := 0; i < 2; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+ })
}
}
func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
- // Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Any packets will be resolved to the link address "c".
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
-
- for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
- // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = 3
- // Set the packet sequence number.
- binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
- ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
-
- for i := 0; i < maxPendingPacketsPerResolution; i++ {
- var p fwdTestPacketInfo
-
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
-
- b := PayloadSince(p.Pkt.NetworkHeader())
- if b[dstAddrOffset] != 3 {
- t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
- }
- if len(b) < fwdTestNetHeaderLen+2 {
- t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b)
- }
- seqNumBuf := b[fwdTestNetHeaderLen:]
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
- // The first 5 packets should not be forwarded so the sequence number should
- // start with 5.
- want := uint16(i + 5)
- if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
- t.Fatalf("got the packet #%d, want = #%d", n, want)
- }
+ for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
+ // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ // Set the packet sequence number.
+ binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
+ for i := 0; i < maxPendingPacketsPerResolution; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ b := PayloadSince(p.Pkt.NetworkHeader())
+ if b[dstAddrOffset] != 3 {
+ t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
+ }
+ if len(b) < fwdTestNetHeaderLen+2 {
+ t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b)
+ }
+ seqNumBuf := b[fwdTestNetHeaderLen:]
+
+ // The first 5 packets should not be forwarded so the sequence number should
+ // start with 5.
+ want := uint16(i + 5)
+ if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
+ t.Fatalf("got the packet #%d, want = #%d", n, want)
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+ })
}
}
func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
- // Create a network protocol with a fake resolver.
- proto := &fwdTestNetworkProtocol{
- addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
- // Any packets will be resolved to the link address "c".
- cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ tests := []struct {
+ name string
+ useNeighborCache bool
+ proto *fwdTestNetworkProtocol
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ },
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ proto: &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) {
+ t.Helper()
+ if len(remoteLinkAddr) != 0 {
+ t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr)
+ }
+ // Any packets will be resolved to the link address "c".
+ neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ },
+ },
},
}
- ep1, ep2 := fwdTestNetFactory(t, proto)
-
- for i := 0; i < maxPendingResolutions+5; i++ {
- // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
- // Each packet has a different destination address (3 to
- // maxPendingResolutions + 7).
- buf := buffer.NewView(30)
- buf[dstAddrOffset] = byte(3 + i)
- ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
- Data: buf.ToVectorisedView(),
- }))
- }
-
- for i := 0; i < maxPendingResolutions; i++ {
- var p fwdTestPacketInfo
-
- select {
- case p = <-ep2.C:
- case <-time.After(time.Second):
- t.Fatal("packet not forwarded")
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache)
- // The first 5 packets (address 3 to 7) should not be forwarded
- // because their address resolutions are interrupted.
- if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 {
- t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset])
- }
+ for i := 0; i < maxPendingResolutions+5; i++ {
+ // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
+ // Each packet has a different destination address (3 to
+ // maxPendingResolutions + 7).
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = byte(3 + i)
+ ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
- // Test that the address resolution happened correctly.
- if p.RemoteLinkAddress != "c" {
- t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
- }
- if p.LocalLinkAddress != "b" {
- t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
- }
+ for i := 0; i < maxPendingResolutions; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ // The first 5 packets (address 3 to 7) should not be forwarded
+ // because their address resolutions are interrupted.
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+ })
}
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index c37da814f..4a521eca9 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -57,7 +57,72 @@ const reaperDelay = 5 * time.Second
// all packets.
func DefaultTables() *IPTables {
return &IPTables{
- tables: [numTables]Table{
+ v4Tables: [numTables]Table{
+ natID: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: 0,
+ Input: 1,
+ Forward: HookUnset,
+ Output: 2,
+ Postrouting: 3,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: 1,
+ Forward: HookUnset,
+ Output: 2,
+ Postrouting: 3,
+ },
+ },
+ mangleID: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: 0,
+ Output: 1,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: 1,
+ Postrouting: HookUnset,
+ },
+ },
+ filterID: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
+ },
+ },
+ },
+ v6Tables: [numTables]Table{
natID: Table{
Rules: []Rule{
Rule{Target: AcceptTarget{}},
@@ -165,18 +230,21 @@ func EmptyNATTable() Table {
}
// GetTable returns a table by name.
-func (it *IPTables) GetTable(name string) (Table, bool) {
+func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) {
id, ok := nameToID[name]
if !ok {
return Table{}, false
}
it.mu.RLock()
defer it.mu.RUnlock()
- return it.tables[id], true
+ if ipv6 {
+ return it.v6Tables[id], true
+ }
+ return it.v4Tables[id], true
}
// ReplaceTable replaces or inserts table by name.
-func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error {
+func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error {
id, ok := nameToID[name]
if !ok {
return tcpip.ErrInvalidOptionValue
@@ -190,7 +258,11 @@ func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error {
it.startReaper(reaperDelay)
}
it.modified = true
- it.tables[id] = table
+ if ipv6 {
+ it.v6Tables[id] = table
+ } else {
+ it.v4Tables[id] = table
+ }
return nil
}
@@ -213,8 +285,15 @@ const (
// should continue traversing the network stack and false when it should be
// dropped.
//
+// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from
+// which address and nicName can be gathered. Currently, address is only
+// needed for prerouting and nicName is only needed for output.
+//
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool {
+func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool {
+ if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
+ return true
+ }
// Many users never configure iptables. Spare them the cost of rule
// traversal if rules have never been set.
it.mu.RLock()
@@ -235,9 +314,14 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
if tableID == natID && pkt.NatDone {
continue
}
- table := it.tables[tableID]
+ var table Table
+ if pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber {
+ table = it.v6Tables[tableID]
+ } else {
+ table = it.v4Tables[tableID]
+ }
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -248,7 +332,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr
// Any Return from a built-in chain means we have to
// call the underflow.
underflow := table.Rules[table.Underflows[hook]]
- switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v {
+ switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v {
case RuleAccept:
continue
case RuleDrop:
@@ -315,8 +399,8 @@ func (it *IPTables) startReaper(interval time.Duration) {
// should not go forward.
//
// Preconditions:
-// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// - pkt.NetworkHeader is not nil.
+// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// * pkt.NetworkHeader is not nil.
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
@@ -341,13 +425,13 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
}
// Preconditions:
-// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// - pkt.NetworkHeader is not nil.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict {
+// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// * pkt.NetworkHeader is not nil.
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
case RuleAccept:
return chainAccept
@@ -364,7 +448,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -388,13 +472,13 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
}
// Preconditions:
-// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
-// - pkt.NetworkHeader is not nil.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) {
+// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// * pkt.NetworkHeader is not nil.
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
- if !rule.Filter.match(header.IPv4(pkt.NetworkHeader().View()), hook, nicName) {
+ if !rule.Filter.match(pkt, hook, nicName) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
@@ -413,11 +497,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
}
// All the matchers matched, so run the target.
- return rule.Target.Action(pkt, &it.connections, hook, gso, r, address)
+ return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr)
}
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ if !it.modified {
+ return "", 0, tcpip.ErrNotConnected
+ }
return it.connections.originalDst(epID)
}
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 73274ada9..093ee6881 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -15,6 +15,7 @@
package stack
import (
+ "fmt"
"strings"
"sync"
@@ -81,26 +82,25 @@ const (
//
// +stateify savable
type IPTables struct {
- // mu protects tables, priorities, and modified.
+ // mu protects v4Tables, v6Tables, and modified.
mu sync.RWMutex
-
- // tables maps tableIDs to tables. Holds builtin tables only, not user
- // tables. mu must be locked for accessing.
- tables [numTables]Table
-
- // priorities maps each hook to a list of table names. The order of the
- // list is the order in which each table should be visited for that
- // hook. mu needs to be locked for accessing.
- priorities [NumHooks][]tableID
-
+ // v4Tables and v6tables map tableIDs to tables. They hold builtin
+ // tables only, not user tables. mu must be locked for accessing.
+ v4Tables [numTables]Table
+ v6Tables [numTables]Table
// modified is whether tables have been modified at least once. It is
// used to elide the iptables performance overhead for workloads that
// don't utilize iptables.
modified bool
+ // priorities maps each hook to a list of table names. The order of the
+ // list is the order in which each table should be visited for that
+ // hook. It is immutable.
+ priorities [NumHooks][]tableID
+
connections ConnTrack
- // reaperDone can be signalled to stop the reaper goroutine.
+ // reaperDone can be signaled to stop the reaper goroutine.
reaperDone chan struct{}
}
@@ -148,13 +148,18 @@ type Rule struct {
Target Target
}
-// IPHeaderFilter holds basic IP filtering data common to every rule.
+// IPHeaderFilter performs basic IP header matching common to every rule.
//
// +stateify savable
type IPHeaderFilter struct {
// Protocol matches the transport protocol.
Protocol tcpip.TransportProtocolNumber
+ // CheckProtocol determines whether the Protocol field should be
+ // checked during matching.
+ // TODO(gvisor.dev/issue/3549): Check this field during matching.
+ CheckProtocol bool
+
// Dst matches the destination IP address.
Dst tcpip.Address
@@ -191,16 +196,43 @@ type IPHeaderFilter struct {
OutputInterfaceInvert bool
}
-// match returns whether hdr matches the filter.
-func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool {
- // TODO(gvisor.dev/issue/170): Support other fields of the filter.
+// match returns whether pkt matches the filter.
+//
+// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4
+// or IPv6 header length.
+func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool {
+ // Extract header fields.
+ var (
+ // TODO(gvisor.dev/issue/170): Support other filter fields.
+ transProto tcpip.TransportProtocolNumber
+ dstAddr tcpip.Address
+ srcAddr tcpip.Address
+ )
+ switch proto := pkt.NetworkProtocolNumber; proto {
+ case header.IPv4ProtocolNumber:
+ hdr := header.IPv4(pkt.NetworkHeader().View())
+ transProto = hdr.TransportProtocol()
+ dstAddr = hdr.DestinationAddress()
+ srcAddr = hdr.SourceAddress()
+
+ case header.IPv6ProtocolNumber:
+ hdr := header.IPv6(pkt.NetworkHeader().View())
+ transProto = hdr.TransportProtocol()
+ dstAddr = hdr.DestinationAddress()
+ srcAddr = hdr.SourceAddress()
+
+ default:
+ panic(fmt.Sprintf("unknown network protocol with EtherType: %d", proto))
+ }
+
// Check the transport protocol.
- if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() {
+ if fl.CheckProtocol && fl.Protocol != transProto {
return false
}
- // Check the source and destination IPs.
- if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) {
+ // Check the addresses.
+ if !filterAddress(dstAddr, fl.DstMask, fl.Dst, fl.DstInvert) ||
+ !filterAddress(srcAddr, fl.SrcMask, fl.Src, fl.SrcInvert) {
return false
}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index b15b8d1cb..33806340e 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -16,6 +16,7 @@ package stack
import (
"fmt"
+ "math"
"sync/atomic"
"testing"
"time"
@@ -191,7 +192,13 @@ func TestCacheReplace(t *testing.T) {
}
func TestCacheResolution(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
+ // There is a race condition causing this test to fail when the executor
+ // takes longer than the resolution timeout to call linkAddrCache.get. This
+ // is especially common when this test is run with gotsan.
+ //
+ // Using a large resolution timeout decreases the probability of experiencing
+ // this race condition and does not affect how long this test takes to run.
+ c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1)
linkRes := &testLinkAddressResolver{cache: c}
for i, ta := range testAddrs {
got, err := getBlocking(c, ta.addr, linkRes)
@@ -275,3 +282,71 @@ func TestStaticResolution(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
}
}
+
+// TestCacheWaker verifies that RemoveWaker removes a waker previously added
+// through get().
+func TestCacheWaker(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+
+ // First, sanity check that wakers are working.
+ {
+ linkRes := &testLinkAddressResolver{cache: c}
+ s := sleep.Sleeper{}
+ defer s.Done()
+
+ const wakerID = 1
+ w := sleep.Waker{}
+ s.AddWaker(&w, wakerID)
+
+ e := testAddrs[0]
+
+ if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock)
+ }
+ id, ok := s.Fetch(true /* block */)
+ if !ok {
+ t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)")
+ }
+ if id != wakerID {
+ t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID)
+ }
+
+ if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil {
+ t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err)
+ } else if got != e.linkAddr {
+ t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr)
+ }
+ }
+
+ // Check that RemoveWaker works.
+ {
+ linkRes := &testLinkAddressResolver{cache: c}
+ s := sleep.Sleeper{}
+ defer s.Done()
+
+ const wakerID = 2 // different than the ID used in the sanity check
+ w := sleep.Waker{}
+ s.AddWaker(&w, wakerID)
+
+ e := testAddrs[1]
+ linkRes.onLinkAddressRequest = func() {
+ // Remove the waker before the linkAddrCache has the opportunity to send
+ // a notification.
+ c.removeWaker(e.addr, &w)
+ }
+
+ if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock)
+ }
+
+ if got, err := getBlocking(c, e.addr, linkRes); err != nil {
+ t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err)
+ } else if got != e.linkAddr {
+ t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr)
+ }
+
+ if id, ok := s.Fetch(false /* block */); ok {
+ t.Fatalf("unexpected notification from waker with id %d", id)
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 1a6724c31..5e43a9b0b 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -2787,7 +2787,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
// stack.Stack will have a default route through the router (llAddr3) installed
// and a static link-address (linkAddr3) added to the link address cache for the
// router.
-func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
+func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useNeighborCache bool) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
t.Helper()
ndpDisp := &ndpDispatcher{
autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
@@ -2800,7 +2800,8 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
HandleRAs: true,
AutoGenGlobalAddresses: true,
},
- NDPDisp: ndpDisp,
+ NDPDisp: ndpDisp,
+ UseNeighborCache: useNeighborCache,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -2810,7 +2811,11 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd
Gateway: llAddr3,
NIC: nicID,
}})
- s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+ if useNeighborCache {
+ s.AddStaticNeighbor(nicID, llAddr3, linkAddr3)
+ } else {
+ s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+ }
return ndpDisp, e, s
}
@@ -2884,110 +2889,128 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA
// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when
// receiving a PI with 0 preferred lifetime.
func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
- const nicID = 1
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ const nicID = 1
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
}
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
- }
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
- // Receive PI for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- expectPrimaryAddr(addr1)
+ // Receive PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ expectPrimaryAddr(addr1)
- // Deprecate addr for prefix1 immedaitely.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
- expectAutoGenAddrEvent(addr1, deprecatedAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- // addr should still be the primary endpoint as there are no other addresses.
- expectPrimaryAddr(addr1)
+ // Deprecate addr for prefix1 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ // addr should still be the primary endpoint as there are no other addresses.
+ expectPrimaryAddr(addr1)
- // Refresh lifetimes of addr generated from prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
+ // Refresh lifetimes of addr generated from prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
- // Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
- // Deprecate addr for prefix2 immedaitely.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- expectAutoGenAddrEvent(addr2, deprecatedAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr1 should be the primary endpoint now since addr2 is deprecated but
- // addr1 is not.
- expectPrimaryAddr(addr1)
- // addr2 is deprecated but if explicitly requested, it should be used.
- fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
- if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
- }
+ // Deprecate addr for prefix2 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr1 should be the primary endpoint now since addr2 is deprecated but
+ // addr1 is not.
+ expectPrimaryAddr(addr1)
+ // addr2 is deprecated but if explicitly requested, it should be used.
+ fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
+ }
- // Another PI w/ 0 preferred lifetime should not result in a deprecation
- // event.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
- }
+ // Another PI w/ 0 preferred lifetime should not result in a deprecation
+ // event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
+ }
- // Refresh lifetimes of addr generated from prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
+ // Refresh lifetimes of addr generated from prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+ })
}
- expectPrimaryAddr(addr2)
}
// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated
@@ -2996,217 +3019,236 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
const nicID = 1
const newMinVL = 2
newMinVLDuration := newMinVL * time.Second
- saved := stack.MinPrefixInformationValidLifetimeForUpdate
- defer func() {
- stack.MinPrefixInformationValidLifetimeForUpdate = saved
- }()
- stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
+ }
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
}
- default:
- t.Fatal("expected addr auto gen event")
- }
- }
- expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
- t.Helper()
+ expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
}
- case <-time.After(timeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- }
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
- }
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
- // Receive PI for prefix2.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
- expectAutoGenAddrEvent(addr2, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
- // Receive a PI for prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
- expectAutoGenAddrEvent(addr1, newAddr)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr1)
+ // Receive a PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr1)
- // Refresh lifetime for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr1)
+ // Refresh lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
- // Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr2 should be the primary endpoint now since addr1 is deprecated but
- // addr2 is not.
- expectPrimaryAddr(addr2)
- // addr1 is deprecated but if explicitly requested, it should be used.
- fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since addr1 is deprecated but
+ // addr2 is not.
+ expectPrimaryAddr(addr2)
+ // addr1 is deprecated but if explicitly requested, it should be used.
+ fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
- // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
- // sure we do not get a deprecation event again.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- expectPrimaryAddr(addr2)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
+ // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
+ // sure we do not get a deprecation event again.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
- // Refresh lifetimes for addr of prefix1.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
- // addr1 is the primary endpoint again since it is non-deprecated now.
- expectPrimaryAddr(addr1)
+ // Refresh lifetimes for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ // addr1 is the primary endpoint again since it is non-deprecated now.
+ expectPrimaryAddr(addr1)
- // Wait for addr of prefix1 to be deprecated.
- expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- // addr2 should be the primary endpoint now since it is not deprecated.
- expectPrimaryAddr(addr2)
- if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
- t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
- }
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since it is not deprecated.
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
- // Wait for addr of prefix1 to be invalidated.
- expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should have %s in the list of addresses", addr2)
- }
- expectPrimaryAddr(addr2)
+ // Wait for addr of prefix1 to be invalidated.
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
- // Refresh both lifetimes for addr of prefix2 to the same value.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- default:
- }
+ // Refresh both lifetimes for addr of prefix2 to the same value.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
- // Wait for a deprecation then invalidation events, or just an invalidation
- // event. We need to cover both cases but cannot deterministically hit both
- // cases because the deprecation and invalidation handlers could be handled in
- // either deprecation then invalidation, or invalidation then deprecation
- // (which should be cancelled by the invalidation handler).
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
- // If we get a deprecation event first, we should get an invalidation
- // event almost immediately after.
+ // Wait for a deprecation then invalidation events, or just an invalidation
+ // event. We need to cover both cases but cannot deterministically hit both
+ // cases because the deprecation and invalidation handlers could be handled in
+ // either deprecation then invalidation, or invalidation then deprecation
+ // (which should be cancelled by the invalidation handler).
select {
case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
+ // If we get a deprecation event first, we should get an invalidation
+ // event almost immediately after.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
+ // If we get an invalidation event first, we should not get a deprecation
+ // event after.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ case <-time.After(defaultAsyncNegativeEventTimeout):
+ }
+ } else {
+ t.Fatalf("got unexpected auto-generated event")
}
- case <-time.After(defaultAsyncPositiveEventTimeout):
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for addr auto gen event")
}
- } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
- // If we get an invalidation event first, we should not get a deprecation
- // event after.
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto-generated event")
- case <-time.After(defaultAsyncNegativeEventTimeout):
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should not have %s in the list of addresses", addr2)
+ }
+ // Should not have any primary endpoints.
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
}
- } else {
- t.Fatalf("got unexpected auto-generated event")
- }
- case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
- t.Fatal("timed out waiting for addr auto gen event")
- }
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
- t.Fatalf("should not have %s in the list of addresses", addr1)
- }
- if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
- t.Fatalf("should not have %s in the list of addresses", addr2)
- }
- // Should not have any primary endpoints.
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
- }
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
- ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
- if err != nil {
- t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
- }
- defer ep.Close()
- if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
- t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
- }
- if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
- t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute)
+ if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
+ t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute)
+ }
+ })
}
}
@@ -3524,110 +3566,128 @@ func TestAutoGenAddrRemoval(t *testing.T) {
func TestAutoGenAddrAfterRemoval(t *testing.T) {
const nicID = 1
- prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
- prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
- ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
-
- expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
- t.Helper()
-
- select {
- case e := <-ndpDisp.autoGenAddrC:
- if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
- t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
- }
- default:
- t.Fatal("expected addr auto gen event")
- }
+ stacks := []struct {
+ name string
+ useNeighborCache bool
+ }{
+ {
+ name: "linkAddrCache",
+ useNeighborCache: false,
+ },
+ {
+ name: "neighborCache",
+ useNeighborCache: true,
+ },
}
- expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
- t.Helper()
+ for _, stackTyp := range stacks {
+ t.Run(stackTyp.name, func(t *testing.T) {
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache)
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
- }
-
- if got := addrForNewConnection(t, s); got != addr.Address {
- t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
- }
- }
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
- // Receive a PI to auto-generate addr1 with a large valid and preferred
- // lifetime.
- const largeLifetimeSeconds = 999
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- expectAutoGenAddrEvent(addr1, newAddr)
- expectPrimaryAddr(addr1)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
- // Add addr2 as a static address.
- protoAddr2 := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addr2,
- }
- if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
- t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
- }
- // addr2 should be more preferred now since it is at the front of the primary
- // list.
- expectPrimaryAddr(addr2)
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
- // Get a route using addr2 to increment its reference count then remove it
- // to leave it in the permanentExpired state.
- r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
- if err != nil {
- t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
- }
- defer r.Release()
- if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
- t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
- }
- // addr1 should be preferred again since addr2 is in the expired state.
- expectPrimaryAddr(addr1)
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
- // Receive a PI to auto-generate addr2 as valid and preferred.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- expectAutoGenAddrEvent(addr2, newAddr)
- // addr2 should be more preferred now that it is closer to the front of the
- // primary list and not deprecated.
- expectPrimaryAddr(addr2)
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
- // Removing the address should result in an invalidation event immediately.
- // It should still be in the permanentExpired state because r is still held.
- //
- // We remove addr2 here to make sure addr2 was marked as a SLAAC address
- // (it was previously marked as a static address).
- if err := s.RemoveAddress(1, addr2.Address); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
- }
- expectAutoGenAddrEvent(addr2, invalidatedAddr)
- // addr1 should be more preferred since addr2 is in the expired state.
- expectPrimaryAddr(addr1)
+ // Receive a PI to auto-generate addr1 with a large valid and preferred
+ // lifetime.
+ const largeLifetimeSeconds = 999
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectPrimaryAddr(addr1)
- // Receive a PI to auto-generate addr2 as valid and deprecated.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
- expectAutoGenAddrEvent(addr2, newAddr)
- // addr1 should still be more preferred since addr2 is deprecated, even though
- // it is closer to the front of the primary list.
- expectPrimaryAddr(addr1)
+ // Add addr2 as a static address.
+ protoAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr2,
+ }
+ if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ }
+ // addr2 should be more preferred now since it is at the front of the primary
+ // list.
+ expectPrimaryAddr(addr2)
- // Receive a PI to refresh addr2's preferred lifetime.
- e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Fatal("unexpectedly got an auto gen addr event")
- default:
- }
- // addr2 should be more preferred now that it is not deprecated.
- expectPrimaryAddr(addr2)
+ // Get a route using addr2 to increment its reference count then remove it
+ // to leave it in the permanentExpired state.
+ r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
+ }
+ defer r.Release()
+ if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
+ }
+ // addr1 should be preferred again since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and preferred.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr2 should be more preferred now that it is closer to the front of the
+ // primary list and not deprecated.
+ expectPrimaryAddr(addr2)
+
+ // Removing the address should result in an invalidation event immediately.
+ // It should still be in the permanentExpired state because r is still held.
+ //
+ // We remove addr2 here to make sure addr2 was marked as a SLAAC address
+ // (it was previously marked as a static address).
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ // addr1 should be more preferred since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr1 should still be more preferred since addr2 is deprecated, even though
+ // it is closer to the front of the primary list.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to refresh addr2's preferred lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto gen addr event")
+ default:
+ }
+ // addr2 should be more preferred now that it is not deprecated.
+ expectPrimaryAddr(addr2)
- if err := s.RemoveAddress(1, addr2.Address); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ expectPrimaryAddr(addr1)
+ })
}
- expectAutoGenAddrEvent(addr2, invalidatedAddr)
- expectPrimaryAddr(addr1)
}
// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index e74d2562a..be274773c 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -21,6 +21,7 @@ import (
"sort"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -135,24 +136,33 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
}
nic.mu.ndp.initializeTempAddrState()
- // Register supported packet endpoint protocols.
- for _, netProto := range header.Ethertypes {
- nic.mu.packetEPs[netProto] = []PacketEndpoint{}
- }
- for _, netProto := range stack.networkProtocols {
- netNum := netProto.Number()
- nic.mu.packetEPs[netNum] = nil
- nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nic, ep, stack)
- }
-
// Check for Neighbor Unreachability Detection support.
- if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 {
+ var nud NUDHandler
+ if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 && stack.useNeighborCache {
rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds()))
nic.neigh = &neighborCache{
nic: nic,
state: NewNUDState(stack.nudConfigs, rng),
cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
}
+
+ // An interface value that holds a nil pointer but non-nil type is not the
+ // same as the nil interface. Because of this, nud must only be assignd if
+ // nic.neigh is non-nil since a nil reference to a neighborCache is not
+ // valid.
+ //
+ // See https://golang.org/doc/faq#nil_error for more information.
+ nud = nic.neigh
+ }
+
+ // Register supported packet and network endpoint protocols.
+ for _, netProto := range header.Ethertypes {
+ nic.mu.packetEPs[netProto] = []PacketEndpoint{}
+ }
+ for _, netProto := range stack.networkProtocols {
+ netNum := netProto.Number()
+ nic.mu.packetEPs[netNum] = nil
+ nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nud, nic, ep, stack)
}
nic.linkEP.Attach(nic)
@@ -431,7 +441,7 @@ func (n *NIC) setSpoofing(enable bool) {
// If an IPv6 primary endpoint is requested, Source Address Selection (as
// defined by RFC 6724 section 5) will be performed.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint {
- if protocol == header.IPv6ProtocolNumber && remoteAddr != "" {
+ if protocol == header.IPv6ProtocolNumber && len(remoteAddr) != 0 {
return n.primaryIPv6Endpoint(remoteAddr)
}
@@ -655,22 +665,15 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
}
}
- // Check if address is a broadcast address for the endpoint's network.
- //
- // Only IPv4 has a notion of broadcast addresses.
if protocol == header.IPv4ProtocolNumber {
- if ref := n.getRefForBroadcastRLocked(address); ref != nil {
+ if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil {
n.mu.RUnlock()
return ref
}
}
-
- // A usable reference was not found, create a temporary one if requested by
- // the caller or if the address is found in the NIC's subnets.
- createTempEP := spoofingOrPromiscuous
n.mu.RUnlock()
- if !createTempEP {
+ if !spoofingOrPromiscuous {
return nil
}
@@ -683,20 +686,21 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
return ref
}
-// getRefForBroadcastLocked returns an endpoint where address is the IPv4
-// broadcast address for the endpoint's network.
+// getRefForBroadcastOrLoopbackRLocked returns an endpoint whose address is the
+// broadcast address for the endpoint's network or an address in the endpoint's
+// subnet if the NIC is a loopback interface. This matches linux behaviour.
//
-// n.mu MUST be read locked.
-func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint {
+// n.mu MUST be read or write locked.
+func (n *NIC) getIPv4RefForBroadcastOrLoopbackRLocked(address tcpip.Address) *referencedNetworkEndpoint {
for _, ref := range n.mu.endpoints {
- // Only IPv4 has a notion of broadcast addresses.
+ // Only IPv4 has a notion of broadcast addresses or considers the loopback
+ // interface bound to an address's whole subnet (on linux).
if ref.protocol != header.IPv4ProtocolNumber {
continue
}
- addr := ref.addrWithPrefix()
- subnet := addr.Subnet()
- if subnet.IsBroadcast(address) && ref.tryIncRef() {
+ subnet := ref.addrWithPrefix().Subnet()
+ if (subnet.IsBroadcast(address) || (n.isLoopback() && subnet.Contains(address))) && ref.isValidForOutgoingRLocked() && ref.tryIncRef() {
return ref
}
}
@@ -724,11 +728,8 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add
n.removeEndpointLocked(ref)
}
- // Check if address is a broadcast address for an endpoint's network.
- //
- // Only IPv4 has a notion of broadcast addresses.
if protocol == header.IPv4ProtocolNumber {
- if ref := n.getRefForBroadcastRLocked(address); ref != nil {
+ if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil {
return ref
}
}
@@ -833,10 +834,11 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
deprecated: deprecated,
}
- // Set up cache if link address resolution exists for this protocol.
+ // Set up resolver if link address resolution exists for this protocol.
if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
- if _, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok {
+ if linkRes, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok {
ref.linkCache = n.stack
+ ref.linkRes = linkRes
}
}
@@ -1071,6 +1073,51 @@ func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
return n.removePermanentAddressLocked(addr)
}
+func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) {
+ if n.neigh == nil {
+ return nil, tcpip.ErrNotSupported
+ }
+
+ return n.neigh.entries(), nil
+}
+
+func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) {
+ if n.neigh == nil {
+ return
+ }
+
+ n.neigh.removeWaker(addr, w)
+}
+
+func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error {
+ if n.neigh == nil {
+ return tcpip.ErrNotSupported
+ }
+
+ n.neigh.addStaticEntry(addr, linkAddress)
+ return nil
+}
+
+func (n *NIC) removeNeighbor(addr tcpip.Address) *tcpip.Error {
+ if n.neigh == nil {
+ return tcpip.ErrNotSupported
+ }
+
+ if !n.neigh.removeEntry(addr) {
+ return tcpip.ErrBadAddress
+ }
+ return nil
+}
+
+func (n *NIC) clearNeighbors() *tcpip.Error {
+ if n.neigh == nil {
+ return tcpip.ErrNotSupported
+ }
+
+ n.neigh.clear()
+ return nil
+}
+
// joinGroup adds a new endpoint for the given multicast address, if none
// exists yet. Otherwise it just increments its count.
func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
@@ -1235,14 +1282,14 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
return
}
- // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet.
// Loopback traffic skips the prerouting chain.
- if protocol == header.IPv4ProtocolNumber && !n.isLoopback() {
+ if !n.isLoopback() {
// iptables filtering.
ipt := n.stack.IPTables()
address := n.primaryAddress(protocol)
if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok {
// iptables is telling us to drop the packet.
+ n.stack.stats.IP.IPTablesPreroutingDropped.Increment()
return
}
}
@@ -1652,6 +1699,10 @@ type referencedNetworkEndpoint struct {
// protocol. Set to nil otherwise.
linkCache LinkAddressCache
+ // linkRes is set if link address resolution is enabled for this protocol.
+ // Set to nil otherwise.
+ linkRes LinkAddressResolver
+
// refs is counting references held for this endpoint. When refs hits zero it
// triggers the automatic removal of the endpoint from the NIC.
refs int32
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index d312a79eb..dd6474297 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -192,7 +192,7 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address)
}
// NewEndpoint implements NetworkProtocol.NewEndpoint.
-func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint {
+func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint {
return &testIPv6Endpoint{
nicID: nicID,
linkEP: linkEP,
@@ -201,12 +201,12 @@ func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _
}
// SetOption implements NetworkProtocol.SetOption.
-func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error {
+func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
return nil
}
// Option implements NetworkProtocol.Option.
-func (*testIPv6Protocol) Option(interface{}) *tcpip.Error {
+func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
index 2494ee610..2b97e5972 100644
--- a/pkg/tcpip/stack/nud_test.go
+++ b/pkg/tcpip/stack/nud_test.go
@@ -61,6 +61,7 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) {
// stack will only allocate neighbor caches if a protocol providing link
// address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ UseNeighborCache: true,
})
// No NIC with ID 1 yet.
@@ -84,7 +85,8 @@ func TestNUDConfigurationFailsForNotSupported(t *testing.T) {
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(stack.Options{
- NUDConfigs: stack.DefaultNUDConfigurations(),
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -108,7 +110,8 @@ func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) {
e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(stack.Options{
- NUDConfigs: stack.DefaultNUDConfigurations(),
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -136,6 +139,7 @@ func TestDefaultNUDConfigurations(t *testing.T) {
// address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: stack.DefaultNUDConfigurations(),
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -190,6 +194,7 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -246,6 +251,7 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -325,6 +331,7 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -386,6 +393,7 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -437,6 +445,7 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -488,6 +497,7 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -539,6 +549,7 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
@@ -590,6 +601,7 @@ func TestNUDConfigurationsUnreachableTime(t *testing.T) {
// providing link address resolution is specified (e.g. ARP or IPv6).
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NUDConfigs: c,
+ UseNeighborCache: true,
})
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 17b8beebb..1932aaeb7 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -80,7 +80,7 @@ type PacketBuffer struct {
// data are held in the same underlying buffer storage.
header buffer.Prependable
- // NetworkProtocol is only valid when NetworkHeader is set.
+ // NetworkProtocolNumber is only valid when NetworkHeader is set.
// TODO(gvisor.dev/issue/3574): Remove the separately passed protocol
// numbers in registration APIs that take a PacketBuffer.
NetworkProtocolNumber tcpip.NetworkProtocolNumber
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index aca2f77f8..4fa86a3ac 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -159,12 +159,12 @@ type TransportProtocol interface {
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
- SetOption(option interface{}) *tcpip.Error
+ SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error
// Option allows retrieving protocol specific option values.
// Option returns an error if the option is not supported or the
// provided option value is invalid.
- Option(option interface{}) *tcpip.Error
+ Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error
// Close requests that any worker goroutines owned by the protocol
// stop.
@@ -298,17 +298,17 @@ type NetworkProtocol interface {
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
- NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint
+ NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
- SetOption(option interface{}) *tcpip.Error
+ SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error
// Option allows retrieving protocol specific option values.
// Option returns an error if the option is not supported or the
// provided option value is invalid.
- Option(option interface{}) *tcpip.Error
+ Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error
// Close requests that any worker goroutines owned by the protocol
// stop.
@@ -488,7 +488,7 @@ type LinkAddressResolver interface {
ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool)
// LinkAddressProtocol returns the network protocol of the
- // addresses this this resolver can resolve.
+ // addresses this resolver can resolve.
LinkAddressProtocol() tcpip.NetworkProtocolNumber
}
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index e267bebb0..2cbbf0de8 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -48,10 +48,6 @@ type Route struct {
// Loop controls where WritePacket should send packets.
Loop PacketLooping
-
- // directedBroadcast indicates whether this route is sending a directed
- // broadcast packet.
- directedBroadcast bool
}
// makeRoute initializes a new route. It takes ownership of the provided
@@ -141,6 +137,16 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
}
nextAddr = r.RemoteAddress
}
+
+ if r.ref.nic.neigh != nil {
+ entry, ch, err := r.ref.nic.neigh.entry(nextAddr, r.LocalAddress, r.ref.linkRes, waker)
+ if err != nil {
+ return ch, err
+ }
+ r.RemoteLinkAddress = entry.LinkAddr
+ return nil, nil
+ }
+
linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
if err != nil {
return ch, err
@@ -155,6 +161,12 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
if nextAddr == "" {
nextAddr = r.RemoteAddress
}
+
+ if r.ref.nic.neigh != nil {
+ r.ref.nic.neigh.removeWaker(nextAddr, waker)
+ return
+ }
+
r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker)
}
@@ -163,6 +175,9 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
//
// The NIC r uses must not be locked.
func (r *Route) IsResolutionRequired() bool {
+ if r.ref.nic.neigh != nil {
+ return r.ref.isValidForOutgoing() && r.ref.linkRes != nil && r.RemoteLinkAddress == ""
+ }
return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
@@ -284,24 +299,27 @@ func (r *Route) Stack() *Stack {
return r.ref.stack()
}
+func (r *Route) isV4Broadcast(addr tcpip.Address) bool {
+ if addr == header.IPv4Broadcast {
+ return true
+ }
+
+ subnet := r.ref.addrWithPrefix().Subnet()
+ return subnet.IsBroadcast(addr)
+}
+
// IsOutboundBroadcast returns true if the route is for an outbound broadcast
// packet.
func (r *Route) IsOutboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
- return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast
+ return r.isV4Broadcast(r.RemoteAddress)
}
// IsInboundBroadcast returns true if the route is for an inbound broadcast
// packet.
func (r *Route) IsInboundBroadcast() bool {
// Only IPv4 has a notion of broadcast.
- if r.LocalAddress == header.IPv4Broadcast {
- return true
- }
-
- addr := r.ref.addrWithPrefix()
- subnet := addr.Subnet()
- return subnet.IsBroadcast(r.LocalAddress)
+ return r.isV4Broadcast(r.LocalAddress)
}
// ReverseRoute returns new route with given source and destination address.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 814b3e94a..68cf77de2 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -22,7 +22,6 @@ package stack
import (
"bytes"
"encoding/binary"
- "math"
mathrand "math/rand"
"sync/atomic"
"time"
@@ -51,41 +50,6 @@ const (
DefaultTOS = 0
)
-const (
- // fakeNetNumber is used as a protocol number in tests.
- //
- // This constant should match fakeNetNumber in stack_test.go.
- fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
-)
-
-type forwardingFlag uint32
-
-// Packet forwarding flags. Forwarding settings for different network protocols
-// are stored as bit flags in an uint32 number.
-const (
- forwardingIPv4 forwardingFlag = 1 << iota
- forwardingIPv6
-
- // forwardingFake is used to test package forwarding with a fake protocol.
- forwardingFake
-)
-
-func getForwardingFlag(protocol tcpip.NetworkProtocolNumber) forwardingFlag {
- var flag forwardingFlag
- switch protocol {
- case header.IPv4ProtocolNumber:
- flag = forwardingIPv4
- case header.IPv6ProtocolNumber:
- flag = forwardingIPv6
- case fakeNetNumber:
- // This network protocol number is used to test packet forwarding.
- flag = forwardingFake
- default:
- // We only support forwarding for IPv4 and IPv6.
- }
- return flag
-}
-
type transportProtocolState struct {
proto TransportProtocol
defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
@@ -284,7 +248,7 @@ type RcvBufAutoTuneParams struct {
// was started.
MeasureTime time.Time
- // CopiedBytes is the number of bytes copied to userspace since
+ // CopiedBytes is the number of bytes copied to user space since
// this measure began.
CopiedBytes int
@@ -441,6 +405,13 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+ // forwarding contains the whether packet forwarding is enabled or not for
+ // different network protocols.
+ forwarding struct {
+ sync.RWMutex
+ protocols map[tcpip.NetworkProtocolNumber]bool
+ }
+
// rawFactory creates raw endpoints. If nil, raw endpoints are
// disabled. It is set during Stack creation and is immutable.
rawFactory RawFactory
@@ -454,14 +425,9 @@ type Stack struct {
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
- // forwarding contains the enable bits for packet forwarding for different
- // network protocols.
- forwarding struct {
- sync.RWMutex
- flag forwardingFlag
- }
-
- cleanupEndpoints map[TransportEndpoint]struct{}
+ // cleanupEndpointsMu protects cleanupEndpoints.
+ cleanupEndpointsMu sync.Mutex
+ cleanupEndpoints map[TransportEndpoint]struct{}
// route is the route table passed in by the user via SetRouteTable(),
// it is used by FindRoute() to build a route for a specific
@@ -472,7 +438,7 @@ type Stack struct {
// If not nil, then any new endpoints will have this probe function
// invoked everytime they receive a TCP segment.
- tcpProbeFunc TCPProbeFunc
+ tcpProbeFunc atomic.Value // TCPProbeFunc
// clock is used to generate user-visible times.
clock tcpip.Clock
@@ -504,6 +470,10 @@ type Stack struct {
// nudConfigs is the default NUD configurations used by interfaces.
nudConfigs NUDConfigurations
+ // useNeighborCache indicates whether ARP and NDP packets should be handled
+ // by the NIC's neighborCache instead of linkAddrCache.
+ useNeighborCache bool
+
// autoGenIPv6LinkLocal determines whether or not the stack will attempt
// to auto-generate an IPv6 link-local address for newly enabled non-loopback
// NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
@@ -584,6 +554,13 @@ type Options struct {
// NUDConfigs is the default NUD configurations used by interfaces.
NUDConfigs NUDConfigurations
+ // UseNeighborCache indicates whether ARP and NDP packets should be handled
+ // by the Neighbor Unreachability Detection (NUD) state machine. This flag
+ // also enables the APIs for inspecting and modifying the neighbor table via
+ // NUDDispatcher and the following Stack methods: Neighbors, RemoveNeighbor,
+ // and ClearNeighbors.
+ UseNeighborCache bool
+
// AutoGenIPv6LinkLocal determines whether or not the stack will attempt to
// auto-generate an IPv6 link-local address for newly enabled non-loopback
// NICs.
@@ -758,6 +735,7 @@ func New(opts Options) *Stack {
seed: generateRandUint32(),
ndpConfigs: opts.NDPConfigs,
nudConfigs: opts.NUDConfigs,
+ useNeighborCache: opts.UseNeighborCache,
autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
uniqueIDGenerator: opts.UniqueID,
ndpDisp: opts.NDPDisp,
@@ -777,6 +755,7 @@ func New(opts Options) *Stack {
Max: DefaultMaxBufferSize,
},
}
+ s.forwarding.protocols = make(map[tcpip.NetworkProtocolNumber]bool)
// Add specified network protocols.
for _, netProto := range opts.NetworkProtocols {
@@ -816,7 +795,7 @@ func (s *Stack) UniqueID() uint64 {
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
// is incorrect.
-func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
netProto, ok := s.networkProtocols[network]
if !ok {
return tcpip.ErrUnknownProtocol
@@ -833,7 +812,7 @@ func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, op
// if err != nil {
// ...
// }
-func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
netProto, ok := s.networkProtocols[network]
if !ok {
return tcpip.ErrUnknownProtocol
@@ -845,7 +824,7 @@ func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, optio
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
// is incorrect.
-func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) *tcpip.Error {
transProtoState, ok := s.transportProtocols[transport]
if !ok {
return tcpip.ErrUnknownProtocol
@@ -860,7 +839,7 @@ func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumb
// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
// ...
// }
-func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) *tcpip.Error {
transProtoState, ok := s.transportProtocols[transport]
if !ok {
return tcpip.ErrUnknownProtocol
@@ -904,23 +883,14 @@ func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool)
return
}
- flag := getForwardingFlag(protocol)
-
// If the forwarding value for this protocol hasn't changed then do
// nothing.
- if s.forwarding.flag&getForwardingFlag(protocol) != 0 == enable {
+ if forwarding := s.forwarding.protocols[protocol]; forwarding == enable {
return
}
- var newValue forwardingFlag
- if enable {
- newValue = s.forwarding.flag | flag
- } else {
- newValue = s.forwarding.flag & ^flag
- }
- s.forwarding.flag = newValue
+ s.forwarding.protocols[protocol] = enable
- // Enable or disable NDP for IPv6.
if protocol == header.IPv6ProtocolNumber {
if enable {
for _, nic := range s.nics {
@@ -938,7 +908,7 @@ func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool)
func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
s.forwarding.RLock()
defer s.forwarding.RUnlock()
- return s.forwarding.flag&getForwardingFlag(protocol) != 0
+ return s.forwarding.protocols[protocol]
}
// SetRouteTable assigns the route table to be used by this stack. It
@@ -1257,8 +1227,8 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
+ nic, ok := s.nics[id]
+ if !ok {
return tcpip.ErrUnknownNICID
}
@@ -1344,13 +1314,11 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
}
r := makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback())
- r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr)
-
if len(route.Gateway) > 0 {
if needRoute {
r.NextHop = route.Gateway
}
- } else if r.directedBroadcast {
+ } else if subnet := ref.addrWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) {
r.RemoteLinkAddress = header.EthernetBroadcastAddress
}
@@ -1383,8 +1351,8 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto
// If a NIC is specified, we try to find the address there only.
if nicID != 0 {
- nic := s.nics[nicID]
- if nic == nil {
+ nic, ok := s.nics[nicID]
+ if !ok {
return 0
}
@@ -1415,8 +1383,8 @@ func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[nicID]
- if nic == nil {
+ nic, ok := s.nics[nicID]
+ if !ok {
return tcpip.ErrUnknownNICID
}
@@ -1431,8 +1399,8 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[nicID]
- if nic == nil {
+ nic, ok := s.nics[nicID]
+ if !ok {
return tcpip.ErrUnknownNICID
}
@@ -1464,8 +1432,33 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
}
-// RemoveWaker implements LinkAddressCache.RemoveWaker.
+// Neighbors returns all IP to MAC address associations.
+func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if !ok {
+ return nil, tcpip.ErrUnknownNICID
+ }
+
+ return nic.neighbors()
+}
+
+// RemoveWaker removes a waker that has been added when link resolution for
+// addr was requested.
func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
+ if s.useNeighborCache {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if ok {
+ nic.removeWaker(addr, waker)
+ }
+ return
+ }
+
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1475,6 +1468,47 @@ func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.
}
}
+// AddStaticNeighbor statically associates an IP address to a MAC address.
+func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.addStaticNeighbor(addr, linkAddr)
+}
+
+// RemoveNeighbor removes an IP to MAC address association previously created
+// either automically or by AddStaticNeighbor. Returns ErrBadAddress if there
+// is no association with the provided address.
+func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.removeNeighbor(addr)
+}
+
+// ClearNeighbors removes all IP to MAC address associations.
+func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[nicID]
+ s.mu.RUnlock()
+
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.clearNeighbors()
+}
+
// RegisterTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
@@ -1498,10 +1532,9 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip
// StartTransportEndpointCleanup removes the endpoint with the given id from
// the stack transport dispatcher. It also transitions it to the cleanup stage.
func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
+ s.cleanupEndpointsMu.Lock()
s.cleanupEndpoints[ep] = struct{}{}
+ s.cleanupEndpointsMu.Unlock()
s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
@@ -1509,9 +1542,9 @@ func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcp
// CompleteTransportEndpointCleanup removes the endpoint from the cleanup
// stage.
func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
- s.mu.Lock()
+ s.cleanupEndpointsMu.Lock()
delete(s.cleanupEndpoints, ep)
- s.mu.Unlock()
+ s.cleanupEndpointsMu.Unlock()
}
// FindTransportEndpoint finds an endpoint that most closely matches the provided
@@ -1554,23 +1587,23 @@ func (s *Stack) RegisteredEndpoints() []TransportEndpoint {
// CleanupEndpoints returns endpoints currently in the cleanup state.
func (s *Stack) CleanupEndpoints() []TransportEndpoint {
- s.mu.Lock()
+ s.cleanupEndpointsMu.Lock()
es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints))
for e := range s.cleanupEndpoints {
es = append(es, e)
}
- s.mu.Unlock()
+ s.cleanupEndpointsMu.Unlock()
return es
}
// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful
// for restoring a stack after a save.
func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
- s.mu.Lock()
+ s.cleanupEndpointsMu.Lock()
for _, e := range es {
s.cleanupEndpoints[e] = struct{}{}
}
- s.mu.Unlock()
+ s.cleanupEndpointsMu.Unlock()
}
// Close closes all currently registered transport endpoints.
@@ -1765,18 +1798,17 @@ func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) Tra
// guarantee provided on which probe will be invoked. Ideally this should only
// be called once per stack.
func (s *Stack) AddTCPProbe(probe TCPProbeFunc) {
- s.mu.Lock()
- s.tcpProbeFunc = probe
- s.mu.Unlock()
+ s.tcpProbeFunc.Store(probe)
}
// GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil
// otherwise.
func (s *Stack) GetTCPProbe() TCPProbeFunc {
- s.mu.Lock()
- p := s.tcpProbeFunc
- s.mu.Unlock()
- return p
+ p := s.tcpProbeFunc.Load()
+ if p == nil {
+ return nil
+ }
+ return p.(TCPProbeFunc)
}
// RemoveTCPProbe removes an installed TCP probe.
@@ -1785,9 +1817,8 @@ func (s *Stack) GetTCPProbe() TCPProbeFunc {
// have a probe attached. Endpoints already created will continue to invoke
// TCP probe.
func (s *Stack) RemoveTCPProbe() {
- s.mu.Lock()
- s.tcpProbeFunc = nil
- s.mu.Unlock()
+ // This must be TCPProbeFunc(nil) because atomic.Value.Store(nil) panics.
+ s.tcpProbeFunc.Store(TCPProbeFunc(nil))
}
// JoinGroup joins the given multicast group on the given NIC.
@@ -2009,7 +2040,7 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres
return nil, tcpip.ErrBadAddress
}
-// FindNICNameFromID returns the name of the nic for the given NICID.
+// FindNICNameFromID returns the name of the NIC for the given NICID.
func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
s.mu.RLock()
defer s.mu.RUnlock()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index f168be402..7669ba672 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -42,9 +42,6 @@ import (
)
const (
- // fakeNetNumber is used as a protocol number in tests.
- //
- // This constant should match fakeNetNumber in stack.go.
fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
fakeNetHeaderLen = 12
fakeDefaultPrefixLen = 8
@@ -161,23 +158,13 @@ func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack
func (*fakeNetworkEndpoint) Close() {}
-type fakeNetGoodOption bool
-
-type fakeNetBadOption bool
-
-type fakeNetInvalidValueOption int
-
-type fakeNetOptions struct {
- good bool
-}
-
// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the
// number of packets sent and received via endpoints of this protocol. The index
// where packets are added is given by the packet's destination address MOD 10.
type fakeNetworkProtocol struct {
packetCount [10]int
sendPacketCount [10]int
- opts fakeNetOptions
+ defaultTTL uint8
}
func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
@@ -200,7 +187,7 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint {
+func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint {
return &fakeNetworkEndpoint{
nicID: nicID,
proto: f,
@@ -209,22 +196,20 @@ func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack
}
}
-func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
+func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case fakeNetGoodOption:
- f.opts.good = bool(v)
+ case *tcpip.DefaultTTLOption:
+ f.defaultTTL = uint8(*v)
return nil
- case fakeNetInvalidValueOption:
- return tcpip.ErrInvalidOptionValue
default:
return tcpip.ErrUnknownProtocolOption
}
}
-func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
+func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case *fakeNetGoodOption:
- *v = fakeNetGoodOption(f.opts.good)
+ case *tcpip.DefaultTTLOption:
+ *v = tcpip.DefaultTTLOption(f.defaultTTL)
return nil
default:
return tcpip.ErrUnknownProtocolOption
@@ -1643,46 +1628,24 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
}
}
-func TestNetworkOptions(t *testing.T) {
+func TestNetworkOption(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
TransportProtocols: []stack.TransportProtocol{},
})
- // Try an unsupported network protocol.
- if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
- t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
+ opt := tcpip.DefaultTTLOption(5)
+ if err := s.SetNetworkProtocolOption(fakeNetNumber, &opt); err != nil {
+ t.Fatalf("s.SetNetworkProtocolOption(%d, &%T(%d)): %s", fakeNetNumber, opt, opt, err)
}
- testCases := []struct {
- option interface{}
- wantErr *tcpip.Error
- verifier func(t *testing.T, p stack.NetworkProtocol)
- }{
- {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) {
- t.Helper()
- fakeNet := p.(*fakeNetworkProtocol)
- if fakeNet.opts.good != true {
- t.Fatalf("fakeNet.opts.good = false, want = true")
- }
- var v fakeNetGoodOption
- if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil {
- t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err)
- }
- if v != true {
- t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v)
- }
- }},
- {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
- {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
+ var optGot tcpip.DefaultTTLOption
+ if err := s.NetworkProtocolOption(fakeNetNumber, &optGot); err != nil {
+ t.Fatalf("s.NetworkProtocolOption(%d, &%T): %s", fakeNetNumber, optGot, err)
}
- for _, tc := range testCases {
- if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr {
- t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr)
- }
- if tc.verifier != nil {
- tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber))
- }
+
+ if opt != optGot {
+ t.Errorf("got optGot = %d, want = %d", optGot, opt)
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index b902c6ca9..0774b5382 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
- if isMulticastOrBroadcast(id.LocalAddress) {
+ if isInboundMulticastOrBroadcast(r) {
mpep.handlePacketAll(r, id, pkt)
epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
@@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) {
eps.mu.RLock()
destEPs := eps.findAllEndpointsLocked(id)
eps.mu.RUnlock()
@@ -546,7 +546,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a TCP packet with a non-unicast source or destination
// address, then do nothing further and instruct the caller to do the same.
- if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ if protocol == header.TCPProtocolNumber && (!isInboundUnicast(r) || !isOutboundUnicast(r)) {
// TCP can only be used to communicate between a single source and a
// single destination; the addresses must be unicast.
r.Stats().TCP.InvalidSegmentsReceived.Increment()
@@ -677,10 +677,14 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
eps.mu.Unlock()
}
-func isMulticastOrBroadcast(addr tcpip.Address) bool {
- return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
+func isInboundMulticastOrBroadcast(r *Route) bool {
+ return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress)
}
-func isUnicast(addr tcpip.Address) bool {
- return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
+func isInboundUnicast(r *Route) bool {
+ return r.LocalAddress != header.IPv4Any && r.LocalAddress != header.IPv6Any && !isInboundMulticastOrBroadcast(r)
+}
+
+func isOutboundUnicast(r *Route) bool {
+ return r.RemoteAddress != header.IPv4Any && r.RemoteAddress != header.IPv6Any && !r.IsOutboundBroadcast() && !header.IsV4MulticastAddress(r.RemoteAddress) && !header.IsV6MulticastAddress(r.RemoteAddress)
}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 1339edc2d..4d6d62eec 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -312,8 +312,8 @@ func TestBindToDeviceDistribution(t *testing.T) {
t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err)
}
bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
- if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
- t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err)
+ if err := ep.SetSockOpt(&bindToDeviceOption); err != nil {
+ t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err)
}
var dstAddr tcpip.Address
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index fa4b14ba6..64e44bc99 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -53,11 +53,11 @@ func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo {
return &f.TransportEndpointInfo
}
-func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
+func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats {
return nil
}
-func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
+func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
@@ -100,12 +100,12 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return int64(len(v)), nil, nil
}
-func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
// SetSockOpt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
+func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
@@ -130,11 +130,7 @@ func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.E
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch opt.(type) {
- case tcpip.ErrorOption:
- return nil
- }
+func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
@@ -169,7 +165,7 @@ func (f *fakeTransportEndpoint) UniqueID() uint64 {
return f.uniqueID
}
-func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
+func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
return nil
}
@@ -184,7 +180,7 @@ func (*fakeTransportEndpoint) Listen(int) *tcpip.Error {
return nil
}
-func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
if len(f.acceptQueue) == 0 {
return nil, nil, nil
}
@@ -239,19 +235,19 @@ func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, s
f.proto.controlCount++
}
-func (f *fakeTransportEndpoint) State() uint32 {
+func (*fakeTransportEndpoint) State() uint32 {
return 0
}
-func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
+func (*fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
-func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) {
- return stack.IPTables{}, nil
-}
+func (*fakeTransportEndpoint) Resume(*stack.Stack) {}
-func (f *fakeTransportEndpoint) Resume(*stack.Stack) {}
+func (*fakeTransportEndpoint) Wait() {}
-func (f *fakeTransportEndpoint) Wait() {}
+func (*fakeTransportEndpoint) LastError() *tcpip.Error {
+ return nil
+}
type fakeTransportGoodOption bool
@@ -295,22 +291,20 @@ func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack
return true
}
-func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error {
+func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case fakeTransportGoodOption:
- f.opts.good = bool(v)
+ case *tcpip.TCPModerateReceiveBufferOption:
+ f.opts.good = bool(*v)
return nil
- case fakeTransportInvalidValueOption:
- return tcpip.ErrInvalidOptionValue
default:
return tcpip.ErrUnknownProtocolOption
}
}
-func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
+func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case *fakeTransportGoodOption:
- *v = fakeTransportGoodOption(f.opts.good)
+ case *tcpip.TCPModerateReceiveBufferOption:
+ *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good)
return nil
default:
return tcpip.ErrUnknownProtocolOption
@@ -537,41 +531,16 @@ func TestTransportOptions(t *testing.T) {
TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
})
- // Try an unsupported transport protocol.
- if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
- t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
- }
-
- testCases := []struct {
- option interface{}
- wantErr *tcpip.Error
- verifier func(t *testing.T, p stack.TransportProtocol)
- }{
- {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) {
- t.Helper()
- fakeTrans := p.(*fakeTransportProtocol)
- if fakeTrans.opts.good != true {
- t.Fatalf("fakeTrans.opts.good = false, want = true")
- }
- var v fakeTransportGoodOption
- if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil {
- t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err)
- }
- if v != true {
- t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v)
- }
-
- }},
- {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
- {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
- }
- for _, tc := range testCases {
- if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr {
- t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr)
- }
- if tc.verifier != nil {
- tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber))
- }
+ v := tcpip.TCPModerateReceiveBufferOption(true)
+ if err := s.SetTransportProtocolOption(fakeTransNumber, &v); err != nil {
+ t.Errorf("s.SetTransportProtocolOption(fakeTrans, &%T(%t)): %s", v, v, err)
+ }
+ v = false
+ if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil {
+ t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &%T): %s", v, err)
+ }
+ if !v {
+ t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true")
}
}
@@ -635,7 +604,7 @@ func TestTransportForwarding(t *testing.T) {
Data: req.ToVectorisedView(),
}))
- aep, _, err := ep.Accept()
+ aep, _, err := ep.Accept(nil)
if err != nil || aep == nil {
t.Fatalf("Accept failed: %v, %v", aep, err)
}