summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
authorBhasker Hariharan <bhaskerh@google.com>2020-10-09 09:08:57 -0700
committergVisor bot <gvisor-bot@google.com>2020-10-09 09:11:18 -0700
commit8566decab094008d5f873cb679c972d5d60cc49a (patch)
tree0ba87479d8d36d3057366b7f26bddb1925b4d767 /pkg/tcpip/stack
parent07b1d7413e8b648b85fa9276a516732dd93276b4 (diff)
Automated rollback of changelist 336185457
PiperOrigin-RevId: 336304024
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD5
-rw-r--r--pkg/tcpip/stack/forwarder.go (renamed from pkg/tcpip/stack/pending_packets.go)60
-rw-r--r--pkg/tcpip/stack/forwarder_test.go (renamed from pkg/tcpip/stack/forwarding_test.go)12
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go4
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go5
-rw-r--r--pkg/tcpip/stack/nic.go125
-rw-r--r--pkg/tcpip/stack/nic_test.go10
-rw-r--r--pkg/tcpip/stack/registration.go33
-rw-r--r--pkg/tcpip/stack/route.go48
-rw-r--r--pkg/tcpip/stack/stack.go42
-rw-r--r--pkg/tcpip/stack/stack_test.go62
11 files changed, 227 insertions, 179 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index eba97334e..2eaeab779 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -56,6 +56,7 @@ go_library(
srcs = [
"addressable_endpoint_state.go",
"conntrack.go",
+ "forwarder.go",
"headertype_string.go",
"icmp_rate_limit.go",
"iptables.go",
@@ -72,7 +73,6 @@ go_library(
"nud.go",
"packet_buffer.go",
"packet_buffer_list.go",
- "pending_packets.go",
"rand.go",
"registration.go",
"route.go",
@@ -123,6 +123,7 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/ports",
@@ -138,7 +139,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
- "forwarding_test.go",
+ "forwarder_test.go",
"linkaddrcache_test.go",
"neighbor_cache_test.go",
"neighbor_entry_test.go",
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/forwarder.go
index f838eda8d..3eff141e6 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/forwarder.go
@@ -29,60 +29,60 @@ const (
)
type pendingPacket struct {
+ nic *NIC
route *Route
proto tcpip.NetworkProtocolNumber
pkt *PacketBuffer
}
-// packetsPendingLinkResolution is a queue of packets pending link resolution.
-//
-// Once link resolution completes successfully, the packets will be written.
-type packetsPendingLinkResolution struct {
+type forwardQueue struct {
sync.Mutex
// The packets to send once the resolver completes.
- packets map[<-chan struct{}][]pendingPacket
+ packets map[<-chan struct{}][]*pendingPacket
// FIFO of channels used to cancel the oldest goroutine waiting for
// link-address resolution.
cancelChans []chan struct{}
}
-func (f *packetsPendingLinkResolution) init() {
- f.Lock()
- defer f.Unlock()
- f.packets = make(map[<-chan struct{}][]pendingPacket)
+func newForwardQueue() *forwardQueue {
+ return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
}
-func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- f.Lock()
- defer f.Unlock()
+func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ shouldWait := false
+ f.Lock()
packets, ok := f.packets[ch]
- if len(packets) == maxPendingPacketsPerResolution {
+ if !ok {
+ shouldWait = true
+ }
+ for len(packets) == maxPendingPacketsPerResolution {
p := packets[0]
- packets[0] = pendingPacket{}
packets = packets[1:]
- p.route.Stats().IP.OutgoingPacketErrors.Increment()
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
p.route.Release()
}
-
if l := len(packets); l >= maxPendingPacketsPerResolution {
panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
}
-
- f.packets[ch] = append(packets, pendingPacket{
+ f.packets[ch] = append(packets, &pendingPacket{
+ nic: n,
route: r,
- proto: proto,
+ proto: protocol,
pkt: pkt,
})
+ f.Unlock()
- if ok {
+ if !shouldWait {
return
}
// Wait for the link-address resolution to complete.
- cancel := f.newCancelChannelLocked()
+ // Start a goroutine with a forwarding-cancel channel so that we can
+ // limit the maximum number of goroutines running concurrently.
+ cancel := f.newCancelChannel()
go func() {
cancelled := false
select {
@@ -92,21 +92,17 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
}
f.Lock()
- packets, ok := f.packets[ch]
+ packets := f.packets[ch]
delete(f.packets, ch)
f.Unlock()
- if !ok {
- panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets"))
- }
-
for _, p := range packets {
if cancelled {
- p.route.Stats().IP.OutgoingPacketErrors.Increment()
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
} else if _, err := p.route.Resolve(nil); err != nil {
- p.route.Stats().IP.OutgoingPacketErrors.Increment()
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
} else {
- p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
+ p.nic.forwardPacket(p.route, p.proto, p.pkt)
}
p.route.Release()
}
@@ -116,10 +112,12 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro
// newCancelChannel creates a channel that can cancel a pending forwarding
// activity. The oldest channel is closed if the number of open channels would
// exceed maxPendingResolutions.
-func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} {
+func (f *forwardQueue) newCancelChannel() chan struct{} {
+ f.Lock()
+ defer f.Unlock()
+
if len(f.cancelChans) == maxPendingResolutions {
ch := f.cancelChans[0]
- f.cancelChans[0] = nil
f.cancelChans = f.cancelChans[1:]
close(ch)
}
diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarder_test.go
index cf042309e..4e4b00a92 100644
--- a/pkg/tcpip/stack/forwarding_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -48,9 +48,10 @@ const (
type fwdTestNetworkEndpoint struct {
AddressableEndpointState
- nic NetworkInterface
+ nicID tcpip.NICID
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
+ ep LinkEndpoint
}
var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
@@ -66,7 +67,7 @@ func (*fwdTestNetworkEndpoint) Enabled() bool {
func (*fwdTestNetworkEndpoint) Disable() {}
func (f *fwdTestNetworkEndpoint) MTU() uint32 {
- return f.nic.MTU() - uint32(f.MaxHeaderLength())
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
@@ -79,7 +80,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen
+ return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen
}
func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
@@ -98,7 +99,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
b[srcAddrOffset] = r.LocalAddress[0]
b[protocolNumberOffset] = byte(params.Protocol)
- return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt)
+ return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
}
// WritePackets implements LinkEndpoint.WritePackets.
@@ -158,9 +159,10 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol
func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint {
e := &fwdTestNetworkEndpoint{
- nic: nic,
+ nicID: nic.ID(),
proto: f,
dispatcher: dispatcher,
+ ep: nic.LinkEndpoint(),
}
e.AddressableEndpointState.Init(e)
return e
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 4d69a4de1..9a72bec79 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -236,7 +236,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil {
// There is no need to log the error here; the NUD implementation may
// assume a working link. A valid link should be the responsibility of
// the NIC/stack.LinkEndpoint.
@@ -277,7 +277,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil {
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index e79abebca..a265fff0a 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -227,9 +227,8 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
clock := faketime.NewManualClock()
disp := testNUDDispatcher{}
nic := NIC{
- LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
-
- id: entryTestNICID,
+ id: entryTestNICID,
+ linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
stack: &Stack{
clock: clock,
nudDisp: &disp,
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 8828cc5fe..6cf54cc89 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -32,11 +32,10 @@ var _ NetworkInterface = (*NIC)(nil)
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
- LinkEndpoint
-
stack *Stack
id tcpip.NICID
name string
+ linkEP LinkEndpoint
context NICContext
stats NICStats
@@ -92,11 +91,10 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// of IPv6 is supported on this endpoint's LinkEndpoint.
nic := &NIC{
- LinkEndpoint: ep,
-
stack: stack,
id: id,
name: name,
+ linkEP: ep,
context: ctx,
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
@@ -132,7 +130,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
- nic.LinkEndpoint.Attach(nic)
+ nic.linkEP.Attach(nic)
return nic
}
@@ -222,7 +220,7 @@ func (n *NIC) remove() *tcpip.Error {
}
// Detach from link endpoint, so no packet comes in.
- n.LinkEndpoint.Attach(nil)
+ n.linkEP.Attach(nil)
return nil
}
@@ -242,64 +240,7 @@ func (n *NIC) isPromiscuousMode() bool {
// IsLoopback implements NetworkInterface.
func (n *NIC) IsLoopback() bool {
- return n.LinkEndpoint.Capabilities()&CapabilityLoopback != 0
-}
-
-// WritePacket implements NetworkLinkEndpoint.
-func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
- // As per relevant RFCs, we should queue packets while we wait for link
- // resolution to complete.
- //
- // RFC 1122 section 2.3.2.2 (for IPv4):
- // The link layer SHOULD save (rather than discard) at least
- // one (the latest) packet of each set of packets destined to
- // the same unresolved IP address, and transmit the saved
- // packet when the address has been resolved.
- //
- // RFC 4861 section 5.2 (for IPv6):
- // Once the IP address of the next-hop node is known, the sender
- // examines the Neighbor Cache for link-layer information about that
- // neighbor. If no entry exists, the sender creates one, sets its state
- // to INCOMPLETE, initiates Address Resolution, and then queues the data
- // packet pending completion of address resolution.
- if ch, err := r.Resolve(nil); err != nil {
- if err == tcpip.ErrWouldBlock {
- r := r.Clone()
- n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt)
- return nil
- }
- return err
- }
-
- return n.writePacket(r, gso, protocol, pkt)
-}
-
-func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
- // WritePacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Size()
-
- if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil {
- return err
- }
-
- n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
- return nil
-}
-
-// WritePackets implements NetworkLinkEndpoint.
-func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution
- // is being peformed like WritePacket.
- writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol)
- n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets))
- writtenBytes := 0
- for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() {
- writtenBytes += pb.Size()
- }
-
- n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
- return writtenPackets, err
+ return n.linkEP.Capabilities()&CapabilityLoopback != 0
}
// setSpoofing enables or disables address spoofing.
@@ -584,7 +525,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// If no local link layer address is provided, assume it was sent
// directly to this NIC.
if local == "" {
- local = n.LinkEndpoint.LinkAddress()
+ local = n.linkEP.LinkAddress()
}
// Are any packet type sockets listening for this network protocol?
@@ -664,7 +605,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n := r.nic
if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
if n.isValidForOutgoing(addressEndpoint) {
- r.LocalLinkAddress = n.LinkEndpoint.LinkAddress()
+ r.LocalLinkAddress = n.linkEP.LinkAddress()
r.RemoteLinkAddress = remote
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
@@ -679,21 +620,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// n doesn't have a destination endpoint.
// Send the packet out of n.
+ // TODO(b/128629022): move this logic to route.WritePacket.
// TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
-
- // pkt may have set its header and may not have enough headroom for
- // link-layer header for the other link to prepend. Here we create a new
- // packet to forward.
- fwdPkt := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()),
- Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
- })
-
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
- if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil {
+ if ch, err := r.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
+ // forwarder will release route.
+ return
+ }
n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
+ r.Release()
+ return
}
+ // The link-address resolution finished immediately.
+ n.forwardPacket(&r, protocol, pkt)
r.Release()
return
}
@@ -717,11 +658,34 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
p.PktType = tcpip.PacketOutgoing
// Add the link layer header as outgoing packets are intercepted
// before the link layer header is created.
- n.LinkEndpoint.AddHeader(local, remote, protocol, p)
+ n.linkEP.AddHeader(local, remote, protocol, p)
ep.HandlePacket(n.id, local, protocol, p)
}
}
+func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+
+ // pkt may have set its header and may not have enough headroom for link-layer
+ // header for the other link to prepend. Here we create a new packet to
+ // forward.
+ fwdPkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()),
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ })
+
+ // WritePacket takes ownership of fwdPkt, calculate numBytes first.
+ numBytes := fwdPkt.Size()
+
+ if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return
+ }
+
+ n.stats.Tx.Packets.Increment()
+ n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+}
+
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
@@ -832,6 +796,11 @@ func (n *NIC) Name() string {
return n.name
}
+// LinkEndpoint implements NetworkInterface.
+func (n *NIC) LinkEndpoint() LinkEndpoint {
+ return n.linkEP
+}
+
// nudConfigs gets the NUD configurations for n.
func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) {
if n.neigh == nil {
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 97a96af62..fdd49b77f 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -33,7 +33,8 @@ var _ NDPEndpoint = (*testIPv6Endpoint)(nil)
type testIPv6Endpoint struct {
AddressableEndpointState
- nic NetworkInterface
+ nicID tcpip.NICID
+ linkEP LinkEndpoint
protocol *testIPv6Protocol
invalidatedRtr tcpip.Address
@@ -56,12 +57,12 @@ func (*testIPv6Endpoint) DefaultTTL() uint8 {
// MTU implements NetworkEndpoint.MTU.
func (e *testIPv6Endpoint) MTU() uint32 {
- return e.nic.MTU() - header.IPv6MinimumSize
+ return e.linkEP.MTU() - header.IPv6MinimumSize
}
// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
- return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
+ return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
}
// WritePacket implements NetworkEndpoint.WritePacket.
@@ -133,7 +134,8 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint implements NetworkProtocol.NewEndpoint.
func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint {
e := &testIPv6Endpoint{
- nic: nic,
+ nicID: nic.ID(),
+ linkEP: nic.LinkEndpoint(),
protocol: p,
}
e.AddressableEndpointState.Init(e)
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index defb9129b..be9bd8042 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -475,8 +475,6 @@ type NDPEndpoint interface {
// NetworkInterface is a network interface.
type NetworkInterface interface {
- NetworkLinkEndpoint
-
// ID returns the interface's ID.
ID() tcpip.NICID
@@ -490,6 +488,9 @@ type NetworkInterface interface {
// Enabled returns true if the interface is enabled.
Enabled() bool
+
+ // LinkEndpoint returns the link endpoint backing the interface.
+ LinkEndpoint() LinkEndpoint
}
// NetworkEndpoint is the interface that needs to be implemented by endpoints
@@ -662,15 +663,22 @@ const (
CapabilitySoftwareGSO
)
-// NetworkLinkEndpoint is a data-link layer that supports sending network
-// layer packets.
-type NetworkLinkEndpoint interface {
+// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
+// ethernet, loopback, raw) and used by network layer protocols to send packets
+// out through the implementer's data link endpoint. When a link header exists,
+// it sets each PacketBuffer's LinkHeader field before passing it up the
+// stack.
+type LinkEndpoint interface {
// MTU is the maximum transmission unit for this endpoint. This is
// usually dictated by the backing physical network; when such a
// physical network doesn't exist, the limit is generally 64k, which
// includes the maximum size of an IP packet.
MTU() uint32
+ // Capabilities returns the set of capabilities supported by the
+ // endpoint.
+ Capabilities() LinkEndpointCapabilities
+
// MaxHeaderLength returns the maximum size the data link (and
// lower level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
@@ -678,7 +686,7 @@ type NetworkLinkEndpoint interface {
MaxHeaderLength() uint16
// LinkAddress returns the link address (typically a MAC) of the
- // endpoint.
+ // link endpoint.
LinkAddress() tcpip.LinkAddress
// WritePacket writes a packet with the given protocol through the
@@ -698,19 +706,6 @@ type NetworkLinkEndpoint interface {
// offload is enabled. If it will be used for something else, it may
// require to change syscall filters.
WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
-}
-
-// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
-// ethernet, loopback, raw) and used by network layer protocols to send packets
-// out through the implementer's data link endpoint. When a link header exists,
-// it sets each PacketBuffer's LinkHeader field before passing it up the
-// stack.
-type LinkEndpoint interface {
- NetworkLinkEndpoint
-
- // Capabilities returns the set of capabilities supported by the
- // endpoint.
- Capabilities() LinkEndpointCapabilities
// WriteRawPacket writes a packet directly to the link. The packet
// should already have an ethernet header. It takes ownership of vv.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 710a02ac3..cc39c9a6a 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -72,20 +72,21 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
loop |= PacketLoop
}
+ linkEP := nic.LinkEndpoint()
r := Route{
NetProto: netProto,
LocalAddress: localAddr,
- LocalLinkAddress: nic.LinkEndpoint.LinkAddress(),
+ LocalLinkAddress: linkEP.LinkAddress(),
RemoteAddress: remoteAddr,
addressEndpoint: addressEndpoint,
nic: nic,
Loop: loop,
}
- if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
- if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok {
+ if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
- r.linkCache = r.nic.stack
+ r.linkCache = nic.stack
}
}
@@ -115,7 +116,7 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
// Capabilities returns the link-layer capabilities of the route.
func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.nic.LinkEndpoint.Capabilities()
+ return r.nic.LinkEndpoint().Capabilities()
}
// GSOMaxSize returns the maximum GSO packet size.
@@ -126,6 +127,12 @@ func (r *Route) GSOMaxSize() uint32 {
return 0
}
+// ResolveWith immediately resolves a route with the specified remote link
+// address.
+func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
+ r.RemoteLinkAddress = addr
+}
+
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
// notified when address resolution is complete (success or not).
@@ -201,7 +208,16 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf
return tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
+ // WritePacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Size()
+
+ if err := r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt); err != nil {
+ return err
+ }
+
+ r.nic.stats.Tx.Packets.Increment()
+ r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ return nil
}
// WritePackets writes a list of n packets through the given route and returns
@@ -211,7 +227,15 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead
return 0, tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
+ n, err := r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
+ r.nic.stats.Tx.Packets.IncrementBy(uint64(n))
+ writtenBytes := 0
+ for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() {
+ writtenBytes += pb.Size()
+ }
+
+ r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
+ return n, err
}
// WriteHeaderIncludedPacket writes a packet already containing a network
@@ -221,7 +245,15 @@ func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
+ // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Data.Size()
+
+ if err := r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt); err != nil {
+ return err
+ }
+ r.nic.stats.Tx.Packets.Increment()
+ r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ return nil
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 3a07577c8..0bf20c0e1 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -436,9 +436,9 @@ type Stack struct {
// uniqueIDGenerator is a generator of unique identifiers.
uniqueIDGenerator UniqueID
- // linkResQueue holds packets that are waiting for link resolution to
- // complete.
- linkResQueue packetsPendingLinkResolution
+ // forwarder holds the packets that wait for their link-address resolutions
+ // to complete, and forwards them when each resolution is done.
+ forwarder *forwardQueue
// randomGenerator is an injectable pseudo random generator that can be
// used when a random number is required.
@@ -550,8 +550,8 @@ type TransportEndpointInfo struct {
// incompatible with the receiver.
//
// Preconditon: the parent endpoint mu must be held while calling this method.
-func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := t.NetProto
+func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.NetProto
switch len(addr.Addr) {
case header.IPv4AddressSize:
netProto = header.IPv4ProtocolNumber
@@ -565,7 +565,7 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
}
}
- switch len(t.ID.LocalAddress) {
+ switch len(e.ID.LocalAddress) {
case header.IPv4AddressSize:
if len(addr.Addr) == header.IPv6AddressSize {
return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
@@ -577,8 +577,8 @@ func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
}
switch {
- case netProto == t.NetProto:
- case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber:
+ case netProto == e.NetProto:
+ case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber:
if v6only {
return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute
}
@@ -640,6 +640,7 @@ func New(opts Options) *Stack {
useNeighborCache: opts.UseNeighborCache,
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
+ forwarder: newForwardQueue(),
randomGenerator: mathrand.New(randSrc),
sendBufferSize: SendBufferSizeOption{
Min: MinBufferSize,
@@ -652,7 +653,6 @@ func New(opts Options) *Stack {
Max: DefaultMaxBufferSize,
},
}
- s.linkResQueue.init()
// Add specified network protocols.
for _, netProtoFactory := range opts.NetworkProtocols {
@@ -928,16 +928,16 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
return s.CreateNICWithOptions(id, ep, NICOptions{})
}
-// GetLinkEndpointByName gets the link endpoint specified by name.
-func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint {
+// GetNICByName gets the NIC specified by name.
+func (s *Stack) GetNICByName(name string) (*NIC, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
for _, nic := range s.nics {
if nic.Name() == name {
- return nic.LinkEndpoint
+ return nic, true
}
}
- return nil
+ return nil, false
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -1062,13 +1062,13 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
}
nics[id] = NICInfo{
Name: nic.name,
- LinkAddress: nic.LinkEndpoint.LinkAddress(),
+ LinkAddress: nic.linkEP.LinkAddress(),
ProtocolAddresses: nic.primaryAddresses(),
Flags: flags,
- MTU: nic.LinkEndpoint.MTU(),
+ MTU: nic.linkEP.MTU(),
Stats: nic.stats,
Context: nic.context,
- ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
+ ARPHardwareType: nic.linkEP.ARPHardwareType(),
}
}
return nics
@@ -1323,7 +1323,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
}
// Neighbors returns all IP to MAC address associations.
@@ -1539,7 +1539,7 @@ func (s *Stack) Wait() {
s.mu.RLock()
defer s.mu.RUnlock()
for _, n := range s.nics {
- n.LinkEndpoint.Wait()
+ n.linkEP.Wait()
}
}
@@ -1627,7 +1627,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t
// Add our own fake ethernet header.
ethFields := header.EthernetFields{
- SrcAddr: nic.LinkEndpoint.LinkAddress(),
+ SrcAddr: nic.linkEP.LinkAddress(),
DstAddr: dst,
Type: netProto,
}
@@ -1636,7 +1636,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t
vv := buffer.View(fakeHeader).ToVectorisedView()
vv.Append(payload)
- if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil {
+ if err := nic.linkEP.WriteRawPacket(vv); err != nil {
return err
}
@@ -1653,7 +1653,7 @@ func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView)
return tcpip.ErrUnknownDevice
}
- if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil {
+ if err := nic.linkEP.WriteRawPacket(payload); err != nil {
return err
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 38994cca1..aa20f750b 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -21,6 +21,7 @@ import (
"bytes"
"fmt"
"math"
+ "net"
"sort"
"testing"
"time"
@@ -34,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -75,9 +77,10 @@ type fakeNetworkEndpoint struct {
enabled bool
}
- nic stack.NetworkInterface
+ nicID tcpip.NICID
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
+ ep stack.LinkEndpoint
}
func (f *fakeNetworkEndpoint) Enable() *tcpip.Error {
@@ -100,7 +103,7 @@ func (f *fakeNetworkEndpoint) Disable() {
}
func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.nic.MTU() - uint32(f.MaxHeaderLength())
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
@@ -132,7 +135,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.nic.MaxHeaderLength() + fakeNetHeaderLen
+ return f.ep.MaxHeaderLength() + fakeNetHeaderLen
}
func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
@@ -161,7 +164,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
return nil
}
- return f.nic.WritePacket(r, gso, fakeNetNumber, pkt)
+ return f.ep.WritePacket(r, gso, fakeNetNumber, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
@@ -213,9 +216,10 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &fakeNetworkEndpoint{
- nic: nic,
+ nicID: nic.ID(),
proto: f,
dispatcher: dispatcher,
+ ep: nic.LinkEndpoint(),
}
e.AddressableEndpointState.Init(e)
return e
@@ -2102,7 +2106,7 @@ func TestNICStats(t *testing.T) {
t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want)
}
- if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want {
+ if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
@@ -3498,6 +3502,52 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
}
+func TestResolveWith(t *testing.T) {
+ const (
+ unspecifiedNICID = 0
+ nicID = 1
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
+ })
+ ep := channel.New(0, defaultMTU, "")
+ ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ addr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
+ PrefixLen: 24,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
+
+ remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4())
+ r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err)
+ }
+ defer r.Release()
+
+ // Should initially require resolution.
+ if !r.IsResolutionRequired() {
+ t.Fatal("got r.IsResolutionRequired() = false, want = true")
+ }
+
+ // Manually resolving the route should no longer require resolution.
+ r.ResolveWith("\x01")
+ if r.IsResolutionRequired() {
+ t.Fatal("got r.IsResolutionRequired() = true, want = false")
+ }
+}
+
// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its
// associated address is removed should not cause a panic.
func TestRouteReleaseAfterAddrRemoval(t *testing.T) {