summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/network
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/network')
-rw-r--r--pkg/tcpip/network/arp/arp.go20
-rw-r--r--pkg/tcpip/network/ip_test.go155
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go4
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go35
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go10
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go8
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go34
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go2
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go70
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go3
10 files changed, 179 insertions, 162 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 4bb7a417c..b47a7be51 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -96,14 +96,6 @@ func (e *endpoint) MTU() uint32 {
return lmtu - uint32(e.MaxHeaderLength())
}
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nic.ID()
-}
-
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
func (e *endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.ARPSize
}
@@ -145,15 +137,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
localAddr := tcpip.Address(h.ProtocolAddressTarget())
if e.nud == nil {
- if e.linkAddrCache.CheckLocalAddress(e.NICID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.NICID(), addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
} else {
- if r.Stack().CheckLocalAddress(e.NICID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
@@ -179,7 +171,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.NICID(), addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
return
}
@@ -211,11 +203,11 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
protocol: p,
nic: nic,
- linkEP: sender,
+ linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
nud: nud,
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 66450f896..56a56362e 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -252,6 +252,8 @@ func buildDummyStack(t *testing.T) *stack.Stack {
var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
+ tester testObject
+
mu struct {
sync.RWMutex
disabled bool
@@ -282,6 +284,10 @@ func (t *testInterface) setEnabled(v bool) {
t.mu.disabled = !v
}
+func (t *testInterface) LinkEndpoint() stack.LinkEndpoint {
+ return &t.tester
+}
+
func TestEnableWhenNICDisabled(t *testing.T) {
tests := []struct {
name string
@@ -312,7 +318,7 @@ func TestEnableWhenNICDisabled(t *testing.T) {
// We pass nil for all parameters except the NetworkInterface and Stack
// since Enable only depends on these.
- ep := p.NewEndpoint(&nic, nil, nil, nil, nil, s)
+ ep := p.NewEndpoint(&nic, nil, nil, nil)
// The endpoint should initially be disabled, regardless the NIC's enabled
// status.
@@ -365,10 +371,15 @@ func TestEnableWhenNICDisabled(t *testing.T) {
}
func TestIPv4Send(t *testing.T) {
- o := testObject{t: t, v4: true}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil, &o, s)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, nil)
defer ep.Close()
// Allocate and initialize the payload view.
@@ -384,10 +395,10 @@ func TestIPv4Send(t *testing.T) {
})
// Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv4Addr
- o.dstAddr = remoteIpv4Addr
- o.contents = payload
+ nic.tester.protocol = 123
+ nic.tester.srcAddr = localIpv4Addr
+ nic.tester.dstAddr = remoteIpv4Addr
+ nic.tester.contents = payload
r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
if err != nil {
@@ -403,10 +414,15 @@ func TestIPv4Send(t *testing.T) {
}
func TestIPv4Receive(t *testing.T) {
- o := testObject{t: t, v4: true}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -431,10 +447,10 @@ func TestIPv4Receive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[header.IPv4MinimumSize:totalLen]
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIpv4Addr
+ nic.tester.dstAddr = localIpv4Addr
+ nic.tester.contents = view[header.IPv4MinimumSize:totalLen]
r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
if err != nil {
@@ -447,8 +463,8 @@ func TestIPv4Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.tester.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
}
}
@@ -478,10 +494,14 @@ func TestIPv4ReceiveControl(t *testing.T) {
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -528,26 +548,31 @@ func TestIPv4ReceiveControl(t *testing.T) {
// Give packet to IPv4 endpoint, dispatcher will validate that
// it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIpv4Addr
+ nic.tester.dstAddr = localIpv4Addr
+ nic.tester.contents = view[dataOffset:]
+ nic.tester.typ = c.expectedTyp
+ nic.tester.extra = c.expectedExtra
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ if want := c.expectedCount; nic.tester.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want)
}
})
}
}
func TestIPv4FragmentationReceive(t *testing.T) {
- o := testObject{t: t, v4: true}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -590,10 +615,10 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIpv4Addr
+ nic.tester.dstAddr = localIpv4Addr
+ nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
if err != nil {
@@ -608,8 +633,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
+ if nic.tester.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls)
}
// Send second segment.
@@ -620,16 +645,20 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.tester.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
}
}
func TestIPv6Send(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, channel.New(0, 1280, ""), s)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, nil)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -649,10 +678,10 @@ func TestIPv6Send(t *testing.T) {
})
// Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv6Addr
- o.dstAddr = remoteIpv6Addr
- o.contents = payload
+ nic.tester.protocol = 123
+ nic.tester.srcAddr = localIpv6Addr
+ nic.tester.dstAddr = remoteIpv6Addr
+ nic.tester.contents = payload
r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
if err != nil {
@@ -668,10 +697,14 @@ func TestIPv6Send(t *testing.T) {
}
func TestIPv6Receive(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -695,10 +728,10 @@ func TestIPv6Receive(t *testing.T) {
}
// Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[header.IPv6MinimumSize:totalLen]
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIpv6Addr
+ nic.tester.dstAddr = localIpv6Addr
+ nic.tester.contents = view[header.IPv6MinimumSize:totalLen]
r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
if err != nil {
@@ -712,8 +745,8 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.tester.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
}
}
@@ -752,10 +785,14 @@ func TestIPv6ReceiveControl(t *testing.T) {
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -814,19 +851,19 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Give packet to IPv6 endpoint, dispatcher will validate that
// it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIpv6Addr
+ nic.tester.dstAddr = localIpv6Addr
+ nic.tester.contents = view[dataOffset:]
+ nic.tester.typ = c.expectedTyp
+ nic.tester.extra = c.expectedExtra
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ if want := c.expectedCount; nic.tester.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want)
}
})
}
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 3e5cf2ad9..5c4f715d7 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -40,7 +40,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match an address we own.
src := hdr.SourceAddress()
- if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -110,7 +110,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 41f6914b9..7adf0fac3 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -59,7 +59,6 @@ type endpoint struct {
linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
protocol *protocol
- stack *stack.Stack
// enabled is set to 1 when the enpoint is enabled and 0 when it is
// disabled.
@@ -75,13 +74,12 @@ type endpoint struct {
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
nic: nic,
- linkEP: linkEP,
+ linkEP: nic.LinkEndpoint(),
dispatcher: dispatcher,
protocol: p,
- stack: st,
}
e.mu.addressableEndpointState.Init(e)
return e
@@ -173,16 +171,6 @@ func (e *endpoint) MTU() uint32 {
return calculateMTU(e.linkEP.MTU())
}
-// Capabilities implements stack.NetworkEndpoint.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nic.ID()
-}
-
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
@@ -324,8 +312,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesOutputDropped.Increment()
@@ -341,7 +329,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
@@ -381,10 +369,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
pkt = pkt.Next()
}
- nicName := e.stack.FindNICNameFromID(e.NICID())
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
@@ -404,7 +392,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -493,7 +481,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesInputDropped.Increment()
@@ -677,6 +665,8 @@ var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
type protocol struct {
+ stack *stack.Stack
+
// defaultTTL is the current default TTL for the protocol. Only the
// uint8 portion of it is meaningful.
//
@@ -799,7 +789,7 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV ui
}
// NewProtocol returns an IPv4 network protocol.
-func NewProtocol(*stack.Stack) stack.NetworkProtocol {
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
@@ -810,6 +800,7 @@ func NewProtocol(*stack.Stack) stack.NetworkProtocol {
hashIV := r[buckets]
return &protocol{
+ stack: s,
ids: ids,
hashIV: hashIV,
defaultTTL: DefaultTTL,
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 270439b5c..4b4b483cc 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -41,7 +41,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Drop packet if it doesn't have the basic IPv6 header or if the
// original source address doesn't match an address we own.
src := hdr.SourceAddress()
- if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -248,7 +248,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// section 5.4.3.
// Is the NS targeting us?
- if r.Stack().CheckLocalAddress(e.NICID(), ProtocolNumber, targetAddr) == 0 {
+ if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
return
}
@@ -283,7 +283,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
} else if e.nud != nil {
e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(e.NICID(), r.RemoteAddress, sourceLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr)
}
// ICMPv6 Neighbor Solicit messages are always sent to
@@ -410,7 +410,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// address cache with the link address for the target of the message.
if len(targetLinkAddr) != 0 {
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.NICID(), targetAddr, targetLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr)
return
}
@@ -438,7 +438,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index b4e8a077f..5472ceb46 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -123,6 +123,10 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) LinkEndpoint() stack.LinkEndpoint {
+ return nil
+}
+
func TestICMPCounts(t *testing.T) {
tests := []struct {
name string
@@ -170,7 +174,7 @@ func TestICMPCounts(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -312,7 +316,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{})
defer ep.Close()
if err := ep.Enable(); err != nil {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 75b27a4cf..d1ad7acb7 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -351,16 +351,6 @@ func (e *endpoint) MTU() uint32 {
return calculateMTU(e.linkEP.MTU())
}
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nic.ID()
-}
-
-// Capabilities implements stack.NetworkEndpoint.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
@@ -395,8 +385,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesOutputDropped.Increment()
@@ -412,7 +402,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
return nil
@@ -455,8 +445,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
@@ -476,7 +466,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -531,7 +521,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesInputDropped.Increment()
@@ -1084,6 +1074,8 @@ var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
type protocol struct {
+ stack *stack.Stack
+
mu struct {
sync.RWMutex
@@ -1147,15 +1139,14 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
nic: nic,
- linkEP: linkEP,
+ linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
nud: nud,
dispatcher: dispatcher,
protocol: p,
- stack: st,
}
e.mu.addressableEndpointState.Init(e)
e.mu.ndp = ndpState{
@@ -1312,8 +1303,9 @@ type Options struct {
func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
opts.NDPConfigs.validate()
- return func(*stack.Stack) stack.NetworkProtocol {
+ return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
+ stack: s,
fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
ndpDisp: opts.NDPDisp,
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 3495a8b19..d85b5c00f 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -1901,7 +1901,7 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil, nil, nil).(*endpoint)
+ ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint)
{
proto.mu.Lock()
_, hasEP := proto.mu.eps[ep]
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 1b5c61b80..84c082852 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -627,7 +627,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
if addressEndpoint.GetKind() != stack.PermanentTentative {
// The endpoint should be marked as tentative since we are starting DAD.
- panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
// Should not attempt to perform DAD on an address that is currently in the
@@ -639,7 +639,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// address, or its reference count would have been increased without doing
// the work that would have been done for an address that was brand new.
// See endpoint.addAddressLocked.
- panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID()))
}
remaining := ndp.configs.DupAddrDetectTransmits
@@ -649,7 +649,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// Consider DAD to have resolved even if no DAD messages were actually
// transmitted.
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.NICID(), addr, true, nil)
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
}
return nil
@@ -661,7 +661,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// cannot be done while holding the IPv6 endpoint's lock. This is effectively
// the same as starting a goroutine but we use a timer that fires immediately
// so we can reset it for the next DAD iteration.
- timer = ndp.ep.stack.Clock().AfterFunc(0, func() {
+ timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() {
ndp.ep.mu.Lock()
defer ndp.ep.mu.Unlock()
@@ -676,7 +676,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
if addressEndpoint.GetKind() != stack.PermanentTentative {
// The endpoint should still be marked as tentative since we are still
// performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
dadDone := remaining == 0
@@ -721,7 +721,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
delete(ndp.dad, addr)
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.NICID(), addr, dadDone, err)
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
}
// If DAD resolved for a stable SLAAC address, attempt generation of a
@@ -750,7 +750,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- r, err := ndp.ep.stack.FindRoute(ndp.ep.NICID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
@@ -766,9 +766,9 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add
return err
}
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.NICID(), err))
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err))
} else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID()))
}
icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
@@ -824,7 +824,7 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
// Let the integrator know DAD did not resolve.
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.NICID(), addr, false, nil)
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil)
}
}
@@ -861,7 +861,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
if ndp.dhcpv6Configuration != configuration {
ndp.dhcpv6Configuration = configuration
- ndpDisp.OnDHCPv6Configuration(ndp.ep.NICID(), configuration)
+ ndpDisp.OnDHCPv6Configuration(ndp.ep.nic.ID(), configuration)
}
}
@@ -908,7 +908,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
}
addrs, _ := opt.Addresses()
- ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.NICID(), addrs, opt.Lifetime())
+ ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
case header.NDPDNSSearchList:
if ndp.ep.protocol.ndpDisp == nil {
@@ -916,7 +916,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
}
domainNames, _ := opt.DomainNames()
- ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.NICID(), domainNames, opt.Lifetime())
+ ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
case header.NDPPrefixInformation:
prefix := opt.Subnet()
@@ -965,7 +965,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
// Let the integrator know a discovered default router is invalidated.
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDefaultRouterInvalidated(ndp.ep.NICID(), ip)
+ ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
}
}
@@ -982,14 +982,14 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
}
// Inform the integrator when we discovered a default router.
- if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.NICID(), ip) {
+ if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) {
// Informed by the integrator to not remember the router, do
// nothing further.
return
}
state := defaultRouterState{
- invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
}
@@ -1012,14 +1012,14 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
}
// Inform the integrator when we discovered an on-link prefix.
- if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.NICID(), prefix) {
+ if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) {
// Informed by the integrator to not remember the prefix, do
// nothing further.
return
}
state := onLinkPrefixState{
- invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
ndp.invalidateOnLinkPrefix(prefix)
}),
}
@@ -1048,7 +1048,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
// Let the integrator know a discovered on-link prefix is invalidated.
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.NICID(), prefix)
+ ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix)
}
}
@@ -1164,7 +1164,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
state := slaacPrefixState{
- deprecationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
@@ -1172,7 +1172,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
ndp.deprecateSLAACAddress(state.stableAddr.addressEndpoint)
}),
- invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
@@ -1230,7 +1230,7 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config
return nil
}
- if !ndpDisp.OnAutoGenAddress(ndp.ep.NICID(), addr) {
+ if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) {
// Informed by the integrator not to add the address.
return nil
}
@@ -1276,7 +1276,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
addrBytes = header.AppendOpaqueInterfaceIdentifier(
addrBytes[:header.IIDOffsetInIPv6Address],
prefix,
- oIID.NICNameFromID(ndp.ep.NICID(), ndp.ep.nic.Name()),
+ oIID.NICNameFromID(ndp.ep.nic.ID(), ndp.ep.nic.Name()),
dadCounter,
oIID.SecretKey,
)
@@ -1433,7 +1433,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
state := tempSLAACAddrState{
- deprecationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
@@ -1446,7 +1446,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
}),
- invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
@@ -1459,7 +1459,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
}),
- regenJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ regenJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
@@ -1677,7 +1677,7 @@ func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint
addressEndpoint.SetDeprecated(true)
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.NICID(), addressEndpoint.AddressWithPrefix())
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix())
}
}
@@ -1702,7 +1702,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.NICID(), addr)
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
prefix := addr.Subnet()
@@ -1762,7 +1762,7 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.NICID(), addr)
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
if !invalidateAddr {
@@ -1878,7 +1878,7 @@ func (ndp *ndpState) startSolicitingRouters() {
var done bool
ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = ndp.ep.stack.Clock().AfterFunc(delay, func() {
+ ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() {
ndp.ep.mu.Lock()
if done {
// If we reach this point, it means that the RS timer fired after another
@@ -1904,7 +1904,7 @@ func (ndp *ndpState) startSolicitingRouters() {
ndp.ep.mu.Unlock()
localAddr := addressEndpoint.AddressWithPrefix().Address
- r, err := ndp.ep.stack.FindRoute(ndp.ep.NICID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
addressEndpoint.DecRef()
if err != nil {
return
@@ -1923,9 +1923,9 @@ func (ndp *ndpState) startSolicitingRouters() {
return
}
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.NICID(), err))
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err))
} else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID()))
}
// As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
@@ -1961,7 +1961,7 @@ func (ndp *ndpState) startSolicitingRouters() {
}, pkt,
); err != nil {
sent.Dropped.Increment()
- log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.NICID(), err)
+ log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
// Don't send any more messages if we had an error.
remaining = 0
} else {
@@ -2005,7 +2005,7 @@ func (ndp *ndpState) stopSolicitingRouters() {
// initializeTempAddrState initializes state related to temporary SLAAC
// addresses.
func (ndp *ndpState) initializeTempAddrState() {
- header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.NICID())
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID())
if MaxDesyncFactor != 0 {
ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 1947468fd..25464a03a 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -66,10 +66,11 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
if err := ep.Enable(); err != nil {
t.Fatalf("ep.Enable(): %s", err)
}
+ t.Cleanup(ep.Close)
return s, ep
}