summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD5
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state.go116
-rw-r--r--pkg/tcpip/stack/addressable_endpoint_state_test.go11
-rw-r--r--pkg/tcpip/stack/conntrack.go46
-rw-r--r--pkg/tcpip/stack/forwarding_test.go (renamed from pkg/tcpip/stack/forwarder_test.go)12
-rw-r--r--pkg/tcpip/stack/iptables.go4
-rw-r--r--pkg/tcpip/stack/iptables_targets.go70
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go15
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go66
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go4
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go5
-rw-r--r--pkg/tcpip/stack/nic.go169
-rw-r--r--pkg/tcpip/stack/nic_test.go10
-rw-r--r--pkg/tcpip/stack/packet_buffer.go15
-rw-r--r--pkg/tcpip/stack/pending_packets.go (renamed from pkg/tcpip/stack/forwarder.go)60
-rw-r--r--pkg/tcpip/stack/registration.go54
-rw-r--r--pkg/tcpip/stack/route.go71
-rw-r--r--pkg/tcpip/stack/stack.go48
-rw-r--r--pkg/tcpip/stack/stack_test.go125
19 files changed, 485 insertions, 421 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 2eaeab779..eba97334e 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -56,7 +56,6 @@ go_library(
srcs = [
"addressable_endpoint_state.go",
"conntrack.go",
- "forwarder.go",
"headertype_string.go",
"icmp_rate_limit.go",
"iptables.go",
@@ -73,6 +72,7 @@ go_library(
"nud.go",
"packet_buffer.go",
"packet_buffer_list.go",
+ "pending_packets.go",
"rand.go",
"registration.go",
"route.go",
@@ -123,7 +123,6 @@ go_test(
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
- "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/ports",
@@ -139,7 +138,7 @@ go_test(
name = "stack_test",
size = "small",
srcs = [
- "forwarder_test.go",
+ "forwarding_test.go",
"linkaddrcache_test.go",
"neighbor_cache_test.go",
"neighbor_entry_test.go",
diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go
index 270ac4977..4d3acab96 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state.go
@@ -412,6 +412,60 @@ func (a *AddressableEndpointState) decAddressRefLocked(addrState *addressState)
a.releaseAddressStateLocked(addrState)
}
+// MainAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) MainAddress() tcpip.AddressWithPrefix {
+ a.mu.RLock()
+ defer a.mu.RUnlock()
+
+ ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool {
+ return ep.GetKind() == Permanent
+ })
+ if ep == nil {
+ return tcpip.AddressWithPrefix{}
+ }
+
+ addr := ep.AddressWithPrefix()
+ a.decAddressRefLocked(ep)
+ return addr
+}
+
+// acquirePrimaryAddressRLocked returns an acquired primary address that is
+// valid according to isValid.
+//
+// Precondition: e.mu must be read locked
+func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*addressState) bool) *addressState {
+ var deprecatedEndpoint *addressState
+ for _, ep := range a.mu.primary {
+ if !isValid(ep) {
+ continue
+ }
+
+ if !ep.Deprecated() {
+ if ep.IncRef() {
+ // ep is not deprecated, so return it immediately.
+ //
+ // If we kept track of a deprecated endpoint, decrement its reference
+ // count since it was incremented when we decided to keep track of it.
+ if deprecatedEndpoint != nil {
+ a.decAddressRefLocked(deprecatedEndpoint)
+ deprecatedEndpoint = nil
+ }
+
+ return ep
+ }
+ } else if deprecatedEndpoint == nil && ep.IncRef() {
+ // We prefer an endpoint that is not deprecated, but we keep track of
+ // ep in case a doesn't have any non-deprecated endpoints.
+ //
+ // If we end up finding a more preferred endpoint, ep's reference count
+ // will be decremented.
+ deprecatedEndpoint = ep
+ }
+ }
+
+ return deprecatedEndpoint
+}
+
// AcquireAssignedAddress implements AddressableEndpoint.
func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint {
a.mu.Lock()
@@ -461,47 +515,34 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres
return ep
}
-// AcquirePrimaryAddress implements AddressableEndpoint.
-func (a *AddressableEndpointState) AcquirePrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
+// AcquireOutgoingPrimaryAddress implements AddressableEndpoint.
+func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint {
a.mu.RLock()
defer a.mu.RUnlock()
- var deprecatedEndpoint *addressState
- for _, ep := range a.mu.primary {
- if !ep.IsAssigned(allowExpired) {
- continue
- }
-
- if !ep.Deprecated() {
- if ep.IncRef() {
- // ep is not deprecated, so return it immediately.
- //
- // If we kept track of a deprecated endpoint, decrement its reference
- // count since it was incremented when we decided to keep track of it.
- if deprecatedEndpoint != nil {
- a.decAddressRefLocked(deprecatedEndpoint)
- deprecatedEndpoint = nil
- }
+ ep := a.acquirePrimaryAddressRLocked(func(ep *addressState) bool {
+ return ep.IsAssigned(allowExpired)
+ })
- return ep
- }
- } else if deprecatedEndpoint == nil && ep.IncRef() {
- // We prefer an endpoint that is not deprecated, but we keep track of
- // ep in case a doesn't have any non-deprecated endpoints.
- //
- // If we end up finding a more preferred endpoint, ep's reference count
- // will be decremented.
- deprecatedEndpoint = ep
- }
- }
-
- // a doesn't have any valid non-deprecated endpoints, so return
- // deprecatedEndpoint (which may be nil if a doesn't have any valid deprecated
- // endpoints either).
- if deprecatedEndpoint == nil {
+ // From https://golang.org/doc/faq#nil_error:
+ //
+ // Under the covers, interfaces are implemented as two elements, a type T and
+ // a value V.
+ //
+ // An interface value is nil only if the V and T are both unset, (T=nil, V is
+ // not set), In particular, a nil interface will always hold a nil type. If we
+ // store a nil pointer of type *int inside an interface value, the inner type
+ // will be *int regardless of the value of the pointer: (T=*int, V=nil). Such
+ // an interface value will therefore be non-nil even when the pointer value V
+ // inside is nil.
+ //
+ // Since acquirePrimaryAddressRLocked returns a nil value with a non-nil type,
+ // we need to explicitly return nil below if ep is (a typed) nil.
+ if ep == nil {
return nil
}
- return deprecatedEndpoint
+
+ return ep
}
// PrimaryAddresses implements AddressableEndpoint.
@@ -638,11 +679,6 @@ type addressState struct {
}
}
-// NetworkEndpoint implements AddressEndpoint.
-func (a *addressState) NetworkEndpoint() NetworkEndpoint {
- return a.addressableEndpointState.networkEndpoint
-}
-
// AddressWithPrefix implements AddressEndpoint.
func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix {
return a.addr
diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go
index de4e0d7b1..26787d0a3 100644
--- a/pkg/tcpip/stack/addressable_endpoint_state_test.go
+++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go
@@ -24,8 +24,13 @@ import (
// TestAddressableEndpointStateCleanup tests that cleaning up an addressable
// endpoint state removes permanent addresses and leaves groups.
func TestAddressableEndpointStateCleanup(t *testing.T) {
+ var ep fakeNetworkEndpoint
+ if err := ep.Enable(); err != nil {
+ t.Fatalf("ep.Enable(): %s", err)
+ }
+
var s stack.AddressableEndpointState
- s.Init(&fakeNetworkEndpoint{})
+ s.Init(&ep)
addr := tcpip.AddressWithPrefix{
Address: "\x01",
@@ -43,7 +48,7 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
{
ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
if ep == nil {
- t.Fatalf("got s.AcquireAssignedAddress(%s) = nil, want = non-nil", addr.Address)
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = nil, want = non-nil", addr.Address)
}
ep.DecRef()
}
@@ -63,7 +68,7 @@ func TestAddressableEndpointStateCleanup(t *testing.T) {
ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint)
if ep != nil {
ep.DecRef()
- t.Fatalf("got s.AcquireAssignedAddress(%s) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
+ t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix())
}
}
if s.IsInGroup(group) {
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index 836682ea0..0cd1da11f 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -196,13 +196,14 @@ type bucket struct {
// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
// TCP header.
+//
+// Preconditions: pkt.NetworkHeader() is valid.
func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
- // TODO(gvisor.dev/issue/170): Need to support for other
- // protocols as well.
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ netHeader := pkt.Network()
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
return tupleID{}, tcpip.ErrUnknownProtocol
}
+
tcpHeader := header.TCP(pkt.TransportHeader().View())
if len(tcpHeader) < header.TCPMinimumSize {
return tupleID{}, tcpip.ErrUnknownProtocol
@@ -214,7 +215,7 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
dstAddr: netHeader.DestinationAddress(),
dstPort: tcpHeader.DestinationPort(),
transProto: netHeader.TransportProtocol(),
- netProto: header.IPv4ProtocolNumber,
+ netProto: pkt.NetworkProtocolNumber,
}, nil
}
@@ -268,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
return nil, dirOriginal
}
-func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
@@ -281,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt Redirec
// rule. This tuple will be used to manipulate the packet in
// handlePacket.
replyTID := tid.reply()
- replyTID.srcAddr = rt.MinIP
- replyTID.srcPort = rt.MinPort
+ replyTID.srcAddr = rt.Addr
+ replyTID.srcPort = rt.Port
var manip manipType
switch hook {
case Prerouting:
@@ -344,7 +345,7 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
return
}
- netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
// For prerouting redirection, packets going in the original direction
@@ -366,8 +367,12 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
// support cases when they are validated, e.g. when we can't offload
// receive checksumming.
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacketOutput manipulates ports for packets in Output hook.
@@ -377,7 +382,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
return
}
- netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
// For output redirection, packets going in the original direction
@@ -396,7 +401,7 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
@@ -405,8 +410,11 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
}
// handlePacket will manipulate the port and address of the packet if the
@@ -422,7 +430,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
}
// TODO(gvisor.dev/issue/170): Support other transport protocols.
- if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return false
}
@@ -473,7 +481,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
}
// We only track TCP connections.
- if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ if pkt.Network().TransportProtocol() != header.TCPProtocolNumber {
return
}
@@ -609,7 +617,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo
return true
}
-func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
// Lookup the connection. The reply's original destination
// describes the original address.
tid := tupleID{
@@ -618,7 +626,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint1
dstAddr: epID.RemoteAddress,
dstPort: epID.RemotePort,
transProto: header.TCPProtocolNumber,
- netProto: header.IPv4ProtocolNumber,
+ netProto: netProto,
}
conn, _ := ct.connForTID(tid)
if conn == nil {
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarding_test.go
index 4e4b00a92..cf042309e 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarding_test.go
@@ -48,10 +48,9 @@ const (
type fwdTestNetworkEndpoint struct {
AddressableEndpointState
- nicID tcpip.NICID
+ nic NetworkInterface
proto *fwdTestNetworkProtocol
dispatcher TransportDispatcher
- ep LinkEndpoint
}
var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil)
@@ -67,7 +66,7 @@ func (*fwdTestNetworkEndpoint) Enabled() bool {
func (*fwdTestNetworkEndpoint) Disable() {}
func (f *fwdTestNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
+ return f.nic.MTU() - uint32(f.MaxHeaderLength())
}
func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
@@ -80,7 +79,7 @@ func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen
+ return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen
}
func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
@@ -99,7 +98,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
b[srcAddrOffset] = r.LocalAddress[0]
b[protocolNumberOffset] = byte(params.Protocol)
- return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
+ return f.nic.WritePacket(r, gso, fwdTestNetNumber, pkt)
}
// WritePackets implements LinkEndpoint.WritePackets.
@@ -159,10 +158,9 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol
func (f *fwdTestNetworkProtocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher) NetworkEndpoint {
e := &fwdTestNetworkEndpoint{
- nicID: nic.ID(),
+ nic: nic,
proto: f,
dispatcher: dispatcher,
- ep: nic.LinkEndpoint(),
}
e.AddressableEndpointState.Init(e)
return e
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index faa503b00..8d6d9a7f1 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -502,11 +502,11 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// OriginalDst returns the original destination of redirected connections. It
// returns an error if the connection doesn't exist or isn't redirected.
-func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, *tcpip.Error) {
it.mu.RLock()
defer it.mu.RUnlock()
if !it.modified {
return "", 0, tcpip.ErrNotConnected
}
- return it.connections.originalDst(epID)
+ return it.connections.originalDst(epID, netProto)
}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 8581dd5e8..538c4625d 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -34,7 +34,7 @@ func (at *AcceptTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleAccept, 0
}
@@ -52,7 +52,7 @@ func (dt *DropTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleDrop, 0
}
@@ -76,7 +76,7 @@ func (et *ErrorTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
return RuleDrop, 0
}
@@ -99,7 +99,7 @@ func (uc *UserChainTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -118,7 +118,7 @@ func (rt *ReturnTarget) ID() TargetID {
}
// Action implements Target.Action.
-func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
return RuleReturn, 0
}
@@ -128,26 +128,14 @@ func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.
const RedirectTargetName = "REDIRECT"
// RedirectTarget redirects the packet by modifying the destination port/IP.
-// Min and Max values for IP and Ports in the struct indicate the range of
-// values which can be used to redirect.
+// TODO(gvisor.dev/issue/170): Other flags need to be added after we support
+// them.
type RedirectTarget struct {
- // TODO(gvisor.dev/issue/170): Other flags need to be added after
- // we support them.
- // RangeProtoSpecified flag indicates single port is specified to
- // redirect.
- RangeProtoSpecified bool
+ // Addr indicates address used to redirect.
+ Addr tcpip.Address
- // MinIP indicates address used to redirect.
- MinIP tcpip.Address
-
- // MaxIP indicates address used to redirect.
- MaxIP tcpip.Address
-
- // MinPort indicates port used to redirect.
- MinPort uint16
-
- // MaxPort indicates port used to redirect.
- MaxPort uint16
+ // Port indicates port used to redirect.
+ Port uint16
// NetworkProtocol is the network protocol the target is used with.
NetworkProtocol tcpip.NetworkProtocolNumber
@@ -165,7 +153,7 @@ func (rt *RedirectTarget) ID() TargetID {
// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
// implementation only works for PREROUTING and calls pkt.Clone(), neither
// of which should be the case.
-func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
@@ -176,34 +164,35 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
return RuleDrop, 0
}
- // Change the address to localhost (127.0.0.1) in Output and
- // to primary address of the incoming interface in Prerouting.
+ // Change the address to localhost (127.0.0.1 or ::1) in Output and to
+ // the primary address of the incoming interface in Prerouting.
switch hook {
case Output:
- rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1})
- rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1})
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ rt.Addr = tcpip.Address([]byte{127, 0, 0, 1})
+ } else {
+ rt.Addr = header.IPv6Loopback
+ }
case Prerouting:
- rt.MinIP = address
- rt.MaxIP = address
+ rt.Addr = address
default:
panic("redirect target is supported only on output and prerouting hooks")
}
// TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
// we need to change dest address (for OUTPUT chain) or ports.
- netHeader := header.IPv4(pkt.NetworkHeader().View())
- switch protocol := netHeader.TransportProtocol(); protocol {
+ switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
udpHeader := header.UDP(pkt.TransportHeader().View())
- udpHeader.SetDestinationPort(rt.MinPort)
+ udpHeader.SetDestinationPort(rt.Port)
// Calculate UDP checksum and set it.
if hook == Output {
udpHeader.SetChecksum(0)
- length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
// Only calculate the checksum if offloading isn't supported.
if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := r.PseudoHeaderChecksum(protocol, length)
for _, v := range pkt.Data.Views() {
xsum = header.Checksum(v, xsum)
@@ -212,10 +201,15 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
}
- // Change destination address.
- netHeader.SetDestinationAddress(rt.MinIP)
- netHeader.SetChecksum(0)
- netHeader.SetChecksum(^netHeader.CalculateChecksum())
+
+ pkt.Network().SetDestinationAddress(rt.Addr)
+
+ // After modification, IPv4 packets need a valid checksum.
+ if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ }
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index 27e1feec0..4df288798 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -131,10 +131,17 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
defer entry.mu.Unlock()
switch s := entry.neigh.State; s {
- case Reachable, Static:
+ case Stale:
+ entry.handlePacketQueuedLocked()
+ fallthrough
+ case Reachable, Static, Delay, Probe:
+ // As per RFC 4861 section 7.3.3:
+ // "Neighbor Unreachability Detection operates in parallel with the sending
+ // of packets to a neighbor. While reasserting a neighbor's reachability,
+ // a node continues sending packets to that neighbor using the cached
+ // link-layer address."
return entry.neigh, nil, nil
-
- case Unknown, Incomplete, Stale, Delay, Probe:
+ case Unknown, Incomplete:
entry.addWakerLocked(w)
if entry.done == nil {
@@ -147,10 +154,8 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
entry.handlePacketQueuedLocked()
return entry.neigh, entry.done, tcpip.ErrWouldBlock
-
case Failed:
return entry.neigh, nil, tcpip.ErrNoLinkAddress
-
default:
panic(fmt.Sprintf("Invalid cache entry state: %s", s))
}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index a0b7da5cd..fcd54ed83 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -1500,24 +1500,26 @@ func TestNeighborCacheReplace(t *testing.T) {
}
// Verify the entry exists
- e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != nil {
- t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
- }
- if doneCh != nil {
- t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
- }
- if t.Failed() {
- t.FailNow()
- }
- want := NeighborEntry{
- Addr: entry.Addr,
- LocalAddr: entry.LocalAddr,
- LinkAddr: entry.LinkAddr,
- State: Reachable,
- }
- if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
- t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ {
+ e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ if doneCh != nil {
+ t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ }
}
// Notify of a link address change
@@ -1536,28 +1538,34 @@ func TestNeighborCacheReplace(t *testing.T) {
IsRouter: false,
})
- // Requesting the entry again should start address resolution
+ // Requesting the entry again should start neighbor reachability confirmation.
+ //
+ // Verify the entry's new link address and the new state.
{
- _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
}
- clock.Advance(config.DelayFirstProbeTime + typicalLatency)
- select {
- case <-doneCh:
- default:
- t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: updatedLinkAddr,
+ State: Delay,
}
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ }
+ clock.Advance(config.DelayFirstProbeTime + typicalLatency)
}
- // Verify the entry's new link address
+ // Verify that the neighbor is now reachable.
{
e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
clock.Advance(typicalLatency)
if err != nil {
t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
}
- want = NeighborEntry{
+ want := NeighborEntry{
Addr: entry.Addr,
LocalAddr: entry.LocalAddr,
LinkAddr: updatedLinkAddr,
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 9a72bec79..4d69a4de1 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -236,7 +236,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil {
// There is no need to log the error here; the NUD implementation may
// assume a working link. A valid link should be the responsibility of
// the NIC/stack.LinkEndpoint.
@@ -277,7 +277,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) {
return
}
- if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil {
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil {
e.dispatchRemoveEventLocked()
e.setStateLocked(Failed)
return
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
index a265fff0a..e79abebca 100644
--- a/pkg/tcpip/stack/neighbor_entry_test.go
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -227,8 +227,9 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e
clock := faketime.NewManualClock()
disp := testNUDDispatcher{}
nic := NIC{
- id: entryTestNICID,
- linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
+ LinkEndpoint: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
+
+ id: entryTestNICID,
stack: &Stack{
clock: clock,
nudDisp: &disp,
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 212c6edae..8828cc5fe 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -32,14 +32,18 @@ var _ NetworkInterface = (*NIC)(nil)
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
+ LinkEndpoint
+
stack *Stack
id tcpip.NICID
name string
- linkEP LinkEndpoint
context NICContext
- stats NICStats
- neigh *neighborCache
+ stats NICStats
+ neigh *neighborCache
+
+ // The network endpoints themselves may be modified by calling the interface's
+ // methods, but the map reference and entries must be constant.
networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint
// enabled is set to 1 when the NIC is enabled and 0 when it is disabled.
@@ -88,10 +92,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
// of IPv6 is supported on this endpoint's LinkEndpoint.
nic := &NIC{
+ LinkEndpoint: ep,
+
stack: stack,
id: id,
name: name,
- linkEP: ep,
context: ctx,
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
@@ -127,11 +132,15 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic)
}
- nic.linkEP.Attach(nic)
+ nic.LinkEndpoint.Attach(nic)
return nic
}
+func (n *NIC) getNetworkEndpoint(proto tcpip.NetworkProtocolNumber) NetworkEndpoint {
+ return n.networkEndpoints[proto]
+}
+
// Enabled implements NetworkInterface.
func (n *NIC) Enabled() bool {
return atomic.LoadUint32(&n.enabled) == 1
@@ -211,10 +220,9 @@ func (n *NIC) remove() *tcpip.Error {
for _, ep := range n.networkEndpoints {
ep.Close()
}
- n.networkEndpoints = nil
// Detach from link endpoint, so no packet comes in.
- n.linkEP.Attach(nil)
+ n.LinkEndpoint.Attach(nil)
return nil
}
@@ -234,7 +242,64 @@ func (n *NIC) isPromiscuousMode() bool {
// IsLoopback implements NetworkInterface.
func (n *NIC) IsLoopback() bool {
- return n.linkEP.Capabilities()&CapabilityLoopback != 0
+ return n.LinkEndpoint.Capabilities()&CapabilityLoopback != 0
+}
+
+// WritePacket implements NetworkLinkEndpoint.
+func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ // As per relevant RFCs, we should queue packets while we wait for link
+ // resolution to complete.
+ //
+ // RFC 1122 section 2.3.2.2 (for IPv4):
+ // The link layer SHOULD save (rather than discard) at least
+ // one (the latest) packet of each set of packets destined to
+ // the same unresolved IP address, and transmit the saved
+ // packet when the address has been resolved.
+ //
+ // RFC 4861 section 5.2 (for IPv6):
+ // Once the IP address of the next-hop node is known, the sender
+ // examines the Neighbor Cache for link-layer information about that
+ // neighbor. If no entry exists, the sender creates one, sets its state
+ // to INCOMPLETE, initiates Address Resolution, and then queues the data
+ // packet pending completion of address resolution.
+ if ch, err := r.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ r := r.Clone()
+ n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt)
+ return nil
+ }
+ return err
+ }
+
+ return n.writePacket(r, gso, protocol, pkt)
+}
+
+func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ // WritePacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Size()
+
+ if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil {
+ return err
+ }
+
+ n.stats.Tx.Packets.Increment()
+ n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+ return nil
+}
+
+// WritePackets implements NetworkLinkEndpoint.
+func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution
+ // is being peformed like WritePacket.
+ writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol)
+ n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets))
+ writtenBytes := 0
+ for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() {
+ writtenBytes += pb.Size()
+ }
+
+ n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
+ return writtenPackets, err
}
// setSpoofing enables or disables address spoofing.
@@ -244,22 +309,19 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
-// primaryEndpoint will return the first non-deprecated endpoint if such an
-// endpoint exists for the given protocol and remoteAddr. If no non-deprecated
-// endpoint exists, the first deprecated endpoint will be returned.
-//
-// If an IPv6 primary endpoint is requested, Source Address Selection (as
-// defined by RFC 6724 section 5) will be performed.
+// primaryAddress returns an address that can be used to communicate with
+// remoteAddr.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint {
n.mu.RLock()
- defer n.mu.RUnlock()
+ spoofing := n.mu.spoofing
+ n.mu.RUnlock()
ep, ok := n.networkEndpoints[protocol]
if !ok {
return nil
}
- return ep.AcquirePrimaryAddress(remoteAddr, n.mu.spoofing)
+ return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing)
}
type getAddressBehaviour int
@@ -360,13 +422,12 @@ func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress {
// address exists. If no non-deprecated address exists, the first deprecated
// address will be returned.
func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
- addressEndpoint := n.primaryEndpoint(proto, "")
- if addressEndpoint == nil {
+ ep, ok := n.networkEndpoints[proto]
+ if !ok {
return tcpip.AddressWithPrefix{}
}
- addr := addressEndpoint.AddressWithPrefix()
- addressEndpoint.DecRef()
- return addr
+
+ return ep.MainAddress()
}
// removeAddress removes an address from n.
@@ -487,9 +548,9 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) {
r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */)
+ defer r.Release()
r.RemoteLinkAddress = remotelinkAddr
- addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt)
- addressEndpoint.DecRef()
+ n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
}
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
@@ -523,7 +584,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// If no local link layer address is provided, assume it was sent
// directly to this NIC.
if local == "" {
- local = n.linkEP.LinkAddress()
+ local = n.LinkEndpoint.LinkAddress()
}
// Are any packet type sockets listening for this network protocol?
@@ -603,11 +664,11 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n := r.nic
if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil {
if n.isValidForOutgoing(addressEndpoint) {
- r.LocalLinkAddress = n.linkEP.LinkAddress()
+ r.LocalLinkAddress = n.LinkEndpoint.LinkAddress()
r.RemoteLinkAddress = remote
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
- addressEndpoint.NetworkEndpoint().HandlePacket(&r, pkt)
+ n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt)
addressEndpoint.DecRef()
r.Release()
return
@@ -618,21 +679,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
// n doesn't have a destination endpoint.
// Send the packet out of n.
- // TODO(b/128629022): move this logic to route.WritePacket.
// TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
- if ch, err := r.Resolve(nil); err != nil {
- if err == tcpip.ErrWouldBlock {
- n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
- // forwarder will release route.
- return
- }
+
+ // pkt may have set its header and may not have enough headroom for
+ // link-layer header for the other link to prepend. Here we create a new
+ // packet to forward.
+ fwdPkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()),
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ })
+
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+ if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil {
n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
- r.Release()
- return
}
- // The link-address resolution finished immediately.
- n.forwardPacket(&r, protocol, pkt)
r.Release()
return
}
@@ -656,43 +717,18 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc
p.PktType = tcpip.PacketOutgoing
// Add the link layer header as outgoing packets are intercepted
// before the link layer header is created.
- n.linkEP.AddHeader(local, remote, protocol, p)
+ n.LinkEndpoint.AddHeader(local, remote, protocol, p)
ep.HandlePacket(n.id, local, protocol, p)
}
}
-func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- // TODO(b/143425874) Decrease the TTL field in forwarded packets.
-
- // pkt may have set its header and may not have enough headroom for link-layer
- // header for the other link to prepend. Here we create a new packet to
- // forward.
- fwdPkt := NewPacketBuffer(PacketBufferOptions{
- ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()),
- Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
- })
-
- // WritePacket takes ownership of fwdPkt, calculate numBytes first.
- numBytes := fwdPkt.Size()
-
- if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- return
- }
-
- n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
-}
-
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
- // TODO(gvisor.dev/issue/4365): Let the caller know that the transport
- // protocol is unrecognized.
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
- return TransportPacketHandled
+ return TransportPacketProtocolUnreachable
}
transProto := state.proto
@@ -796,11 +832,6 @@ func (n *NIC) Name() string {
return n.name
}
-// LinkEndpoint implements NetworkInterface.
-func (n *NIC) LinkEndpoint() LinkEndpoint {
- return n.linkEP
-}
-
// nudConfigs gets the NUD configurations for n.
func (n *NIC) nudConfigs() (NUDConfigurations, *tcpip.Error) {
if n.neigh == nil {
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index fdd49b77f..97a96af62 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -33,8 +33,7 @@ var _ NDPEndpoint = (*testIPv6Endpoint)(nil)
type testIPv6Endpoint struct {
AddressableEndpointState
- nicID tcpip.NICID
- linkEP LinkEndpoint
+ nic NetworkInterface
protocol *testIPv6Protocol
invalidatedRtr tcpip.Address
@@ -57,12 +56,12 @@ func (*testIPv6Endpoint) DefaultTTL() uint8 {
// MTU implements NetworkEndpoint.MTU.
func (e *testIPv6Endpoint) MTU() uint32 {
- return e.linkEP.MTU() - header.IPv6MinimumSize
+ return e.nic.MTU() - header.IPv6MinimumSize
}
// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
- return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+ return e.nic.MaxHeaderLength() + header.IPv6MinimumSize
}
// WritePacket implements NetworkEndpoint.WritePacket.
@@ -134,8 +133,7 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint implements NetworkProtocol.NewEndpoint.
func (p *testIPv6Protocol) NewEndpoint(nic NetworkInterface, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher) NetworkEndpoint {
e := &testIPv6Endpoint{
- nicID: nic.ID(),
- linkEP: nic.LinkEndpoint(),
+ nic: nic,
protocol: p,
}
e.AddressableEndpointState.Init(e)
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index a7d9d59fa..105583c49 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
)
type headerType int
@@ -255,6 +256,20 @@ func (pk *PacketBuffer) Clone() *PacketBuffer {
return newPk
}
+// Network returns the network header as a header.Network.
+//
+// Network should only be called when NetworkHeader has been set.
+func (pk *PacketBuffer) Network() header.Network {
+ switch netProto := pk.NetworkProtocolNumber; netProto {
+ case header.IPv4ProtocolNumber:
+ return header.IPv4(pk.NetworkHeader().View())
+ case header.IPv6ProtocolNumber:
+ return header.IPv6(pk.NetworkHeader().View())
+ default:
+ panic(fmt.Sprintf("unknown network protocol number %d", netProto))
+ }
+}
+
// headerInfo stores metadata about a header in a packet.
type headerInfo struct {
// buf is the memorized slice for both prepended and consumed header.
diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/pending_packets.go
index 3eff141e6..f838eda8d 100644
--- a/pkg/tcpip/stack/forwarder.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -29,60 +29,60 @@ const (
)
type pendingPacket struct {
- nic *NIC
route *Route
proto tcpip.NetworkProtocolNumber
pkt *PacketBuffer
}
-type forwardQueue struct {
+// packetsPendingLinkResolution is a queue of packets pending link resolution.
+//
+// Once link resolution completes successfully, the packets will be written.
+type packetsPendingLinkResolution struct {
sync.Mutex
// The packets to send once the resolver completes.
- packets map[<-chan struct{}][]*pendingPacket
+ packets map[<-chan struct{}][]pendingPacket
// FIFO of channels used to cancel the oldest goroutine waiting for
// link-address resolution.
cancelChans []chan struct{}
}
-func newForwardQueue() *forwardQueue {
- return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
+func (f *packetsPendingLinkResolution) init() {
+ f.Lock()
+ defer f.Unlock()
+ f.packets = make(map[<-chan struct{}][]pendingPacket)
}
-func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- shouldWait := false
-
+func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
f.Lock()
+ defer f.Unlock()
+
packets, ok := f.packets[ch]
- if !ok {
- shouldWait = true
- }
- for len(packets) == maxPendingPacketsPerResolution {
+ if len(packets) == maxPendingPacketsPerResolution {
p := packets[0]
+ packets[0] = pendingPacket{}
packets = packets[1:]
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
p.route.Release()
}
+
if l := len(packets); l >= maxPendingPacketsPerResolution {
panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
}
- f.packets[ch] = append(packets, &pendingPacket{
- nic: n,
+
+ f.packets[ch] = append(packets, pendingPacket{
route: r,
- proto: protocol,
+ proto: proto,
pkt: pkt,
})
- f.Unlock()
- if !shouldWait {
+ if ok {
return
}
// Wait for the link-address resolution to complete.
- // Start a goroutine with a forwarding-cancel channel so that we can
- // limit the maximum number of goroutines running concurrently.
- cancel := f.newCancelChannel()
+ cancel := f.newCancelChannelLocked()
go func() {
cancelled := false
select {
@@ -92,17 +92,21 @@ func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tc
}
f.Lock()
- packets := f.packets[ch]
+ packets, ok := f.packets[ch]
delete(f.packets, ch)
f.Unlock()
+ if !ok {
+ panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets"))
+ }
+
for _, p := range packets {
if cancelled {
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else if _, err := p.route.Resolve(nil); err != nil {
- p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Stats().IP.OutgoingPacketErrors.Increment()
} else {
- p.nic.forwardPacket(p.route, p.proto, p.pkt)
+ p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt)
}
p.route.Release()
}
@@ -112,12 +116,10 @@ func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tc
// newCancelChannel creates a channel that can cancel a pending forwarding
// activity. The oldest channel is closed if the number of open channels would
// exceed maxPendingResolutions.
-func (f *forwardQueue) newCancelChannel() chan struct{} {
- f.Lock()
- defer f.Unlock()
-
+func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} {
if len(f.cancelChans) == maxPendingResolutions {
ch := f.cancelChans[0]
+ f.cancelChans[0] = nil
f.cancelChans = f.cancelChans[1:]
close(ch)
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 567e1904e..defb9129b 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -208,6 +208,10 @@ const (
// transport layer and callers need not take any further action.
TransportPacketHandled TransportPacketDisposition = iota
+ // TransportPacketProtocolUnreachable indicates that the transport
+ // protocol requested in the packet is not supported.
+ TransportPacketProtocolUnreachable
+
// TransportPacketDestinationPortUnreachable indicates that there weren't any
// listeners interested in the packet and the transport protocol has no means
// to notify the sender.
@@ -322,10 +326,6 @@ const (
// AssignableAddressEndpoint is a reference counted address endpoint that may be
// assigned to a NetworkEndpoint.
type AssignableAddressEndpoint interface {
- // NetworkEndpoint returns the NetworkEndpoint the receiver is associated
- // with.
- NetworkEndpoint() NetworkEndpoint
-
// AddressWithPrefix returns the endpoint's address.
AddressWithPrefix() tcpip.AddressWithPrefix
@@ -435,7 +435,10 @@ type AddressableEndpoint interface {
// permanent address.
RemovePermanentAddress(addr tcpip.Address) *tcpip.Error
- // AcquireAssignedAddress returns an AddressEndpoint for the passed address
+ // MainAddress returns the endpoint's primary permanent address.
+ MainAddress() tcpip.AddressWithPrefix
+
+ // AcquireAssignedAddress returns an address endpoint for the passed address
// that is considered bound to the endpoint, optionally creating a temporary
// endpoint if requested and no existing address exists.
//
@@ -444,15 +447,15 @@ type AddressableEndpoint interface {
// Returns nil if the specified address is not local to this endpoint.
AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint
- // AcquirePrimaryAddress returns a primary endpoint to use when communicating
- // with the passed remote address.
+ // AcquireOutgoingPrimaryAddress returns a primary address that may be used as
+ // a source address when sending packets to the passed remote address.
//
// If allowExpired is true, expired addresses may be returned.
//
// The returned endpoint's reference count is incremented.
//
- // Returns nil if a primary endpoint is not available.
- AcquirePrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint
+ // Returns nil if a primary address is not available.
+ AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint
// PrimaryAddresses returns the primary addresses.
PrimaryAddresses() []tcpip.AddressWithPrefix
@@ -472,6 +475,8 @@ type NDPEndpoint interface {
// NetworkInterface is a network interface.
type NetworkInterface interface {
+ NetworkLinkEndpoint
+
// ID returns the interface's ID.
ID() tcpip.NICID
@@ -485,9 +490,6 @@ type NetworkInterface interface {
// Enabled returns true if the interface is enabled.
Enabled() bool
-
- // LinkEndpoint returns the link endpoint backing the interface.
- LinkEndpoint() LinkEndpoint
}
// NetworkEndpoint is the interface that needs to be implemented by endpoints
@@ -660,22 +662,15 @@ const (
CapabilitySoftwareGSO
)
-// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
-// ethernet, loopback, raw) and used by network layer protocols to send packets
-// out through the implementer's data link endpoint. When a link header exists,
-// it sets each PacketBuffer's LinkHeader field before passing it up the
-// stack.
-type LinkEndpoint interface {
+// NetworkLinkEndpoint is a data-link layer that supports sending network
+// layer packets.
+type NetworkLinkEndpoint interface {
// MTU is the maximum transmission unit for this endpoint. This is
// usually dictated by the backing physical network; when such a
// physical network doesn't exist, the limit is generally 64k, which
// includes the maximum size of an IP packet.
MTU() uint32
- // Capabilities returns the set of capabilities supported by the
- // endpoint.
- Capabilities() LinkEndpointCapabilities
-
// MaxHeaderLength returns the maximum size the data link (and
// lower level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
@@ -683,7 +678,7 @@ type LinkEndpoint interface {
MaxHeaderLength() uint16
// LinkAddress returns the link address (typically a MAC) of the
- // link endpoint.
+ // endpoint.
LinkAddress() tcpip.LinkAddress
// WritePacket writes a packet with the given protocol through the
@@ -703,6 +698,19 @@ type LinkEndpoint interface {
// offload is enabled. If it will be used for something else, it may
// require to change syscall filters.
WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+}
+
+// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
+// ethernet, loopback, raw) and used by network layer protocols to send packets
+// out through the implementer's data link endpoint. When a link header exists,
+// it sets each PacketBuffer's LinkHeader field before passing it up the
+// stack.
+type LinkEndpoint interface {
+ NetworkLinkEndpoint
+
+ // Capabilities returns the set of capabilities supported by the
+ // endpoint.
+ Capabilities() LinkEndpointCapabilities
// WriteRawPacket writes a packet directly to the link. The packet
// should already have an ethernet header. It takes ownership of vv.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 5ade3c832..25f80c1f8 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -72,21 +72,20 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip
loop |= PacketLoop
}
- linkEP := nic.LinkEndpoint()
r := Route{
NetProto: netProto,
LocalAddress: localAddr,
- LocalLinkAddress: linkEP.LinkAddress(),
+ LocalLinkAddress: nic.LinkEndpoint.LinkAddress(),
RemoteAddress: remoteAddr,
addressEndpoint: addressEndpoint,
nic: nic,
Loop: loop,
}
- if nic := r.nic; linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
- if linkRes, ok := nic.stack.linkAddrResolvers[r.NetProto]; ok {
+ if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok {
r.linkRes = linkRes
- r.linkCache = nic.stack
+ r.linkCache = r.nic.stack
}
}
@@ -100,7 +99,7 @@ func (r *Route) NICID() tcpip.NICID {
// MaxHeaderLength forwards the call to the network endpoint's implementation.
func (r *Route) MaxHeaderLength() uint16 {
- return r.addressEndpoint.NetworkEndpoint().MaxHeaderLength()
+ return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength()
}
// Stats returns a mutable copy of current stats.
@@ -116,23 +115,17 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot
// Capabilities returns the link-layer capabilities of the route.
func (r *Route) Capabilities() LinkEndpointCapabilities {
- return r.nic.LinkEndpoint().Capabilities()
+ return r.nic.LinkEndpoint.Capabilities()
}
// GSOMaxSize returns the maximum GSO packet size.
func (r *Route) GSOMaxSize() uint32 {
- if gso, ok := r.addressEndpoint.NetworkEndpoint().(GSOEndpoint); ok {
+ if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok {
return gso.GSOMaxSize()
}
return 0
}
-// ResolveWith immediately resolves a route with the specified remote link
-// address.
-func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
- r.RemoteLinkAddress = addr
-}
-
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
// notified when address resolution is complete (success or not).
@@ -208,17 +201,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuf
return tcpip.ErrInvalidEndpointState
}
- // WritePacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Size()
-
- err := r.addressEndpoint.NetworkEndpoint().WritePacket(r, gso, params, pkt)
- if err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- } else {
- r.nic.stats.Tx.Packets.Increment()
- r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
- }
- return err
+ return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt)
}
// WritePackets writes a list of n packets through the given route and returns
@@ -228,22 +211,7 @@ func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHead
return 0, tcpip.ErrInvalidEndpointState
}
- // WritePackets takes ownership of pkt, calculate length first.
- numPkts := pkts.Len()
-
- n, err := r.addressEndpoint.NetworkEndpoint().WritePackets(r, gso, pkts, params)
- if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n))
- }
- r.nic.stats.Tx.Packets.IncrementBy(uint64(n))
-
- writtenBytes := 0
- for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() {
- writtenBytes += pb.Size()
- }
-
- r.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
- return n, err
+ return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params)
}
// WriteHeaderIncludedPacket writes a packet already containing a network
@@ -253,32 +221,17 @@ func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first.
- numBytes := pkt.Data.Size()
-
- if err := r.addressEndpoint.NetworkEndpoint().WriteHeaderIncludedPacket(r, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- return err
- }
- r.nic.stats.Tx.Packets.Increment()
- r.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
- return nil
+ return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt)
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
- return r.addressEndpoint.NetworkEndpoint().DefaultTTL()
+ return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL()
}
// MTU returns the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
- return r.addressEndpoint.NetworkEndpoint().MTU()
-}
-
-// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying
-// network endpoint.
-func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
- return r.addressEndpoint.NetworkEndpoint().NetworkProtocolNumber()
+ return r.nic.getNetworkEndpoint(r.NetProto).MTU()
}
// Release frees all resources associated with the route.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 57d8e79e0..3a07577c8 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -436,9 +436,9 @@ type Stack struct {
// uniqueIDGenerator is a generator of unique identifiers.
uniqueIDGenerator UniqueID
- // forwarder holds the packets that wait for their link-address resolutions
- // to complete, and forwards them when each resolution is done.
- forwarder *forwardQueue
+ // linkResQueue holds packets that are waiting for link resolution to
+ // complete.
+ linkResQueue packetsPendingLinkResolution
// randomGenerator is an injectable pseudo random generator that can be
// used when a random number is required.
@@ -550,8 +550,8 @@ type TransportEndpointInfo struct {
// incompatible with the receiver.
//
// Preconditon: the parent endpoint mu must be held while calling this method.
-func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
+func (t *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := t.NetProto
switch len(addr.Addr) {
case header.IPv4AddressSize:
netProto = header.IPv4ProtocolNumber
@@ -565,7 +565,7 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
}
}
- switch len(e.ID.LocalAddress) {
+ switch len(t.ID.LocalAddress) {
case header.IPv4AddressSize:
if len(addr.Addr) == header.IPv6AddressSize {
return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
@@ -577,8 +577,8 @@ func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6onl
}
switch {
- case netProto == e.NetProto:
- case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber:
+ case netProto == t.NetProto:
+ case netProto == header.IPv4ProtocolNumber && t.NetProto == header.IPv6ProtocolNumber:
if v6only {
return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute
}
@@ -640,7 +640,6 @@ func New(opts Options) *Stack {
useNeighborCache: opts.UseNeighborCache,
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
- forwarder: newForwardQueue(),
randomGenerator: mathrand.New(randSrc),
sendBufferSize: SendBufferSizeOption{
Min: MinBufferSize,
@@ -653,6 +652,7 @@ func New(opts Options) *Stack {
Max: DefaultMaxBufferSize,
},
}
+ s.linkResQueue.init()
// Add specified network protocols.
for _, netProtoFactory := range opts.NetworkProtocols {
@@ -928,16 +928,16 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
return s.CreateNICWithOptions(id, ep, NICOptions{})
}
-// GetNICByName gets the NIC specified by name.
-func (s *Stack) GetNICByName(name string) (*NIC, bool) {
+// GetLinkEndpointByName gets the link endpoint specified by name.
+func (s *Stack) GetLinkEndpointByName(name string) LinkEndpoint {
s.mu.RLock()
defer s.mu.RUnlock()
for _, nic := range s.nics {
if nic.Name() == name {
- return nic, true
+ return nic.LinkEndpoint
}
}
- return nil, false
+ return nil
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -1062,13 +1062,13 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
}
nics[id] = NICInfo{
Name: nic.name,
- LinkAddress: nic.linkEP.LinkAddress(),
+ LinkAddress: nic.LinkEndpoint.LinkAddress(),
ProtocolAddresses: nic.primaryAddresses(),
Flags: flags,
- MTU: nic.linkEP.MTU(),
+ MTU: nic.LinkEndpoint.MTU(),
Stats: nic.stats,
Context: nic.context,
- ARPHardwareType: nic.linkEP.ARPHardwareType(),
+ ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(),
}
}
return nics
@@ -1323,7 +1323,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
- return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker)
}
// Neighbors returns all IP to MAC address associations.
@@ -1539,7 +1539,7 @@ func (s *Stack) Wait() {
s.mu.RLock()
defer s.mu.RUnlock()
for _, n := range s.nics {
- n.linkEP.Wait()
+ n.LinkEndpoint.Wait()
}
}
@@ -1627,7 +1627,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t
// Add our own fake ethernet header.
ethFields := header.EthernetFields{
- SrcAddr: nic.linkEP.LinkAddress(),
+ SrcAddr: nic.LinkEndpoint.LinkAddress(),
DstAddr: dst,
Type: netProto,
}
@@ -1636,7 +1636,7 @@ func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto t
vv := buffer.View(fakeHeader).ToVectorisedView()
vv.Append(payload)
- if err := nic.linkEP.WriteRawPacket(vv); err != nil {
+ if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil {
return err
}
@@ -1653,7 +1653,7 @@ func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView)
return tcpip.ErrUnknownDevice
}
- if err := nic.linkEP.WriteRawPacket(payload); err != nil {
+ if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil {
return err
}
@@ -1796,7 +1796,7 @@ func (s *Stack) GetNetworkEndpoint(nicID tcpip.NICID, proto tcpip.NetworkProtoco
return nil, tcpip.ErrUnknownNICID
}
- return nic.networkEndpoints[proto], nil
+ return nic.getNetworkEndpoint(proto), nil
}
// NUDConfigurations gets the per-interface NUD configurations.
@@ -1873,10 +1873,8 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres
if addressEndpoint == nil {
continue
}
-
- ep := addressEndpoint.NetworkEndpoint()
addressEndpoint.DecRef()
- return ep, nil
+ return nic.getNetworkEndpoint(netProto), nil
}
return nil, tcpip.ErrBadAddress
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index fda22c550..38994cca1 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -21,7 +21,6 @@ import (
"bytes"
"fmt"
"math"
- "net"
"sort"
"testing"
"time"
@@ -35,7 +34,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
- "gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -71,24 +69,38 @@ const (
type fakeNetworkEndpoint struct {
stack.AddressableEndpointState
- nicID tcpip.NICID
+ mu struct {
+ sync.RWMutex
+
+ enabled bool
+ }
+
+ nic stack.NetworkInterface
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
- ep stack.LinkEndpoint
}
-func (*fakeNetworkEndpoint) Enable() *tcpip.Error {
+func (f *fakeNetworkEndpoint) Enable() *tcpip.Error {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.enabled = true
return nil
}
-func (*fakeNetworkEndpoint) Enabled() bool {
- return true
+func (f *fakeNetworkEndpoint) Enabled() bool {
+ f.mu.RLock()
+ defer f.mu.RUnlock()
+ return f.mu.enabled
}
-func (*fakeNetworkEndpoint) Disable() {}
+func (f *fakeNetworkEndpoint) Disable() {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.mu.enabled = false
+}
func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.ep.MTU() - uint32(f.MaxHeaderLength())
+ return f.nic.MTU() - uint32(f.MaxHeaderLength())
}
func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
@@ -120,7 +132,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.ep.MaxHeaderLength() + fakeNetHeaderLen
+ return f.nic.MaxHeaderLength() + fakeNetHeaderLen
}
func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
@@ -149,7 +161,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
return nil
}
- return f.ep.WritePacket(r, gso, fakeNetNumber, pkt)
+ return f.nic.WritePacket(r, gso, fakeNetNumber, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
@@ -201,10 +213,9 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
func (f *fakeNetworkProtocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint {
e := &fakeNetworkEndpoint{
- nicID: nic.ID(),
+ nic: nic,
proto: f,
dispatcher: dispatcher,
- ep: nic.LinkEndpoint(),
}
e.AddressableEndpointState.Init(e)
return e
@@ -2091,7 +2102,7 @@ func TestNICStats(t *testing.T) {
t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want)
}
- if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want {
+ if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
@@ -3487,52 +3498,6 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
}
-func TestResolveWith(t *testing.T) {
- const (
- unspecifiedNICID = 0
- nicID = 1
- )
-
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol},
- })
- ep := channel.New(0, defaultMTU, "")
- ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
- }
- addr := tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
- PrefixLen: 24,
- },
- }
- if err := s.AddProtocolAddress(nicID, addr); err != nil {
- t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
- }
-
- s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
-
- remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4())
- r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
- if err != nil {
- t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err)
- }
- defer r.Release()
-
- // Should initially require resolution.
- if !r.IsResolutionRequired() {
- t.Fatal("got r.IsResolutionRequired() = false, want = true")
- }
-
- // Manually resolving the route should no longer require resolution.
- r.ResolveWith("\x01")
- if r.IsResolutionRequired() {
- t.Fatal("got r.IsResolutionRequired() = true, want = false")
- }
-}
-
// TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its
// associated address is removed should not cause a panic.
func TestRouteReleaseAfterAddrRemoval(t *testing.T) {
@@ -3620,3 +3585,43 @@ func TestGetNetworkEndpoint(t *testing.T) {
})
}
}
+
+func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ })
+
+ if err := s.CreateNIC(nicID, channel.New(0, defaultMTU, "")); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: fakeNetNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: "\x01",
+ PrefixLen: 8,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protocolAddress, err)
+ }
+
+ // Check that we get the right initial address and prefix length.
+ if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
+ } else if gotAddr != protocolAddress.AddressWithPrefix {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ }
+
+ // Should still get the address when the NIC is diabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
+ t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
+ } else if gotAddr != protocolAddress.AddressWithPrefix {
+ t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ }
+}