summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/network/arp/arp.go20
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go4
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go35
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go10
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go34
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go70
-rw-r--r--pkg/tcpip/stack/nic.go4
-rw-r--r--pkg/tcpip/stack/registration.go16
-rw-r--r--pkg/tcpip/stack/route.go9
-rw-r--r--pkg/tcpip/stack/stack.go4
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go18
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go13
-rw-r--r--pkg/tcpip/transport/udp/protocol.go13
13 files changed, 114 insertions, 136 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 4bb7a417c..b47a7be51 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -96,14 +96,6 @@ func (e *endpoint) MTU() uint32 {
return lmtu - uint32(e.MaxHeaderLength())
}
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nic.ID()
-}
-
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
func (e *endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.ARPSize
}
@@ -145,15 +137,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
localAddr := tcpip.Address(h.ProtocolAddressTarget())
if e.nud == nil {
- if e.linkAddrCache.CheckLocalAddress(e.NICID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ if e.linkAddrCache.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.NICID(), addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
} else {
- if r.Stack().CheckLocalAddress(e.NICID(), header.IPv4ProtocolNumber, localAddr) == 0 {
+ if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
@@ -179,7 +171,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.NICID(), addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr)
return
}
@@ -211,11 +203,11 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
protocol: p,
nic: nic,
- linkEP: sender,
+ linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
nud: nud,
}
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 3e5cf2ad9..5c4f715d7 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -40,7 +40,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match an address we own.
src := hdr.SourceAddress()
- if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -110,7 +110,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 41f6914b9..7adf0fac3 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -59,7 +59,6 @@ type endpoint struct {
linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
protocol *protocol
- stack *stack.Stack
// enabled is set to 1 when the enpoint is enabled and 0 when it is
// disabled.
@@ -75,13 +74,12 @@ type endpoint struct {
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
nic: nic,
- linkEP: linkEP,
+ linkEP: nic.LinkEndpoint(),
dispatcher: dispatcher,
protocol: p,
- stack: st,
}
e.mu.addressableEndpointState.Init(e)
return e
@@ -173,16 +171,6 @@ func (e *endpoint) MTU() uint32 {
return calculateMTU(e.linkEP.MTU())
}
-// Capabilities implements stack.NetworkEndpoint.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nic.ID()
-}
-
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
@@ -324,8 +312,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesOutputDropped.Increment()
@@ -341,7 +329,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
if err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
@@ -381,10 +369,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
pkt = pkt.Next()
}
- nicName := e.stack.FindNICNameFromID(e.NICID())
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
@@ -404,7 +392,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -493,7 +481,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and will not be forwarded.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesInputDropped.Increment()
@@ -677,6 +665,8 @@ var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
type protocol struct {
+ stack *stack.Stack
+
// defaultTTL is the current default TTL for the protocol. Only the
// uint8 portion of it is meaningful.
//
@@ -799,7 +789,7 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV ui
}
// NewProtocol returns an IPv4 network protocol.
-func NewProtocol(*stack.Stack) stack.NetworkProtocol {
+func NewProtocol(s *stack.Stack) stack.NetworkProtocol {
ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
@@ -810,6 +800,7 @@ func NewProtocol(*stack.Stack) stack.NetworkProtocol {
hashIV := r[buckets]
return &protocol{
+ stack: s,
ids: ids,
hashIV: hashIV,
defaultTTL: DefaultTTL,
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 270439b5c..4b4b483cc 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -41,7 +41,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Drop packet if it doesn't have the basic IPv6 header or if the
// original source address doesn't match an address we own.
src := hdr.SourceAddress()
- if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
+ if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 {
return
}
@@ -248,7 +248,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// section 5.4.3.
// Is the NS targeting us?
- if r.Stack().CheckLocalAddress(e.NICID(), ProtocolNumber, targetAddr) == 0 {
+ if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 {
return
}
@@ -283,7 +283,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
} else if e.nud != nil {
e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol)
} else {
- e.linkAddrCache.AddLinkAddress(e.NICID(), r.RemoteAddress, sourceLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr)
}
// ICMPv6 Neighbor Solicit messages are always sent to
@@ -410,7 +410,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
// address cache with the link address for the target of the message.
if len(targetLinkAddr) != 0 {
if e.nud == nil {
- e.linkAddrCache.AddLinkAddress(e.NICID(), targetAddr, targetLinkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr)
return
}
@@ -438,7 +438,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme
localAddr = ""
}
- r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
if err != nil {
// If we cannot find a route to the destination, silently drop the packet.
return
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 75b27a4cf..d1ad7acb7 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -351,16 +351,6 @@ func (e *endpoint) MTU() uint32 {
return calculateMTU(e.linkEP.MTU())
}
-// NICID returns the ID of the NIC this endpoint belongs to.
-func (e *endpoint) NICID() tcpip.NICID {
- return e.nic.ID()
-}
-
-// Capabilities implements stack.NetworkEndpoint.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.linkEP.Capabilities()
-}
-
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
@@ -395,8 +385,8 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesOutputDropped.Increment()
@@ -412,7 +402,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
ep.HandlePacket(&route, pkt)
return nil
@@ -455,8 +445,8 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
- nicName := e.stack.FindNICNameFromID(e.NICID())
- ipt := e.stack.IPTables()
+ nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
+ ipt := e.protocol.stack.IPTables()
dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
if len(dropped) == 0 && len(natPkts) == 0 {
// Fast path: If no packets are to be dropped then we can just invoke the
@@ -476,7 +466,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
if _, ok := natPkts[pkt]; ok {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ if ep, err := e.protocol.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil {
src := netHeader.SourceAddress()
dst := netHeader.DestinationAddress()
route := r.ReverseRoute(src, dst)
@@ -531,7 +521,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// iptables filtering. All packets that reach here are intended for
// this machine and need not be forwarded.
- ipt := e.stack.IPTables()
+ ipt := e.protocol.stack.IPTables()
if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
// iptables is telling us to drop the packet.
r.Stats().IP.IPTablesInputDropped.Increment()
@@ -1084,6 +1074,8 @@ var _ stack.ForwardingNetworkProtocol = (*protocol)(nil)
var _ stack.NetworkProtocol = (*protocol)(nil)
type protocol struct {
+ stack *stack.Stack
+
mu struct {
sync.RWMutex
@@ -1147,15 +1139,14 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &endpoint{
nic: nic,
- linkEP: linkEP,
+ linkEP: nic.LinkEndpoint(),
linkAddrCache: linkAddrCache,
nud: nud,
dispatcher: dispatcher,
protocol: p,
- stack: st,
}
e.mu.addressableEndpointState.Init(e)
e.mu.ndp = ndpState{
@@ -1312,8 +1303,9 @@ type Options struct {
func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory {
opts.NDPConfigs.validate()
- return func(*stack.Stack) stack.NetworkProtocol {
+ return func(s *stack.Stack) stack.NetworkProtocol {
p := &protocol{
+ stack: s,
fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
ndpDisp: opts.NDPDisp,
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 1b5c61b80..84c082852 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -627,7 +627,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
if addressEndpoint.GetKind() != stack.PermanentTentative {
// The endpoint should be marked as tentative since we are starting DAD.
- panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
// Should not attempt to perform DAD on an address that is currently in the
@@ -639,7 +639,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// address, or its reference count would have been increased without doing
// the work that would have been done for an address that was brand new.
// See endpoint.addAddressLocked.
- panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.ep.nic.ID()))
}
remaining := ndp.configs.DupAddrDetectTransmits
@@ -649,7 +649,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// Consider DAD to have resolved even if no DAD messages were actually
// transmitted.
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.NICID(), addr, true, nil)
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
}
return nil
@@ -661,7 +661,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// cannot be done while holding the IPv6 endpoint's lock. This is effectively
// the same as starting a goroutine but we use a timer that fires immediately
// so we can reset it for the next DAD iteration.
- timer = ndp.ep.stack.Clock().AfterFunc(0, func() {
+ timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() {
ndp.ep.mu.Lock()
defer ndp.ep.mu.Unlock()
@@ -676,7 +676,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
if addressEndpoint.GetKind() != stack.PermanentTentative {
// The endpoint should still be marked as tentative since we are still
// performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
dadDone := remaining == 0
@@ -721,7 +721,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
delete(ndp.dad, addr)
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.NICID(), addr, dadDone, err)
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err)
}
// If DAD resolved for a stable SLAAC address, attempt generation of a
@@ -750,7 +750,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- r, err := ndp.ep.stack.FindRoute(ndp.ep.NICID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */)
if err != nil {
return err
}
@@ -766,9 +766,9 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add
return err
}
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.NICID(), err))
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err))
} else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID()))
}
icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
@@ -824,7 +824,7 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
// Let the integrator know DAD did not resolve.
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.NICID(), addr, false, nil)
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil)
}
}
@@ -861,7 +861,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
if ndp.dhcpv6Configuration != configuration {
ndp.dhcpv6Configuration = configuration
- ndpDisp.OnDHCPv6Configuration(ndp.ep.NICID(), configuration)
+ ndpDisp.OnDHCPv6Configuration(ndp.ep.nic.ID(), configuration)
}
}
@@ -908,7 +908,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
}
addrs, _ := opt.Addresses()
- ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.NICID(), addrs, opt.Lifetime())
+ ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime())
case header.NDPDNSSearchList:
if ndp.ep.protocol.ndpDisp == nil {
@@ -916,7 +916,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
}
domainNames, _ := opt.DomainNames()
- ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.NICID(), domainNames, opt.Lifetime())
+ ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime())
case header.NDPPrefixInformation:
prefix := opt.Subnet()
@@ -965,7 +965,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
// Let the integrator know a discovered default router is invalidated.
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnDefaultRouterInvalidated(ndp.ep.NICID(), ip)
+ ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip)
}
}
@@ -982,14 +982,14 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
}
// Inform the integrator when we discovered a default router.
- if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.NICID(), ip) {
+ if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) {
// Informed by the integrator to not remember the router, do
// nothing further.
return
}
state := defaultRouterState{
- invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
ndp.invalidateDefaultRouter(ip)
}),
}
@@ -1012,14 +1012,14 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration)
}
// Inform the integrator when we discovered an on-link prefix.
- if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.NICID(), prefix) {
+ if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) {
// Informed by the integrator to not remember the prefix, do
// nothing further.
return
}
state := onLinkPrefixState{
- invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
ndp.invalidateOnLinkPrefix(prefix)
}),
}
@@ -1048,7 +1048,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
// Let the integrator know a discovered on-link prefix is invalidated.
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.NICID(), prefix)
+ ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix)
}
}
@@ -1164,7 +1164,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
}
state := slaacPrefixState{
- deprecationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
@@ -1172,7 +1172,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
ndp.deprecateSLAACAddress(state.stableAddr.addressEndpoint)
}),
- invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
state, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
@@ -1230,7 +1230,7 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config
return nil
}
- if !ndpDisp.OnAutoGenAddress(ndp.ep.NICID(), addr) {
+ if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) {
// Informed by the integrator not to add the address.
return nil
}
@@ -1276,7 +1276,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt
addrBytes = header.AppendOpaqueInterfaceIdentifier(
addrBytes[:header.IIDOffsetInIPv6Address],
prefix,
- oIID.NICNameFromID(ndp.ep.NICID(), ndp.ep.nic.Name()),
+ oIID.NICNameFromID(ndp.ep.nic.ID(), ndp.ep.nic.Name()),
dadCounter,
oIID.SecretKey,
)
@@ -1433,7 +1433,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
}
state := tempSLAACAddrState{
- deprecationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ deprecationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
@@ -1446,7 +1446,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint)
}),
- invalidationJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
@@ -1459,7 +1459,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
}),
- regenJob: ndp.ep.stack.NewJob(&ndp.ep.mu, func() {
+ regenJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() {
prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
@@ -1677,7 +1677,7 @@ func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint
addressEndpoint.SetDeprecated(true)
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.NICID(), addressEndpoint.AddressWithPrefix())
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix())
}
}
@@ -1702,7 +1702,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.NICID(), addr)
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
prefix := addr.Subnet()
@@ -1762,7 +1762,7 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA
// The IPv6 endpoint that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.NICID(), addr)
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr)
}
if !invalidateAddr {
@@ -1878,7 +1878,7 @@ func (ndp *ndpState) startSolicitingRouters() {
var done bool
ndp.rtrSolicit.done = &done
- ndp.rtrSolicit.timer = ndp.ep.stack.Clock().AfterFunc(delay, func() {
+ ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() {
ndp.ep.mu.Lock()
if done {
// If we reach this point, it means that the RS timer fired after another
@@ -1904,7 +1904,7 @@ func (ndp *ndpState) startSolicitingRouters() {
ndp.ep.mu.Unlock()
localAddr := addressEndpoint.AddressWithPrefix().Address
- r, err := ndp.ep.stack.FindRoute(ndp.ep.NICID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
+ r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */)
addressEndpoint.DecRef()
if err != nil {
return
@@ -1923,9 +1923,9 @@ func (ndp *ndpState) startSolicitingRouters() {
return
}
- panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.NICID(), err))
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err))
} else if c != nil {
- panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.NICID()))
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID()))
}
// As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
@@ -1961,7 +1961,7 @@ func (ndp *ndpState) startSolicitingRouters() {
}, pkt,
); err != nil {
sent.Dropped.Increment()
- log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.NICID(), err)
+ log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err)
// Don't send any more messages if we had an error.
remaining = 0
} else {
@@ -2005,7 +2005,7 @@ func (ndp *ndpState) stopSolicitingRouters() {
// initializeTempAddrState initializes state related to temporary SLAAC
// addresses.
func (ndp *ndpState) initializeTempAddrState() {
- header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.NICID())
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID())
if MaxDesyncFactor != 0 {
ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 926ce9cfc..212c6edae 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -124,7 +124,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
for _, netProto := range stack.networkProtocols {
netNum := netProto.Number()
nic.mu.packetEPs[netNum] = nil
- nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic, ep, stack)
+ nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
nic.linkEP.Attach(nic)
@@ -796,7 +796,7 @@ func (n *NIC) Name() string {
return n.name
}
-// LinkEndpoint returns the link endpoint of n.
+// LinkEndpoint implements NetworkInterface.
func (n *NIC) LinkEndpoint() LinkEndpoint {
return n.linkEP
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index ef42fd6e1..567e1904e 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -154,10 +154,10 @@ type TransportProtocol interface {
Number() tcpip.TransportProtocolNumber
// NewEndpoint creates a new endpoint of the transport protocol.
- NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
// NewRawEndpoint creates a new raw endpoint of the transport protocol.
- NewRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+ NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
// MinimumPacketSize returns the minimum valid packet size of this
// transport protocol. The stack automatically drops any packets smaller
@@ -485,6 +485,9 @@ type NetworkInterface interface {
// Enabled returns true if the interface is enabled.
Enabled() bool
+
+ // LinkEndpoint returns the link endpoint backing the interface.
+ LinkEndpoint() LinkEndpoint
}
// NetworkEndpoint is the interface that needs to be implemented by endpoints
@@ -515,10 +518,6 @@ type NetworkEndpoint interface {
// minus the network endpoint max header length.
MTU() uint32
- // Capabilities returns the set of capabilities supported by the
- // underlying link-layer endpoint.
- Capabilities() LinkEndpointCapabilities
-
// MaxHeaderLength returns the maximum size the network (and lower
// level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
@@ -539,9 +538,6 @@ type NetworkEndpoint interface {
// header to the given destination address. It takes ownership of pkt.
WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error
- // NICID returns the id of the NIC this endpoint belongs to.
- NICID() tcpip.NICID
-
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint. It sets pkt.NetworkHeader.
//
@@ -586,7 +582,7 @@ type NetworkProtocol interface {
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
- NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint
+ NewEndpoint(nic NetworkInterface, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 1b008a067..5ade3c832 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -72,17 +72,18 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
loop |= PacketLoop
}
+ linkEP := nic.LinkEndpoint()
r := Route{
NetProto: netProto,
LocalAddress: localAddr,
- LocalLinkAddress: nic.linkEP.LinkAddress(),
+ LocalLinkAddress: linkEP.LinkAddress(),
RemoteAddress: remoteAddr,
addressEndpoint: addressEndpoint,
nic: nic,
Loop: loop,
}
- if nic := r.nic; nic.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
+ if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
r.linkCache = nic.stack
@@ -94,7 +95,7 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
- return r.addressEndpoint.NetworkEndpoint().NICID()
+ return r.nic.ID()
}
// MaxHeaderLength forwards the call to the network endpoint's implementation.
@@ -115,7 +116,7 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
// Capabilities returns the link-layer capabilities of the route.
func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.addressEndpoint.NetworkEndpoint().Capabilities()
+ return r.nic.LinkEndpoint().Capabilities()
}
// GSOMaxSize returns the maximum GSO packet size.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index b740aa305..57d8e79e0 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -837,7 +837,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
return nil, tcpip.ErrUnknownProtocol
}
- return t.proto.NewEndpoint(s, network, waiterQueue)
+ return t.proto.NewEndpoint(network, waiterQueue)
}
// NewRawEndpoint creates a new raw transport layer endpoint of the given
@@ -857,7 +857,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
return nil, tcpip.ErrUnknownProtocol
}
- return t.proto.NewRawEndpoint(s, network, waiterQueue)
+ return t.proto.NewRawEndpoint(network, waiterQueue)
}
// NewPacketEndpoint creates a new packet endpoint listening for the given
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 7484f4ad9..87d510f96 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -37,6 +37,8 @@ const (
// protocol implements stack.TransportProtocol.
type protocol struct {
+ stack *stack.Stack
+
number tcpip.TransportProtocolNumber
}
@@ -57,20 +59,20 @@ func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
// NewEndpoint creates a new icmp endpoint. It implements
// stack.TransportProtocol.NewEndpoint.
-func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return newEndpoint(stack, netProto, p.number, waiterQueue)
+ return newEndpoint(p.stack, netProto, p.number, waiterQueue)
}
// NewRawEndpoint creates a new raw icmp endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
- return raw.NewEndpoint(stack, netProto, p.number, waiterQueue)
+ return raw.NewEndpoint(p.stack, netProto, p.number, waiterQueue)
}
// MinimumPacketSize returns the minimum valid icmp packet size.
@@ -130,11 +132,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
}
// NewProtocol4 returns an ICMPv4 transport protocol.
-func NewProtocol4(*stack.Stack) stack.TransportProtocol {
- return &protocol{ProtocolNumber4}
+func NewProtocol4(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s, number: ProtocolNumber4}
}
// NewProtocol6 returns an ICMPv6 transport protocol.
-func NewProtocol6(*stack.Stack) stack.TransportProtocol {
- return &protocol{ProtocolNumber6}
+func NewProtocol6(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s, number: ProtocolNumber6}
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 6a3c2c32b..5bce73605 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -133,6 +133,8 @@ func (s *synRcvdCounter) Threshold() uint64 {
}
type protocol struct {
+ stack *stack.Stack
+
mu sync.RWMutex
sackEnabled bool
recovery tcpip.TCPRecovery
@@ -159,14 +161,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new tcp endpoint.
-func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newEndpoint(stack, netProto, waiterQueue), nil
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue)
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(p.stack, netProto, header.TCPProtocolNumber, waiterQueue)
}
// MinimumPacketSize returns the minimum valid tcp packet size.
@@ -505,8 +507,9 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
}
// NewProtocol returns a TCP transport protocol.
-func NewProtocol(*stack.Stack) stack.TransportProtocol {
+func NewProtocol(s *stack.Stack) stack.TransportProtocol {
p := protocol{
+ stack: s,
sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{
Min: MinBufferSize,
Default: DefaultSendBufferSize,
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index e6fc23258..da5b1deb2 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -45,6 +45,7 @@ const (
)
type protocol struct {
+ stack *stack.Stack
}
// Number returns the udp protocol number.
@@ -53,14 +54,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new udp endpoint.
-func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newEndpoint(stack, netProto, waiterQueue), nil
+func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(p.stack, netProto, waiterQueue), nil
}
// NewRawEndpoint creates a new raw UDP endpoint. It implements
// stack.TransportProtocol.NewRawEndpoint.
-func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue)
+func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue)
}
// MinimumPacketSize returns the minimum valid udp packet size.
@@ -114,6 +115,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
}
// NewProtocol returns a UDP transport protocol.
-func NewProtocol(*stack.Stack) stack.TransportProtocol {
- return &protocol{}
+func NewProtocol(s *stack.Stack) stack.TransportProtocol {
+ return &protocol{stack: s}
}