summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2020-08-14 17:27:23 -0700
committergVisor bot <gvisor-bot@google.com>2020-08-14 17:30:01 -0700
commit1736b2208f7eeec56531a9877ca53dc784fed544 (patch)
tree27335ba42c862cbe6015036f4b55b2562c08a275 /pkg/tcpip
parent3f523b3bbcf5ef7f37bb247bd4c5727711c70ba9 (diff)
Use a single NetworkEndpoint per NIC per protocol
The NetworkEndpoint does not need to be created for each address. Most of the work the NetworkEndpoint does is address agnostic. PiperOrigin-RevId: 326759605
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/arp/arp.go15
-rw-r--r--pkg/tcpip/network/ip_test.go77
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go7
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go20
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go7
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go6
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go20
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go5
-rw-r--r--pkg/tcpip/stack/forwarder_test.go18
-rw-r--r--pkg/tcpip/stack/ndp.go8
-rw-r--r--pkg/tcpip/stack/nic.go86
-rw-r--r--pkg/tcpip/stack/nic_test.go30
-rw-r--r--pkg/tcpip/stack/registration.go8
-rw-r--r--pkg/tcpip/stack/stack.go6
-rw-r--r--pkg/tcpip/stack/stack_test.go20
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go10
17 files changed, 123 insertions, 221 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 6a4839fb8..46083925c 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -12,6 +12,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 1ad788a17..920872c3f 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -66,14 +66,6 @@ func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.linkEP.Capabilities()
}
-func (e *endpoint) ID() *stack.NetworkEndpointID {
- return &stack.NetworkEndpointID{ProtocolAddress}
-}
-
-func (e *endpoint) PrefixLen() int {
- return 0
-}
-
func (e *endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.ARPSize
}
@@ -142,16 +134,13 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
- if addrWithPrefix.Address != ProtocolAddress {
- return nil, tcpip.ErrBadLocalAddress
- }
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
return &endpoint{
protocol: p,
nicID: nicID,
linkEP: sender,
linkAddrCache: linkAddrCache,
- }, nil
+ }
}
// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 491d936a1..9007346fe 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -41,6 +42,7 @@ const (
ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ nicID = 1
)
// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
@@ -195,15 +197,15 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
})
- s.CreateNIC(1, loopback.New())
- s.AddAddress(1, ipv4.ProtocolNumber, local)
+ s.CreateNIC(nicID, loopback.New())
+ s.AddAddress(nicID, ipv4.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
Gateway: ipv4Gateway,
NIC: 1,
}})
- return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
+ return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
}
func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
@@ -211,31 +213,45 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
})
- s.CreateNIC(1, loopback.New())
- s.AddAddress(1, ipv6.ProtocolNumber, local)
+ s.CreateNIC(nicID, loopback.New())
+ s.AddAddress(nicID, ipv6.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
Gateway: ipv6Gateway,
NIC: 1,
}})
- return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
+ return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
}
-func buildDummyStack() *stack.Stack {
- return stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+func buildDummyStack(t *testing.T) *stack.Stack {
+ t.Helper()
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
})
+ e := channel.New(0, 1280, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, localIpv4Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, localIpv4Addr, err)
+ }
+
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, localIpv6Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, localIpv6Addr, err)
+ }
+
+ return s
}
func TestIPv4Send(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
+ ep := proto.NewEndpoint(nicID, nil, nil, &o, buildDummyStack(t))
+ defer ep.Close()
// Allocate and initialize the payload view.
payload := buffer.NewView(100)
@@ -271,10 +287,8 @@ func TestIPv4Send(t *testing.T) {
func TestIPv4Receive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
+ ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ defer ep.Close()
totalLen := header.IPv4MinimumSize + 30
view := buffer.NewView(totalLen)
@@ -343,10 +357,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
o := testObject{t: t}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
+ ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
defer ep.Close()
const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
@@ -407,10 +418,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
func TestIPv4FragmentationReceive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
+ ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ defer ep.Close()
totalLen := header.IPv4MinimumSize + 24
@@ -486,10 +495,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
func TestIPv6Send(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
+ ep := proto.NewEndpoint(nicID, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t))
+ defer ep.Close()
// Allocate and initialize the payload view.
payload := buffer.NewView(100)
@@ -525,10 +532,8 @@ func TestIPv6Send(t *testing.T) {
func TestIPv6Receive(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
+ ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ defer ep.Close()
totalLen := header.IPv6MinimumSize + 30
view := buffer.NewView(totalLen)
@@ -606,11 +611,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil, buildDummyStack())
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
+ ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
defer ep.Close()
dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 067d770f3..b5659a36b 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -37,8 +37,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// false.
//
// Drop packet if it doesn't have the basic IPv4 header or if the
- // original source address doesn't match the endpoint's address.
- if hdr.SourceAddress() != e.id.LocalAddress {
+ // original source address doesn't match an address we own.
+ src := hdr.SourceAddress()
+ if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
return
}
@@ -53,7 +54,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// Skip the ip header, then deliver control message.
pkt.Data.TrimFront(hlen)
p := hdr.TransportProtocol()
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 3cd48ceb3..79872ec9a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -52,8 +52,6 @@ const (
type endpoint struct {
nicID tcpip.NICID
- id stack.NetworkEndpointID
- prefixLen int
linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
protocol *protocol
@@ -61,18 +59,14 @@ type endpoint struct {
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
- e := &endpoint{
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+ return &endpoint{
nicID: nicID,
- id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
linkEP: linkEP,
dispatcher: dispatcher,
protocol: p,
stack: st,
}
-
- return e, nil
}
// DefaultTTL is the default time-to-live value for this endpoint.
@@ -96,16 +90,6 @@ func (e *endpoint) NICID() tcpip.NICID {
return e.nicID
}
-// ID returns the ipv4 endpoint ID.
-func (e *endpoint) ID() *stack.NetworkEndpointID {
- return &e.id
-}
-
-// PrefixLen returns the ipv4 endpoint subnet prefix length in bits.
-func (e *endpoint) PrefixLen() int {
- return e.prefixLen
-}
-
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 39ae19295..66d3a953a 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -39,8 +39,9 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
// is truncated, which would cause IsValid to return false.
//
// Drop packet if it doesn't have the basic IPv6 header or if the
- // original source address doesn't match the endpoint's address.
- if hdr.SourceAddress() != e.id.LocalAddress {
+ // original source address doesn't match an address we own.
+ src := hdr.SourceAddress()
+ if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
return
}
@@ -67,7 +68,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack
}
// Deliver the control packet to the transport endpoint.
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
+ e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) {
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 2a2f7de01..9e4eeea77 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -114,10 +114,8 @@ func TestICMPCounts(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
- if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
- }
+ ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
+ defer ep.Close()
r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 0ade655b2..0eafe9790 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -46,12 +46,11 @@ const (
type endpoint struct {
nicID tcpip.NICID
- id stack.NetworkEndpointID
- prefixLen int
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
dispatcher stack.TransportDispatcher
protocol *protocol
+ stack *stack.Stack
}
// DefaultTTL is the default hop limit for this endpoint.
@@ -70,16 +69,6 @@ func (e *endpoint) NICID() tcpip.NICID {
return e.nicID
}
-// ID returns the ipv6 endpoint ID.
-func (e *endpoint) ID() *stack.NetworkEndpointID {
- return &e.id
-}
-
-// PrefixLen returns the ipv6 endpoint subnet prefix length in bits.
-func (e *endpoint) PrefixLen() int {
- return e.prefixLen
-}
-
// Capabilities implements stack.NetworkEndpoint.Capabilities.
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.linkEP.Capabilities()
@@ -464,16 +453,15 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
return &endpoint{
nicID: nicID,
- id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
linkEP: linkEP,
linkAddrCache: linkAddrCache,
dispatcher: dispatcher,
protocol: p,
- }, nil
+ stack: st,
+ }
}
// SetOption implements NetworkProtocol.SetOption.
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 2efa82e60..af71a7d6b 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -63,10 +63,7 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
- if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
- }
+ ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
return s, ep
}
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 944f622fd..5a684eb9d 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -46,8 +46,6 @@ const (
// protocol. They're all one byte fields to simplify parsing.
type fwdTestNetworkEndpoint struct {
nicID tcpip.NICID
- id NetworkEndpointID
- prefixLen int
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
ep LinkEndpoint
@@ -61,18 +59,10 @@ func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID {
return f.nicID
}
-func (f *fwdTestNetworkEndpoint) PrefixLen() int {
- return f.prefixLen
-}
-
func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
-func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
- return &f.id
-}
-
func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
// Dispatch the packet to the transport protocol.
f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
@@ -99,7 +89,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
// endpoint.
b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen)
b[dstAddrOffset] = r.RemoteAddress[0]
- b[srcAddrOffset] = f.id.LocalAddress[0]
+ b[srcAddrOffset] = r.LocalAddress[0]
b[protocolNumberOffset] = byte(params.Protocol)
return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
@@ -151,15 +141,13 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol
return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
}
-func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
+func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint {
return &fwdTestNetworkEndpoint{
nicID: nicID,
- id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
proto: f,
dispatcher: dispatcher,
ep: ep,
- }, nil
+ }
}
func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 93567806b..b0873d1af 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -728,7 +728,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
snmc := header.SolicitedNodeAddr(addr)
- r := makeRoute(header.IPv6ProtocolNumber, ref.ep.ID().LocalAddress, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ r := makeRoute(header.IPv6ProtocolNumber, ref.address(), snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
// Route should resolve immediately since snmc is a multicast address so a
@@ -1353,7 +1353,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla
return false
}
- stableAddr := prefixState.stableAddr.ref.ep.ID().LocalAddress
+ stableAddr := prefixState.stableAddr.ref.address()
now := time.Now()
// As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
@@ -1690,7 +1690,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr
prefix := addr.Subnet()
state, ok := ndp.slaacPrefixes[prefix]
- if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.ep.ID().LocalAddress {
+ if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.address() {
return
}
@@ -1867,7 +1867,7 @@ func (ndp *ndpState) startSolicitingRouters() {
}
ndp.nic.mu.Unlock()
- localAddr := ref.ep.ID().LocalAddress
+ localAddr := ref.address()
r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
defer r.Release()
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 2315ea5b9..10d2b7964 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -45,8 +45,9 @@ type NIC struct {
linkEP LinkEndpoint
context NICContext
- stats NICStats
- neigh *neighborCache
+ stats NICStats
+ neigh *neighborCache
+ networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint
mu struct {
sync.RWMutex
@@ -114,12 +115,13 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// of IPv6 is supported on this endpoint's LinkEndpoint.
nic := &NIC{
- stack: stack,
- id: id,
- name: name,
- linkEP: ep,
- context: ctx,
- stats: makeNICStats(),
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ context: ctx,
+ stats: makeNICStats(),
+ networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
}
nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
@@ -140,7 +142,9 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
nic.mu.packetEPs[netProto] = []PacketEndpoint{}
}
for _, netProto := range stack.networkProtocols {
- nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{}
+ netNum := netProto.Number()
+ nic.mu.packetEPs[netNum] = nil
+ nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nic, ep, stack)
}
// Check for Neighbor Unreachability Detection support.
@@ -205,7 +209,7 @@ func (n *NIC) disableLocked() *tcpip.Error {
// Stop DAD for all the unicast IPv6 endpoints that are in the
// permanentTentative state.
for _, r := range n.mu.endpoints {
- if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) {
+ if addr := r.address(); r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) {
n.mu.ndp.stopDuplicateAddressDetection(addr)
}
}
@@ -300,7 +304,7 @@ func (n *NIC) enable() *tcpip.Error {
// Addresses may have aleady completed DAD but in the time since the NIC was
// last enabled, other devices may have acquired the same addresses.
for _, r := range n.mu.endpoints {
- addr := r.ep.ID().LocalAddress
+ addr := r.address()
if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) {
continue
}
@@ -362,6 +366,11 @@ func (n *NIC) remove() *tcpip.Error {
}
}
+ // Release any resources the network endpoint may hold.
+ for _, ep := range n.networkEndpoints {
+ ep.Close()
+ }
+
// Detach from link endpoint, so no packet comes in.
n.linkEP.Attach(nil)
@@ -510,7 +519,7 @@ func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNe
continue
}
- addr := r.ep.ID().LocalAddress
+ addr := r.address()
scope, err := header.ScopeForIPv6Address(addr)
if err != nil {
// Should never happen as we got r from the primary IPv6 endpoint list and
@@ -539,10 +548,10 @@ func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNe
sb := cs[j]
// Prefer same address as per RFC 6724 section 5 rule 1.
- if sa.ref.ep.ID().LocalAddress == remoteAddr {
+ if sa.ref.address() == remoteAddr {
return true
}
- if sb.ref.ep.ID().LocalAddress == remoteAddr {
+ if sb.ref.address() == remoteAddr {
return false
}
@@ -819,17 +828,11 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
}
}
- netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol]
+ ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
return nil, tcpip.ErrUnknownProtocol
}
- // Create the new network endpoint.
- ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP, n.stack)
- if err != nil {
- return nil, err
- }
-
isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address)
// If the address is an IPv6 address and it is a permanent address,
@@ -842,6 +845,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
ref := &referencedNetworkEndpoint{
refs: 1,
+ addr: protocolAddress.AddressWithPrefix,
ep: ep,
nic: n,
protocol: protocolAddress.Protocol,
@@ -898,7 +902,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
defer n.mu.RUnlock()
addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints))
- for nid, ref := range n.mu.endpoints {
+ for _, ref := range n.mu.endpoints {
// Don't include tentative, expired or temporary endpoints to
// avoid confusion and prevent the caller from using those.
switch ref.getKind() {
@@ -907,11 +911,8 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
}
addrs = append(addrs, tcpip.ProtocolAddress{
- Protocol: ref.protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: nid.LocalAddress,
- PrefixLen: ref.ep.PrefixLen(),
- },
+ Protocol: ref.protocol,
+ AddressWithPrefix: ref.addrWithPrefix(),
})
}
return addrs
@@ -934,11 +935,8 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
}
addrs = append(addrs, tcpip.ProtocolAddress{
- Protocol: proto,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: ref.ep.ID().LocalAddress,
- PrefixLen: ref.ep.PrefixLen(),
- },
+ Protocol: proto,
+ AddressWithPrefix: ref.addrWithPrefix(),
})
}
}
@@ -969,10 +967,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit
}
if !ref.deprecated {
- return tcpip.AddressWithPrefix{
- Address: ref.ep.ID().LocalAddress,
- PrefixLen: ref.ep.PrefixLen(),
- }
+ return ref.addrWithPrefix()
}
if deprecatedEndpoint == nil {
@@ -981,10 +976,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit
}
if deprecatedEndpoint != nil {
- return tcpip.AddressWithPrefix{
- Address: deprecatedEndpoint.ep.ID().LocalAddress,
- PrefixLen: deprecatedEndpoint.ep.PrefixLen(),
- }
+ return deprecatedEndpoint.addrWithPrefix()
}
return tcpip.AddressWithPrefix{}
@@ -1048,7 +1040,7 @@ func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb Prim
}
func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
- id := *r.ep.ID()
+ id := NetworkEndpointID{LocalAddress: r.address()}
// Nothing to do if the reference has already been replaced with a different
// one. This happens in the case where 1) this endpoint's ref count hit zero
@@ -1072,8 +1064,6 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
break
}
}
-
- r.ep.Close()
}
func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
@@ -1718,6 +1708,7 @@ const (
type referencedNetworkEndpoint struct {
ep NetworkEndpoint
+ addr tcpip.AddressWithPrefix
nic *NIC
protocol tcpip.NetworkProtocolNumber
@@ -1743,11 +1734,12 @@ type referencedNetworkEndpoint struct {
deprecated bool
}
+func (r *referencedNetworkEndpoint) address() tcpip.Address {
+ return r.addr.Address
+}
+
func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix {
- return tcpip.AddressWithPrefix{
- Address: r.ep.ID().LocalAddress,
- PrefixLen: r.ep.PrefixLen(),
- }
+ return r.addr
}
func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 0870c8d9c..d312a79eb 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -101,11 +101,9 @@ var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
// We use this instead of ipv6.endpoint because the ipv6 package depends on
// the stack package which this test lives in, causing a cyclic dependency.
type testIPv6Endpoint struct {
- nicID tcpip.NICID
- id NetworkEndpointID
- prefixLen int
- linkEP LinkEndpoint
- protocol *testIPv6Protocol
+ nicID tcpip.NICID
+ linkEP LinkEndpoint
+ protocol *testIPv6Protocol
}
// DefaultTTL implements NetworkEndpoint.DefaultTTL.
@@ -146,16 +144,6 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip
return tcpip.ErrNotSupported
}
-// ID implements NetworkEndpoint.ID.
-func (e *testIPv6Endpoint) ID() *NetworkEndpointID {
- return &e.id
-}
-
-// PrefixLen implements NetworkEndpoint.PrefixLen.
-func (e *testIPv6Endpoint) PrefixLen() int {
- return e.prefixLen
-}
-
// NICID implements NetworkEndpoint.NICID.
func (e *testIPv6Endpoint) NICID() tcpip.NICID {
return e.nicID
@@ -204,14 +192,12 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address)
}
// NewEndpoint implements NetworkProtocol.NewEndpoint.
-func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
+func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint {
return &testIPv6Endpoint{
- nicID: nicID,
- id: NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
- linkEP: linkEP,
- protocol: p,
- }, nil
+ nicID: nicID,
+ linkEP: linkEP,
+ protocol: p,
+ }
}
// SetOption implements NetworkProtocol.SetOption.
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 4570e8969..aca2f77f8 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -262,12 +262,6 @@ type NetworkEndpoint interface {
// header to the given destination address. It takes ownership of pkt.
WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error
- // ID returns the network protocol endpoint ID.
- ID() *NetworkEndpointID
-
- // PrefixLen returns the network endpoint's subnet prefix length in bits.
- PrefixLen() int
-
// NICID returns the id of the NIC this endpoint belongs to.
NICID() tcpip.NICID
@@ -304,7 +298,7 @@ type NetworkProtocol interface {
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
- NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) (NetworkEndpoint, *tcpip.Error)
+ NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 5b19c5d59..9a1c8e409 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1321,7 +1321,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if id != 0 && !needRoute {
if nic, ok := s.nics[id]; ok && nic.enabled() {
if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
- return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
+ return makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
}
}
} else {
@@ -1334,10 +1334,10 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if len(remoteAddr) == 0 {
// If no remote address was provided, then the route
// provided will refer to the link local address.
- remoteAddr = ref.ep.ID().LocalAddress
+ remoteAddr = ref.address()
}
- r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback())
+ r := makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback())
r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr)
if len(route.Gateway) > 0 {
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 0273b3c63..b5a603098 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -70,8 +70,6 @@ const (
// protocol. They're all one byte fields to simplify parsing.
type fakeNetworkEndpoint struct {
nicID tcpip.NICID
- id stack.NetworkEndpointID
- prefixLen int
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
ep stack.LinkEndpoint
@@ -85,21 +83,13 @@ func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
return f.nicID
}
-func (f *fakeNetworkEndpoint) PrefixLen() int {
- return f.prefixLen
-}
-
func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
-func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
- return &f.id
-}
-
func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
- f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
+ f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++
// Handle control packets.
if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) {
@@ -145,7 +135,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
// endpoint.
hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen)
hdr[dstAddrOffset] = r.RemoteAddress[0]
- hdr[srcAddrOffset] = f.id.LocalAddress[0]
+ hdr[srcAddrOffset] = r.LocalAddress[0]
hdr[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
@@ -208,15 +198,13 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
+func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint {
return &fakeNetworkEndpoint{
nicID: nicID,
- id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
proto: f,
dispatcher: dispatcher,
ep: ep,
- }, nil
+ }
}
func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 71776d6db..f87d99d5a 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -1469,13 +1469,10 @@ func TestTTL(t *testing.T) {
} else {
p = ipv6.NewProtocol()
}
- ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{
+ ep := p.NewEndpoint(0, nil, nil, nil, stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
}))
- if err != nil {
- t.Fatal(err)
- }
wantTTL = ep.DefaultTTL()
ep.Close()
}
@@ -1505,13 +1502,10 @@ func TestSetTTL(t *testing.T) {
} else {
p = ipv6.NewProtocol()
}
- ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{
+ ep := p.NewEndpoint(0, nil, nil, nil, stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
}))
- if err != nil {
- t.Fatal(err)
- }
ep.Close()
testWrite(c, flow, checker.TTL(wantTTL))