summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/network/arp/arp.go20
-rw-r--r--pkg/tcpip/network/ip_test.go155
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go4
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go35
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go10
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go8
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go34
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go2
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go70
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go3
-rw-r--r--pkg/tcpip/stack/forwarder_test.go12
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go2
-rw-r--r--pkg/tcpip/stack/nic.go4
-rw-r--r--pkg/tcpip/stack/nic_test.go14
-rw-r--r--pkg/tcpip/stack/registration.go16
-rw-r--r--pkg/tcpip/stack/route.go9
-rw-r--r--pkg/tcpip/stack/stack.go4
-rw-r--r--pkg/tcpip/stack/stack_test.go12
-rw-r--r--pkg/tcpip/stack/transport_test.go25
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go18
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go13
-rw-r--r--pkg/tcpip/transport/udp/protocol.go13
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go9
23 files changed, 244 insertions, 248 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 4bb7a417c..b47a7be51 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -96,14 +96,6 @@ func (e *endpoint) MTU() uint32 {
return lmtu - uint32(e.MaxHeaderLength())
}
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nic.ID()
-}
-
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
func (e *endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.ARPSize
}
@@ -145,15 +137,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
localAddr := tcpip.Address(h.ProtocolAddressTarget())
if e.nud == nil {
- if e.linkAddrCache.CheckLocalAddress(e.NICID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.NICID(), addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
} else {
- if r.Stack().CheckLocalAddress(e.NICID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
@@ -179,7 +171,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.NICID(), addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
return
}
@@ -211,11 +203,11 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
protocol: p,
nic: nic,
- linkEP: sender,
+ linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
nud: nud,
}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 66450f896..56a56362e 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -252,6 +252,8 @@ func buildDummyStack(t *testing.T) *stack.Stack {
var _ stack.NetworkInterface = (*testInterface)(nil)
type testInterface struct {
+ tester testObject
+
mu struct {
sync.RWMutex
disabled bool
@@ -282,6 +284,10 @@ func (t *testInterface) setEnabled(v bool) {
t.mu.disabled = !v
}
+func (t *testInterface) LinkEndpoint() stack.LinkEndpoint {
+ return &t.tester
+}
+
func TestEnableWhenNICDisabled(t *testing.T) {
tests := []struct {
name string
@@ -312,7 +318,7 @@ func TestEnableWhenNICDisabled(t *testing.T) {
// We pass nil for all parameters except the NetworkInterface and Stack
// since Enable only depends on these.
- ep := p.NewEndpoint(&nic, nil, nil, nil, nil, s)
+ ep := p.NewEndpoint(&nic, nil, nil, nil)
// The endpoint should initially be disabled, regardless the NIC's enabled
// status.
@@ -365,10 +371,15 @@ func TestEnableWhenNICDisabled(t *testing.T) {
}
func TestIPv4Send(t *testing.T) {
- o := testObject{t: t, v4: true}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil, &o, s)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, nil)
defer ep.Close()
// Allocate and initialize the payload view.
@@ -384,10 +395,10 @@ func TestIPv4Send(t *testing.T) {
})
// Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv4Addr
- o.dstAddr = remoteIpv4Addr
- o.contents = payload
+ nic.tester.protocol = 123
+ nic.tester.srcAddr = localIpv4Addr
+ nic.tester.dstAddr = remoteIpv4Addr
+ nic.tester.contents = payload
r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
if err != nil {
@@ -403,10 +414,15 @@ func TestIPv4Send(t *testing.T) {
}
func TestIPv4Receive(t *testing.T) {
- o := testObject{t: t, v4: true}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -431,10 +447,10 @@ func TestIPv4Receive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[header.IPv4MinimumSize:totalLen]
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIpv4Addr
+ nic.tester.dstAddr = localIpv4Addr
+ nic.tester.contents = view[header.IPv4MinimumSize:totalLen]
r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
if err != nil {
@@ -447,8 +463,8 @@ func TestIPv4Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.tester.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
}
}
@@ -478,10 +494,14 @@ func TestIPv4ReceiveControl(t *testing.T) {
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -528,26 +548,31 @@ func TestIPv4ReceiveControl(t *testing.T) {
// Give packet to IPv4 endpoint, dispatcher will validate that
// it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIpv4Addr
+ nic.tester.dstAddr = localIpv4Addr
+ nic.tester.contents = view[dataOffset:]
+ nic.tester.typ = c.expectedTyp
+ nic.tester.extra = c.expectedExtra
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ if want := c.expectedCount; nic.tester.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want)
}
})
}
}
func TestIPv4FragmentationReceive(t *testing.T) {
- o := testObject{t: t, v4: true}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ v4: true,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -590,10 +615,10 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv4Addr
- o.dstAddr = localIpv4Addr
- o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIpv4Addr
+ nic.tester.dstAddr = localIpv4Addr
+ nic.tester.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
if err != nil {
@@ -608,8 +633,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 0 {
- t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
+ if nic.tester.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", nic.tester.dataCalls)
}
// Send second segment.
@@ -620,16 +645,20 @@ func TestIPv4FragmentationReceive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.tester.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
}
}
func TestIPv6Send(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, channel.New(0, 1280, ""), s)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, nil)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -649,10 +678,10 @@ func TestIPv6Send(t *testing.T) {
})
// Issue the write.
- o.protocol = 123
- o.srcAddr = localIpv6Addr
- o.dstAddr = remoteIpv6Addr
- o.contents = payload
+ nic.tester.protocol = 123
+ nic.tester.srcAddr = localIpv6Addr
+ nic.tester.dstAddr = remoteIpv6Addr
+ nic.tester.contents = payload
r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
if err != nil {
@@ -668,10 +697,14 @@ func TestIPv6Send(t *testing.T) {
}
func TestIPv6Receive(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -695,10 +728,10 @@ func TestIPv6Receive(t *testing.T) {
}
// Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[header.IPv6MinimumSize:totalLen]
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIpv6Addr
+ nic.tester.dstAddr = localIpv6Addr
+ nic.tester.contents = view[header.IPv6MinimumSize:totalLen]
r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
if err != nil {
@@ -712,8 +745,8 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
}
ep.HandlePacket(&r, pkt)
- if o.dataCalls != 1 {
- t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ if nic.tester.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", nic.tester.dataCalls)
}
}
@@ -752,10 +785,14 @@ func TestIPv6ReceiveControl(t *testing.T) {
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
- o := testObject{t: t}
s := buildDummyStack(t)
proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, &o, nil, s)
+ nic := testInterface{
+ tester: testObject{
+ t: t,
+ },
+ }
+ ep := proto.NewEndpoint(&nic, nil, nil, &nic.tester)
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -814,19 +851,19 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Give packet to IPv6 endpoint, dispatcher will validate that
// it's ok.
- o.protocol = 10
- o.srcAddr = remoteIpv6Addr
- o.dstAddr = localIpv6Addr
- o.contents = view[dataOffset:]
- o.typ = c.expectedTyp
- o.extra = c.expectedExtra
+ nic.tester.protocol = 10
+ nic.tester.srcAddr = remoteIpv6Addr
+ nic.tester.dstAddr = localIpv6Addr
+ nic.tester.contents = view[dataOffset:]
+ nic.tester.typ = c.expectedTyp
+ nic.tester.extra = c.expectedExtra
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
- if want := c.expectedCount; o.controlCalls != want {
- t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ if want := c.expectedCount; nic.tester.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.tester.controlCalls, want)
}
})
}
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 3e5cf2ad9..5c4f715d7 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -40,7 +40,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match an address we own.
src := hdr.SourceAddress()
- if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -110,7 +110,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 41f6914b9..7adf0fac3 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -59,7 +59,6 @@ type endpoint struct {
linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
protocol *protocol
- stack *stack.Stack
// enabled is set to 1 when the enpoint is enabled and 0 when it is
// disabled.
@@ -75,13 +74,12 @@ type endpoint struct {
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
nic: nic,
- linkEP: linkEP,
+ linkEP: nic.LinkEndpoint(),
dispatcher: dispatcher,
protocol: p,
- stack: st,
}
e.mu.addressableEndpointState.Init(e)
return e
@@ -173,16 +171,6 @@ func (e *endpoint) MTU() uint32 {
return calculateMTU(e.linkEP.MTU())
}
-// Capabilities implements stack.NetworkEndpoint.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nic.ID()
-}
-
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
@@ -324,8 +312,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesOutputDropped.Increment()
@@ -341,7 +329,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
@@ -381,10 +369,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
pkt = pkt.Next()
}
- nicName := e.stack.FindNICNameFromID(e.NICID())
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
@@ -404,7 +392,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -493,7 +481,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesInputDropped.Increment()
@@ -677,6 +665,8 @@ var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
type protocol struct {
+ stack *stack.Stack
+
// defaultTTL is the current default TTL for the protocol. Only the
// uint8 portion of it is meaningful.
//
@@ -799,7 +789,7 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV ui
}
// NewProtocol returns an IPv4 network protocol.
-func NewProtocol(*stack.Stack) stack.NetworkProtocol {
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
@@ -810,6 +800,7 @@ func NewProtocol(*stack.Stack) stack.NetworkProtocol {
hashIV := r[buckets]
return &protocol{
+ stack: s,
ids: ids,
hashIV: hashIV,
defaultTTL: DefaultTTL,
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 270439b5c..4b4b483cc 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -41,7 +41,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Drop packet if it doesn't have the basic IPv6 header or if the
// original source address doesn't match an address we own.
src := hdr.SourceAddress()
- if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -248,7 +248,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// section 5.4.3.
// Is the NS targeting us?
- if r.Stack().CheckLocalAddress(e.NICID(), ProtocolNumber, targetAddr) == 0 {
+ if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
return
}
@@ -283,7 +283,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
} else if e.nud != nil {
e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(e.NICID(), r.RemoteAddress, sourceLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr)
}
// ICMPv6 Neighbor Solicit messages are always sent to
@@ -410,7 +410,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// address cache with the link address for the target of the message.
if len(targetLinkAddr) != 0 {
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.NICID(), targetAddr, targetLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr)
return
}
@@ -438,7 +438,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index b4e8a077f..5472ceb46 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -123,6 +123,10 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) LinkEndpoint() stack.LinkEndpoint {
+ return nil
+}
+
func TestICMPCounts(t *testing.T) {
tests := []struct {
name string
@@ -170,7 +174,7 @@ func TestICMPCounts(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
defer ep.Close()
if err := ep.Enable(); err != nil {
@@ -312,7 +316,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, nil, &stubNUDHandler{}, &stubDispatcher{})
defer ep.Close()
if err := ep.Enable(); err != nil {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 75b27a4cf..d1ad7acb7 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -351,16 +351,6 @@ func (e *endpoint) MTU() uint32 {
return calculateMTU(e.linkEP.MTU())
}
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nic.ID()
-}
-
-// Capabilities implements stack.NetworkEndpoint.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
@@ -395,8 +385,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesOutputDropped.Increment()
@@ -412,7 +402,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
return nil
@@ -455,8 +445,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
@@ -476,7 +466,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -531,7 +521,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesInputDropped.Increment()
@@ -1084,6 +1074,8 @@ var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
type protocol struct {
+ stack *stack.Stack
+
mu struct {
sync.RWMutex
@@ -1147,15 +1139,14 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
nic: nic,
- linkEP: linkEP,
+ linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
nud: nud,
dispatcher: dispatcher,
protocol: p,
- stack: st,
}
e.mu.addressableEndpointState.Init(e)
e.mu.ndp = ndpState{
@@ -1312,8 +1303,9 @@ type Options struct {
func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
opts.NDPConfigs.validate()
- return func(*stack.Stack) stack.NetworkProtocol {
+ return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
+ stack: s,
fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
ndpDisp: opts.NDPDisp,
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 3495a8b19..d85b5c00f 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -1901,7 +1901,7 @@ func TestClearEndpointFromProtocolOnClose(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
})
proto := s.NetworkProtocolInstance(ProtocolNumber).(*protocol)
- ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil, nil, nil).(*endpoint)
+ ep := proto.NewEndpoint(&testInterface{}, nil, nil, nil).(*endpoint)
{
proto.mu.Lock()
_, hasEP := proto.mu.eps[ep]
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 1b5c61b80..84c082852 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -627,7 +627,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
if addressEndpoint.GetKind() != stack.PermanentTentative {
// The endpoint should be marked as tentative since we are starting DAD.
- panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
// Should not attempt to perform DAD on an address that is currently in the
@@ -639,7 +639,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// address, or its reference count would have been increased without doing
// the work that would have been done for an address that was brand new.
// See endpoint.addAddressLocked.
- panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID()))
}
remaining := ndp.configs.DupAddrDetectTransmits
@@ -649,7 +649,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// Consider DAD to have resolved even if no DAD messages were actually
// transmitted.
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.NICID(), addr, true, nil)
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
}
return nil
@@ -661,7 +661,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// cannot be done while holding the IPv6 endpoint's lock. This is effectively
// the same as starting a goroutine but we use a timer that fires immediately
// so we can reset it for the next DAD iteration.
- timer = ndp.ep.stack.Clock().AfterFunc(0, func() {
+ timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() {
ndp.ep.mu.Lock()
defer ndp.ep.mu.Unlock()
@@ -676,7 +676,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
if addressEndpoint.GetKind() != stack.PermanentTentative {
// The endpoint should still be marked as tentative since we are still
// performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
dadDone := remaining == 0
@@ -721,7 +721,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
delete(ndp.dad, addr)
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.NICID(), addr, dadDone, err)
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
}
// If DAD resolved for a stable SLAAC address, attempt generation of a
@@ -750,7 +750,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- r, err := ndp.ep.stack.FindRoute(ndp.ep.NICID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
@@ -766,9 +766,9 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add
return err
}
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.NICID(), err))
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err))
} else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID()))
}
icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
@@ -824,7 +824,7 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
// Let the integrator know DAD did not resolve.
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.NICID(), addr, false, nil)
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil)
}
}
@@ -861,7 +861,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
if ndp.dhcpv6Configuration != configuration {
ndp.dhcpv6Configuration = configuration
- ndpDisp.OnDHCPv6Configuration(ndp.ep.NICID(), configuration)
+ ndpDisp.OnDHCPv6Configuration(ndp.ep.nic.ID(), configuration)
}
}
@@ -908,7 +908,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
}
addrs, _ := opt.Addresses()
- ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.NICID(), addrs, opt.Lifetime())
+ ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
case header.NDPDNSSearchList:
if ndp.ep.protocol.ndpDisp == nil {
@@ -916,7 +916,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
}
domainNames, _ := opt.DomainNames()
- ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.NICID(), domainNames, opt.Lifetime())
+ ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
case header.NDPPrefixInformation:
prefix := opt.Subnet()
@@ -965,7 +965,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
// Let the integrator know a discovered default router is invalidated.
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDefaultRouterInvalidated(ndp.ep.NICID(), ip)
+ ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
}
}
@@ -982,14 +982,14 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
}
// Inform the integrator when we discovered a default router.
- if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.NICID(), ip) {
+ if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) {
// Informed by the integrator to not remember the router, do
// nothing further.
return
}
state := defaultRouterState{
- invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
}
@@ -1012,14 +1012,14 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
}
// Inform the integrator when we discovered an on-link prefix.
- if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.NICID(), prefix) {
+ if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) {
// Informed by the integrator to not remember the prefix, do
// nothing further.
return
}
state := onLinkPrefixState{
- invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
ndp.invalidateOnLinkPrefix(prefix)
}),
}
@@ -1048,7 +1048,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
// Let the integrator know a discovered on-link prefix is invalidated.
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.NICID(), prefix)
+ ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix)
}
}
@@ -1164,7 +1164,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
state := slaacPrefixState{
- deprecationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
@@ -1172,7 +1172,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
ndp.deprecateSLAACAddress(state.stableAddr.addressEndpoint)
}),
- invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
@@ -1230,7 +1230,7 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config
return nil
}
- if !ndpDisp.OnAutoGenAddress(ndp.ep.NICID(), addr) {
+ if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) {
// Informed by the integrator not to add the address.
return nil
}
@@ -1276,7 +1276,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
addrBytes = header.AppendOpaqueInterfaceIdentifier(
addrBytes[:header.IIDOffsetInIPv6Address],
prefix,
- oIID.NICNameFromID(ndp.ep.NICID(), ndp.ep.nic.Name()),
+ oIID.NICNameFromID(ndp.ep.nic.ID(), ndp.ep.nic.Name()),
dadCounter,
oIID.SecretKey,
)
@@ -1433,7 +1433,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
state := tempSLAACAddrState{
- deprecationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
@@ -1446,7 +1446,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
}),
- invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
@@ -1459,7 +1459,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
}),
- regenJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ regenJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
@@ -1677,7 +1677,7 @@ func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint
addressEndpoint.SetDeprecated(true)
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.NICID(), addressEndpoint.AddressWithPrefix())
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix())
}
}
@@ -1702,7 +1702,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.NICID(), addr)
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
prefix := addr.Subnet()
@@ -1762,7 +1762,7 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.NICID(), addr)
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
if !invalidateAddr {
@@ -1878,7 +1878,7 @@ func (ndp *ndpState) startSolicitingRouters() {
var done bool
ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = ndp.ep.stack.Clock().AfterFunc(delay, func() {
+ ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() {
ndp.ep.mu.Lock()
if done {
// If we reach this point, it means that the RS timer fired after another
@@ -1904,7 +1904,7 @@ func (ndp *ndpState) startSolicitingRouters() {
ndp.ep.mu.Unlock()
localAddr := addressEndpoint.AddressWithPrefix().Address
- r, err := ndp.ep.stack.FindRoute(ndp.ep.NICID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
addressEndpoint.DecRef()
if err != nil {
return
@@ -1923,9 +1923,9 @@ func (ndp *ndpState) startSolicitingRouters() {
return
}
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.NICID(), err))
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err))
} else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID()))
}
// As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
@@ -1961,7 +1961,7 @@ func (ndp *ndpState) startSolicitingRouters() {
}, pkt,
); err != nil {
sent.Dropped.Increment()
- log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.NICID(), err)
+ log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
// Don't send any more messages if we had an error.
remaining = 0
} else {
@@ -2005,7 +2005,7 @@ func (ndp *ndpState) stopSolicitingRouters() {
// initializeTempAddrState initializes state related to temporary SLAAC
// addresses.
func (ndp *ndpState) initializeTempAddrState() {
- header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.NICID())
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID())
if MaxDesyncFactor != 0 {
ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 1947468fd..25464a03a 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -66,10 +66,11 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s)
+ ep := netProto.NewEndpoint(&testInterface{}, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{})
if err := ep.Enable(); err != nil {
t.Fatalf("ep.Enable(): %s", err)
}
+ t.Cleanup(ep.Close)
return s, ep
}
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 572a2c3b6..4e4b00a92 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -70,10 +70,6 @@ func (f *fwdTestNetworkEndpoint) MTU() uint32 {
return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
-func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID {
- return f.nicID
-}
-
func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
@@ -91,10 +87,6 @@ func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportPr
return 0
}
-func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities {
- return f.ep.Capabilities()
-}
-
func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return f.proto.Number()
}
@@ -165,12 +157,12 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol
return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
}
-func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint {
+func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint {
e := &fwdTestNetworkEndpoint{
nicID: nic.ID(),
proto: f,
dispatcher: dispatcher,
- ep: ep,
+ ep: nic.LinkEndpoint(),
}
e.AddressableEndpointState.Init(e)
return e
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index dcc8363b2..a265fff0a 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -235,7 +235,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
},
}
nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{
- header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil, nil, nil),
+ header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil),
}
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 926ce9cfc..212c6edae 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -124,7 +124,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
nic.mu.packetEPs[netNum] = nil
- nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic, ep, stack)
+ nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
nic.linkEP.Attach(nic)
@@ -796,7 +796,7 @@ func (n *NIC) Name() string {
return n.name
}
-// LinkEndpoint returns the link endpoint of n.
+// LinkEndpoint implements NetworkInterface.
func (n *NIC) LinkEndpoint() LinkEndpoint {
return n.linkEP
}
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index df516aad7..fdd49b77f 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -60,11 +60,6 @@ func (e *testIPv6Endpoint) MTU() uint32 {
return e.linkEP.MTU() - header.IPv6MinimumSize
}
-// Capabilities implements NetworkEndpoint.Capabilities.
-func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
@@ -88,11 +83,6 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip
return tcpip.ErrNotSupported
}
-// NICID implements NetworkEndpoint.NICID.
-func (e *testIPv6Endpoint) NICID() tcpip.NICID {
- return e.nicID
-}
-
// HandlePacket implements NetworkEndpoint.HandlePacket.
func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
}
@@ -142,10 +132,10 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address)
}
// NewEndpoint implements NetworkProtocol.NewEndpoint.
-func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint {
+func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint {
e := &testIPv6Endpoint{
nicID: nic.ID(),
- linkEP: linkEP,
+ linkEP: nic.LinkEndpoint(),
protocol: p,
}
e.AddressableEndpointState.Init(e)
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index ef42fd6e1..567e1904e 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -154,10 +154,10 @@ type TransportProtocol interface {
Number() tcpip.TransportProtocolNumber
// NewEndpoint creates a new endpoint of the transport protocol.
- NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
// NewRawEndpoint creates a new raw endpoint of the transport protocol.
- NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
// MinimumPacketSize returns the minimum valid packet size of this
// transport protocol. The stack automatically drops any packets smaller
@@ -485,6 +485,9 @@ type NetworkInterface interface {
// Enabled returns true if the interface is enabled.
Enabled() bool
+
+ // LinkEndpoint returns the link endpoint backing the interface.
+ LinkEndpoint() LinkEndpoint
}
// NetworkEndpoint is the interface that needs to be implemented by endpoints
@@ -515,10 +518,6 @@ type NetworkEndpoint interface {
// minus the network endpoint max header length.
MTU() uint32
- // Capabilities returns the set of capabilities supported by the
- // underlying link-layer endpoint.
- Capabilities() LinkEndpointCapabilities
-
// MaxHeaderLength returns the maximum size the network (and lower
// level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
@@ -539,9 +538,6 @@ type NetworkEndpoint interface {
// header to the given destination address. It takes ownership of pkt.
WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error
- // NICID returns the id of the NIC this endpoint belongs to.
- NICID() tcpip.NICID
-
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint. It sets pkt.NetworkHeader.
//
@@ -586,7 +582,7 @@ type NetworkProtocol interface {
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
- NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint
+ NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 1b008a067..5ade3c832 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -72,17 +72,18 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
loop |= PacketLoop
}
+ linkEP := nic.LinkEndpoint()
r := Route{
NetProto: netProto,
LocalAddress: localAddr,
- LocalLinkAddress: nic.linkEP.LinkAddress(),
+ LocalLinkAddress: linkEP.LinkAddress(),
RemoteAddress: remoteAddr,
addressEndpoint: addressEndpoint,
nic: nic,
Loop: loop,
}
- if nic := r.nic; nic.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
+ if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
r.linkCache = nic.stack
@@ -94,7 +95,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
- return r.addressEndpoint.NetworkEndpoint().NICID()
+ return r.nic.ID()
}
// MaxHeaderLength forwards the call to the network endpoint's implementation.
@@ -115,7 +116,7 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
// Capabilities returns the link-layer capabilities of the route.
func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.addressEndpoint.NetworkEndpoint().Capabilities()
+ return r.nic.LinkEndpoint().Capabilities()
}
// GSOMaxSize returns the maximum GSO packet size.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index b740aa305..57d8e79e0 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -837,7 +837,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
return nil, tcpip.ErrUnknownProtocol
}
- return t.proto.NewEndpoint(s, network, waiterQueue)
+ return t.proto.NewEndpoint(network, waiterQueue)
}
// NewRawEndpoint creates a new raw transport layer endpoint of the given
@@ -857,7 +857,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
return nil, tcpip.ErrUnknownProtocol
}
- return t.proto.NewRawEndpoint(s, network, waiterQueue)
+ return t.proto.NewRawEndpoint(network, waiterQueue)
}
// NewPacketEndpoint creates a new packet endpoint listening for the given
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 7589306a4..fda22c550 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -91,10 +91,6 @@ func (f *fakeNetworkEndpoint) MTU() uint32 {
return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
-func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
- return f.nicID
-}
-
func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
@@ -131,10 +127,6 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto
return 0
}
-func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return f.ep.Capabilities()
-}
-
func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
return f.proto.Number()
}
@@ -207,12 +199,12 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint {
+func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &fakeNetworkEndpoint{
nicID: nic.ID(),
proto: f,
dispatcher: dispatcher,
- ep: ep,
+ ep: nic.LinkEndpoint(),
}
e.AddressableEndpointState.Init(e)
return e
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 8aae60740..62ab6d92f 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -39,7 +39,7 @@ const (
// use it.
type fakeTransportEndpoint struct {
stack.TransportEndpointInfo
- stack *stack.Stack
+
proto *fakeTransportProtocol
peerAddr tcpip.Address
route stack.Route
@@ -59,8 +59,8 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats {
func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
-func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
+ return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
func (f *fakeTransportEndpoint) Abort() {
@@ -143,7 +143,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
f.peerAddr = addr.Addr
// Find the route.
- r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
+ r, err := f.proto.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
return tcpip.ErrNoRoute
}
@@ -151,7 +151,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
+ err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -190,7 +190,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai
}
func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
- if err := f.stack.RegisterTransportEndpoint(
+ if err := f.proto.stack.RegisterTransportEndpoint(
a.NIC,
[]tcpip.NetworkProtocolNumber{fakeNetNumber},
fakeTransNumber,
@@ -218,7 +218,6 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE
f.proto.packetCount++
if f.acceptQueue != nil {
f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
- stack: f.stack,
TransportEndpointInfo: stack.TransportEndpointInfo{
ID: f.ID,
NetProto: f.NetProto,
@@ -262,6 +261,8 @@ type fakeTransportProtocolOptions struct {
// fakeTransportProtocol is a transport-layer protocol descriptor. It
// aggregates the number of packets received via endpoints of this protocol.
type fakeTransportProtocol struct {
+ stack *stack.Stack
+
packetCount int
controlCount int
opts fakeTransportProtocolOptions
@@ -271,11 +272,11 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
return fakeTransNumber
}
-func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
+func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil
}
-func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return nil, tcpip.ErrUnknownProtocol
}
@@ -326,8 +327,8 @@ func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool {
return ok
}
-func fakeTransFactory(*stack.Stack) stack.TransportProtocol {
- return &fakeTransportProtocol{}
+func fakeTransFactory(s *stack.Stack) stack.TransportProtocol {
+ return &fakeTransportProtocol{stack: s}
}
func TestTransportReceive(t *testing.T) {
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 7484f4ad9..87d510f96 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -37,6 +37,8 @@ const (
// protocol implements stack.TransportProtocol.
type protocol struct {
+ stack *stack.Stack
+
number tcpip.TransportProtocolNumber
}
@@ -57,20 +59,20 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
// NewEndpoint creates a new icmp endpoint. It implements
// stack.TransportProtocol.NewEndpoint.
-func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return newEndpoint(stack, netProto, p.number, waiterQueue)
+ return newEndpoint(p.stack, netProto, p.number, waiterQueue)
}
// NewRawEndpoint creates a new raw icmp endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return raw.NewEndpoint(stack, netProto, p.number, waiterQueue)
+ return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue)
}
// MinimumPacketSize returns the minimum valid icmp packet size.
@@ -130,11 +132,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
}
// NewProtocol4 returns an ICMPv4 transport protocol.
-func NewProtocol4(*stack.Stack) stack.TransportProtocol {
- return &protocol{ProtocolNumber4}
+func NewProtocol4(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s, number: ProtocolNumber4}
}
// NewProtocol6 returns an ICMPv6 transport protocol.
-func NewProtocol6(*stack.Stack) stack.TransportProtocol {
- return &protocol{ProtocolNumber6}
+func NewProtocol6(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s, number: ProtocolNumber6}
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 6a3c2c32b..5bce73605 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -133,6 +133,8 @@ func (s *synRcvdCounter) Threshold() uint64 {
}
type protocol struct {
+ stack *stack.Stack
+
mu sync.RWMutex
sackEnabled bool
recovery tcpip.TCPRecovery
@@ -159,14 +161,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new tcp endpoint.
-func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newEndpoint(stack, netProto, waiterQueue), nil
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue)
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue)
}
// MinimumPacketSize returns the minimum valid tcp packet size.
@@ -505,8 +507,9 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
}
// NewProtocol returns a TCP transport protocol.
-func NewProtocol(*stack.Stack) stack.TransportProtocol {
+func NewProtocol(s *stack.Stack) stack.TransportProtocol {
p := protocol{
+ stack: s,
sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{
Min: MinBufferSize,
Default: DefaultSendBufferSize,
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index e6fc23258..da5b1deb2 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -45,6 +45,7 @@ const (
)
type protocol struct {
+ stack *stack.Stack
}
// Number returns the udp protocol number.
@@ -53,14 +54,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new udp endpoint.
-func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newEndpoint(stack, netProto, waiterQueue), nil
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw UDP endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue)
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue)
}
// MinimumPacketSize returns the minimum valid udp packet size.
@@ -114,6 +115,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
}
// NewProtocol returns a UDP transport protocol.
-func NewProtocol(*stack.Stack) stack.TransportProtocol {
- return &protocol{}
+func NewProtocol(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 22a809efa..7aaedb708 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -1486,6 +1486,10 @@ func (*testInterface) Enabled() bool {
return true
}
+func (*testInterface) LinkEndpoint() stack.LinkEndpoint {
+ return nil
+}
+
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1509,10 +1513,7 @@ func TestTTL(t *testing.T) {
} else {
p = ipv6.NewProtocol(nil)
}
- ep := p.NewEndpoint(&testInterface{}, nil, nil, nil, nil, stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
- TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
- }))
+ ep := p.NewEndpoint(&testInterface{}, nil, nil, nil)
wantTTL = ep.DefaultTTL()
ep.Close()
}