diff options
-rw-r--r-- | pkg/tcpip/stack/addressable_endpoint_state.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 18 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 2 |
6 files changed, 22 insertions, 33 deletions
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index db8ac1c2b..4d3acab96 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -679,11 +679,6 @@ type addressState struct { } } -// NetworkEndpoint implements AddressEndpoint. -func (a *addressState) NetworkEndpoint() NetworkEndpoint { - return a.addressableEndpointState.networkEndpoint -} - // AddressWithPrefix implements AddressEndpoint. func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix { return a.addr diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 06824843a..6cf54cc89 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -38,8 +38,11 @@ type NIC struct { linkEP LinkEndpoint context NICContext - stats NICStats - neigh *neighborCache + stats NICStats + neigh *neighborCache + + // The network endpoints themselves may be modified by calling the interface's + // methods, but the map reference and entries must be constant. networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. @@ -132,6 +135,10 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC return nic } +func (n *NIC) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint { + return n.networkEndpoints[proto] +} + // Enabled implements NetworkInterface. func (n *NIC) Enabled() bool { return atomic.LoadUint32(&n.enabled) == 1 @@ -211,7 +218,6 @@ func (n *NIC) remove() *tcpip.Error { for _, ep := range n.networkEndpoints { ep.Close() } - n.networkEndpoints = nil // Detach from link endpoint, so no packet comes in. n.linkEP.Attach(nil) @@ -483,9 +489,9 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) + defer r.Release() r.RemoteLinkAddress = remotelinkAddr - addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt) - addressEndpoint.DecRef() + n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) } // DeliverNetworkPacket finds the appropriate network protocol endpoint and @@ -603,7 +609,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp r.RemoteLinkAddress = remote r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. - addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt) + n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) addressEndpoint.DecRef() r.Release() return diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 16f854e1f..be9bd8042 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -326,10 +326,6 @@ const ( // AssignableAddressEndpoint is a reference counted address endpoint that may be // assigned to a NetworkEndpoint. type AssignableAddressEndpoint interface { - // NetworkEndpoint returns the NetworkEndpoint the receiver is associated - // with. - NetworkEndpoint() NetworkEndpoint - // AddressWithPrefix returns the endpoint's address. AddressWithPrefix() tcpip.AddressWithPrefix diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index effe30155..cc39c9a6a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -100,7 +100,7 @@ func (r *Route) NICID() tcpip.NICID { // MaxHeaderLength forwards the call to the network endpoint's implementation. func (r *Route) MaxHeaderLength() uint16 { - return r.addressEndpoint.NetworkEndpoint().MaxHeaderLength() + return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength() } // Stats returns a mutable copy of current stats. @@ -121,7 +121,7 @@ func (r *Route) Capabilities() LinkEndpointCapabilities { // GSOMaxSize returns the maximum GSO packet size. func (r *Route) GSOMaxSize() uint32 { - if gso, ok := r.addressEndpoint.NetworkEndpoint().(GSOEndpoint); ok { + if gso, ok := r.nic.getNetworkEndpoint(r.NetProto).(GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 @@ -211,7 +211,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() - if err := r.addressEndpoint.NetworkEndpoint().WritePacket(r, gso, params, pkt); err != nil { + if err := r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt); err != nil { return err } @@ -227,7 +227,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead return 0, tcpip.ErrInvalidEndpointState } - n, err := r.addressEndpoint.NetworkEndpoint().WritePackets(r, gso, pkts, params) + n, err := r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) r.nic.stats.Tx.Packets.IncrementBy(uint64(n)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() { @@ -248,7 +248,7 @@ func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Data.Size() - if err := r.addressEndpoint.NetworkEndpoint().WriteHeaderIncludedPacket(r, pkt); err != nil { + if err := r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt); err != nil { return err } r.nic.stats.Tx.Packets.Increment() @@ -258,18 +258,12 @@ func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { - return r.addressEndpoint.NetworkEndpoint().DefaultTTL() + return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL() } // MTU returns the MTU of the underlying network endpoint. func (r *Route) MTU() uint32 { - return r.addressEndpoint.NetworkEndpoint().MTU() -} - -// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying -// network endpoint. -func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return r.addressEndpoint.NetworkEndpoint().NetworkProtocolNumber() + return r.nic.getNetworkEndpoint(r.NetProto).MTU() } // Release frees all resources associated with the route. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 57d8e79e0..0bf20c0e1 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1796,7 +1796,7 @@ func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtoco return nil, tcpip.ErrUnknownNICID } - return nic.networkEndpoints[proto], nil + return nic.getNetworkEndpoint(proto), nil } // NUDConfigurations gets the per-interface NUD configurations. @@ -1873,10 +1873,8 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres if addressEndpoint == nil { continue } - - ep := addressEndpoint.NetworkEndpoint() addressEndpoint.DecRef() - return ep, nil + return nic.getNetworkEndpoint(netProto), nil } return nil, tcpip.ErrBadAddress } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 6891fd245..189c01c8f 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -804,7 +804,7 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso pkt.Owner = owner pkt.EgressRoute = r pkt.GSOOptions = gso - pkt.NetworkProtocolNumber = r.NetworkProtocolNumber() + pkt.NetworkProtocolNumber = r.NetProto data.ReadToVV(&pkt.Data, packetSize) buildTCPHdr(r, tf, pkt, gso) tf.seq = tf.seq.Add(seqnum.Size(packetSize)) |