summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD9
-rw-r--r--pkg/tcpip/stack/forwarder.go4
-rw-r--r--pkg/tcpip/stack/forwarder_test.go36
-rw-r--r--pkg/tcpip/stack/iptables.go311
-rw-r--r--pkg/tcpip/stack/iptables_targets.go144
-rw-r--r--pkg/tcpip/stack/iptables_types.go180
-rw-r--r--pkg/tcpip/stack/ndp.go300
-rw-r--r--pkg/tcpip/stack/ndp_test.go168
-rw-r--r--pkg/tcpip/stack/nic.go128
-rw-r--r--pkg/tcpip/stack/nic_test.go3
-rw-r--r--pkg/tcpip/stack/packet_buffer.go71
-rw-r--r--pkg/tcpip/stack/packet_buffer_state.go27
-rw-r--r--pkg/tcpip/stack/rand.go40
-rw-r--r--pkg/tcpip/stack/registration.go35
-rw-r--r--pkg/tcpip/stack/route.go6
-rw-r--r--pkg/tcpip/stack/stack.go55
-rw-r--r--pkg/tcpip/stack/stack_test.go546
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go126
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go2
-rw-r--r--pkg/tcpip/stack/transport_test.go27
20 files changed, 1642 insertions, 576 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 6c029b2fb..8d80e9cee 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -21,10 +21,16 @@ go_library(
"dhcpv6configurationfromndpra_string.go",
"forwarder.go",
"icmp_rate_limit.go",
+ "iptables.go",
+ "iptables_targets.go",
+ "iptables_types.go",
"linkaddrcache.go",
"linkaddrentry_list.go",
"ndp.go",
"nic.go",
+ "packet_buffer.go",
+ "packet_buffer_state.go",
+ "rand.go",
"registration.go",
"route.go",
"stack.go",
@@ -34,6 +40,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/ilist",
+ "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
"//pkg/sync",
@@ -41,7 +48,6 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/waiter",
@@ -65,7 +71,6 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go
index 631953935..6b64cd37f 100644
--- a/pkg/tcpip/stack/forwarder.go
+++ b/pkg/tcpip/stack/forwarder.go
@@ -32,7 +32,7 @@ type pendingPacket struct {
nic *NIC
route *Route
proto tcpip.NetworkProtocolNumber
- pkt tcpip.PacketBuffer
+ pkt PacketBuffer
}
type forwardQueue struct {
@@ -50,7 +50,7 @@ func newForwardQueue() *forwardQueue {
return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
}
-func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
shouldWait := false
f.Lock()
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 321b7524d..c45c43d21 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -68,7 +68,7 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
return &f.id
}
-func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt tcpip.PacketBuffer) {
+func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt PacketBuffer) {
// Consume the network header.
b := pkt.Data.First()
pkt.Data.TrimFront(fwdTestNetHeaderLen)
@@ -89,7 +89,7 @@ func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities {
return f.ep.Capabilities()
}
-func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error {
// Add the protocol's header to the packet and send it to the link
// endpoint.
b := pkt.Header.Prepend(fwdTestNetHeaderLen)
@@ -101,11 +101,11 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
}
// WritePackets implements LinkEndpoint.WritePackets.
-func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) {
+func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) {
panic("not implemented")
}
-func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -183,7 +183,7 @@ func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumb
type fwdTestPacketInfo struct {
RemoteLinkAddress tcpip.LinkAddress
LocalLinkAddress tcpip.LinkAddress
- Pkt tcpip.PacketBuffer
+ Pkt PacketBuffer
}
type fwdTestLinkEndpoint struct {
@@ -196,12 +196,12 @@ type fwdTestLinkEndpoint struct {
}
// InjectInbound injects an inbound packet.
-func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
e.InjectLinkAddr(protocol, "", pkt)
}
// InjectLinkAddr injects an inbound packet with a remote link address.
-func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt tcpip.PacketBuffer) {
+func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt PacketBuffer) {
e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, pkt)
}
@@ -244,7 +244,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error {
p := fwdTestPacketInfo{
RemoteLinkAddress: r.RemoteLinkAddress,
LocalLinkAddress: r.LocalLinkAddress,
@@ -260,7 +260,7 @@ func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.Netw
}
// WritePackets stores outbound packets into the channel.
-func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
n := 0
for _, pkt := range pkts {
e.WritePacket(r, gso, protocol, pkt)
@@ -273,7 +273,7 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts []tcpip.Pack
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
p := fwdTestPacketInfo{
- Pkt: tcpip.PacketBuffer{Data: vv},
+ Pkt: PacketBuffer{Data: vv},
}
select {
@@ -355,7 +355,7 @@ func TestForwardingWithStaticResolver(t *testing.T) {
// forwarded to NIC 2.
buf := buffer.NewView(30)
buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -392,7 +392,7 @@ func TestForwardingWithFakeResolver(t *testing.T) {
// forwarded to NIC 2.
buf := buffer.NewView(30)
buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -423,7 +423,7 @@ func TestForwardingWithNoResolver(t *testing.T) {
// forwarded to NIC 2.
buf := buffer.NewView(30)
buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -453,7 +453,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// not be forwarded.
buf := buffer.NewView(30)
buf[0] = 4
- ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -461,7 +461,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// forwarded to NIC 2.
buf = buffer.NewView(30)
buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -503,7 +503,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
for i := 0; i < 2; i++ {
buf := buffer.NewView(30)
buf[0] = 3
- ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -550,7 +550,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
buf[0] = 3
// Set the packet sequence number.
binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
- ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -603,7 +603,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// maxPendingResolutions + 7).
buf := buffer.NewView(30)
buf[0] = byte(3 + i)
- ep1.InjectInbound(fwdTestNetNumber, tcpip.PacketBuffer{
+ ep1.InjectInbound(fwdTestNetNumber, PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
new file mode 100644
index 000000000..37907ae24
--- /dev/null
+++ b/pkg/tcpip/stack/iptables.go
@@ -0,0 +1,311 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// Table names.
+const (
+ TablenameNat = "nat"
+ TablenameMangle = "mangle"
+ TablenameFilter = "filter"
+)
+
+// Chain names as defined by net/ipv4/netfilter/ip_tables.c.
+const (
+ ChainNamePrerouting = "PREROUTING"
+ ChainNameInput = "INPUT"
+ ChainNameForward = "FORWARD"
+ ChainNameOutput = "OUTPUT"
+ ChainNamePostrouting = "POSTROUTING"
+)
+
+// HookUnset indicates that there is no hook set for an entrypoint or
+// underflow.
+const HookUnset = -1
+
+// DefaultTables returns a default set of tables. Each chain is set to accept
+// all packets.
+func DefaultTables() IPTables {
+ // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for
+ // iotas.
+ return IPTables{
+ Tables: map[string]Table{
+ TablenameNat: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: map[Hook]int{
+ Prerouting: 0,
+ Input: 1,
+ Output: 2,
+ Postrouting: 3,
+ },
+ Underflows: map[Hook]int{
+ Prerouting: 0,
+ Input: 1,
+ Output: 2,
+ Postrouting: 3,
+ },
+ UserChains: map[string]int{},
+ },
+ TablenameMangle: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: map[Hook]int{
+ Prerouting: 0,
+ Output: 1,
+ },
+ Underflows: map[Hook]int{
+ Prerouting: 0,
+ Output: 1,
+ },
+ UserChains: map[string]int{},
+ },
+ TablenameFilter: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: map[Hook]int{
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ },
+ Underflows: map[Hook]int{
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ },
+ UserChains: map[string]int{},
+ },
+ },
+ Priorities: map[Hook][]string{
+ Input: []string{TablenameNat, TablenameFilter},
+ Prerouting: []string{TablenameMangle, TablenameNat},
+ Output: []string{TablenameMangle, TablenameNat, TablenameFilter},
+ },
+ }
+}
+
+// EmptyFilterTable returns a Table with no rules and the filter table chains
+// mapped to HookUnset.
+func EmptyFilterTable() Table {
+ return Table{
+ Rules: []Rule{},
+ BuiltinChains: map[Hook]int{
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: HookUnset,
+ },
+ Underflows: map[Hook]int{
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: HookUnset,
+ },
+ UserChains: map[string]int{},
+ }
+}
+
+// EmptyNatTable returns a Table with no rules and the filter table chains
+// mapped to HookUnset.
+func EmptyNatTable() Table {
+ return Table{
+ Rules: []Rule{},
+ BuiltinChains: map[Hook]int{
+ Prerouting: HookUnset,
+ Input: HookUnset,
+ Output: HookUnset,
+ Postrouting: HookUnset,
+ },
+ Underflows: map[Hook]int{
+ Prerouting: HookUnset,
+ Input: HookUnset,
+ Output: HookUnset,
+ Postrouting: HookUnset,
+ },
+ UserChains: map[string]int{},
+ }
+}
+
+// A chainVerdict is what a table decides should be done with a packet.
+type chainVerdict int
+
+const (
+ // chainAccept indicates the packet should continue through netstack.
+ chainAccept chainVerdict = iota
+
+ // chainAccept indicates the packet should be dropped.
+ chainDrop
+
+ // chainReturn indicates the packet should return to the calling chain
+ // or the underflow rule of a builtin chain.
+ chainReturn
+)
+
+// Check runs pkt through the rules for hook. It returns true when the packet
+// should continue traversing the network stack and false when it should be
+// dropped.
+//
+// Precondition: pkt.NetworkHeader is set.
+func (it *IPTables) Check(hook Hook, pkt PacketBuffer) bool {
+ // Go through each table containing the hook.
+ for _, tablename := range it.Priorities[hook] {
+ table := it.Tables[tablename]
+ ruleIdx := table.BuiltinChains[hook]
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict {
+ // If the table returns Accept, move on to the next table.
+ case chainAccept:
+ continue
+ // The Drop verdict is final.
+ case chainDrop:
+ return false
+ case chainReturn:
+ // Any Return from a built-in chain means we have to
+ // call the underflow.
+ underflow := table.Rules[table.Underflows[hook]]
+ switch v, _ := underflow.Target.Action(pkt); v {
+ case RuleAccept:
+ continue
+ case RuleDrop:
+ return false
+ case RuleJump, RuleReturn:
+ panic("Underflows should only return RuleAccept or RuleDrop.")
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", v))
+ }
+
+ default:
+ panic(fmt.Sprintf("Unknown verdict %v.", verdict))
+ }
+ }
+
+ // Every table returned Accept.
+ return true
+}
+
+// Precondition: pkt.NetworkHeader is set.
+func (it *IPTables) checkChain(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) chainVerdict {
+ // Start from ruleIdx and walk the list of rules until a rule gives us
+ // a verdict.
+ for ruleIdx < len(table.Rules) {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict {
+ case RuleAccept:
+ return chainAccept
+
+ case RuleDrop:
+ return chainDrop
+
+ case RuleReturn:
+ return chainReturn
+
+ case RuleJump:
+ // "Jumping" to the next rule just means we're
+ // continuing on down the list.
+ if jumpTo == ruleIdx+1 {
+ ruleIdx++
+ continue
+ }
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict {
+ case chainAccept:
+ return chainAccept
+ case chainDrop:
+ return chainDrop
+ case chainReturn:
+ ruleIdx++
+ continue
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", verdict))
+ }
+
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", verdict))
+ }
+
+ }
+
+ // We got through the entire table without a decision. Default to DROP
+ // for safety.
+ return chainDrop
+}
+
+// Precondition: pk.NetworkHeader is set.
+func (it *IPTables) checkRule(hook Hook, pkt PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) {
+ rule := table.Rules[ruleIdx]
+
+ // If pkt.NetworkHeader hasn't been set yet, it will be contained in
+ // pkt.Data.First().
+ if pkt.NetworkHeader == nil {
+ pkt.NetworkHeader = pkt.Data.First()
+ }
+
+ // Check whether the packet matches the IP header filter.
+ if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader)) {
+ // Continue on to the next rule.
+ return RuleJump, ruleIdx + 1
+ }
+
+ // Go through each rule matcher. If they all match, run
+ // the rule target.
+ for _, matcher := range rule.Matchers {
+ matches, hotdrop := matcher.Match(hook, pkt, "")
+ if hotdrop {
+ return RuleDrop, 0
+ }
+ if !matches {
+ // Continue on to the next rule.
+ return RuleJump, ruleIdx + 1
+ }
+ }
+
+ // All the matchers matched, so run the target.
+ return rule.Target.Action(pkt)
+}
+
+func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool {
+ // TODO(gvisor.dev/issue/170): Support other fields of the filter.
+ // Check the transport protocol.
+ if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() {
+ return false
+ }
+
+ // Check the destination IP.
+ dest := hdr.DestinationAddress()
+ matches := true
+ for i := range filter.Dst {
+ if dest[i]&filter.DstMask[i] != filter.Dst[i] {
+ matches = false
+ break
+ }
+ }
+ if matches == filter.DstInvert {
+ return false
+ }
+
+ return true
+}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
new file mode 100644
index 000000000..7b4543caf
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -0,0 +1,144 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// AcceptTarget accepts packets.
+type AcceptTarget struct{}
+
+// Action implements Target.Action.
+func (AcceptTarget) Action(packet PacketBuffer) (RuleVerdict, int) {
+ return RuleAccept, 0
+}
+
+// DropTarget drops packets.
+type DropTarget struct{}
+
+// Action implements Target.Action.
+func (DropTarget) Action(packet PacketBuffer) (RuleVerdict, int) {
+ return RuleDrop, 0
+}
+
+// ErrorTarget logs an error and drops the packet. It represents a target that
+// should be unreachable.
+type ErrorTarget struct{}
+
+// Action implements Target.Action.
+func (ErrorTarget) Action(packet PacketBuffer) (RuleVerdict, int) {
+ log.Debugf("ErrorTarget triggered.")
+ return RuleDrop, 0
+}
+
+// UserChainTarget marks a rule as the beginning of a user chain.
+type UserChainTarget struct {
+ Name string
+}
+
+// Action implements Target.Action.
+func (UserChainTarget) Action(PacketBuffer) (RuleVerdict, int) {
+ panic("UserChainTarget should never be called.")
+}
+
+// ReturnTarget returns from the current chain. If the chain is a built-in, the
+// hook's underflow should be called.
+type ReturnTarget struct{}
+
+// Action implements Target.Action.
+func (ReturnTarget) Action(PacketBuffer) (RuleVerdict, int) {
+ return RuleReturn, 0
+}
+
+// RedirectTarget redirects the packet by modifying the destination port/IP.
+// Min and Max values for IP and Ports in the struct indicate the range of
+// values which can be used to redirect.
+type RedirectTarget struct {
+ // TODO(gvisor.dev/issue/170): Other flags need to be added after
+ // we support them.
+ // RangeProtoSpecified flag indicates single port is specified to
+ // redirect.
+ RangeProtoSpecified bool
+
+ // Min address used to redirect.
+ MinIP tcpip.Address
+
+ // Max address used to redirect.
+ MaxIP tcpip.Address
+
+ // Min port used to redirect.
+ MinPort uint16
+
+ // Max port used to redirect.
+ MaxPort uint16
+}
+
+// Action implements Target.Action.
+// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
+// implementation only works for PREROUTING and calls pkt.Clone(), neither
+// of which should be the case.
+func (rt RedirectTarget) Action(pkt PacketBuffer) (RuleVerdict, int) {
+ newPkt := pkt.Clone()
+
+ // Set network header.
+ headerView := newPkt.Data.First()
+ netHeader := header.IPv4(headerView)
+ newPkt.NetworkHeader = headerView[:header.IPv4MinimumSize]
+
+ hlen := int(netHeader.HeaderLength())
+ tlen := int(netHeader.TotalLength())
+ newPkt.Data.TrimFront(hlen)
+ newPkt.Data.CapLength(tlen - hlen)
+
+ // TODO(gvisor.dev/issue/170): Change destination address to
+ // loopback or interface address on which the packet was
+ // received.
+
+ // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
+ // we need to change dest address (for OUTPUT chain) or ports.
+ switch protocol := netHeader.TransportProtocol(); protocol {
+ case header.UDPProtocolNumber:
+ var udpHeader header.UDP
+ if newPkt.TransportHeader != nil {
+ udpHeader = header.UDP(newPkt.TransportHeader)
+ } else {
+ if len(pkt.Data.First()) < header.UDPMinimumSize {
+ return RuleDrop, 0
+ }
+ udpHeader = header.UDP(newPkt.Data.First())
+ }
+ udpHeader.SetDestinationPort(rt.MinPort)
+ case header.TCPProtocolNumber:
+ var tcpHeader header.TCP
+ if newPkt.TransportHeader != nil {
+ tcpHeader = header.TCP(newPkt.TransportHeader)
+ } else {
+ if len(pkt.Data.First()) < header.TCPMinimumSize {
+ return RuleDrop, 0
+ }
+ tcpHeader = header.TCP(newPkt.TransportHeader)
+ }
+ // TODO(gvisor.dev/issue/170): Need to recompute checksum
+ // and implement nat connection tracking to support TCP.
+ tcpHeader.SetDestinationPort(rt.MinPort)
+ default:
+ return RuleDrop, 0
+ }
+
+ return RuleAccept, 0
+}
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
new file mode 100644
index 000000000..2ffb55f2a
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -0,0 +1,180 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// A Hook specifies one of the hooks built into the network stack.
+//
+// Userspace app Userspace app
+// ^ |
+// | v
+// [Input] [Output]
+// ^ |
+// | v
+// | routing
+// | |
+// | v
+// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]----->
+type Hook uint
+
+// These values correspond to values in include/uapi/linux/netfilter.h.
+const (
+ // Prerouting happens before a packet is routed to applications or to
+ // be forwarded.
+ Prerouting Hook = iota
+
+ // Input happens before a packet reaches an application.
+ Input
+
+ // Forward happens once it's decided that a packet should be forwarded
+ // to another host.
+ Forward
+
+ // Output happens after a packet is written by an application to be
+ // sent out.
+ Output
+
+ // Postrouting happens just before a packet goes out on the wire.
+ Postrouting
+
+ // The total number of hooks.
+ NumHooks
+)
+
+// A RuleVerdict is what a rule decides should be done with a packet.
+type RuleVerdict int
+
+const (
+ // RuleAccept indicates the packet should continue through netstack.
+ RuleAccept RuleVerdict = iota
+
+ // RuleDrop indicates the packet should be dropped.
+ RuleDrop
+
+ // RuleJump indicates the packet should jump to another chain.
+ RuleJump
+
+ // RuleReturn indicates the packet should return to the previous chain.
+ RuleReturn
+)
+
+// IPTables holds all the tables for a netstack.
+type IPTables struct {
+ // Tables maps table names to tables. User tables have arbitrary names.
+ Tables map[string]Table
+
+ // Priorities maps each hook to a list of table names. The order of the
+ // list is the order in which each table should be visited for that
+ // hook.
+ Priorities map[Hook][]string
+}
+
+// A Table defines a set of chains and hooks into the network stack. It is
+// really just a list of rules with some metadata for entrypoints and such.
+type Table struct {
+ // Rules holds the rules that make up the table.
+ Rules []Rule
+
+ // BuiltinChains maps builtin chains to their entrypoint rule in Rules.
+ BuiltinChains map[Hook]int
+
+ // Underflows maps builtin chains to their underflow rule in Rules
+ // (i.e. the rule to execute if the chain returns without a verdict).
+ Underflows map[Hook]int
+
+ // UserChains holds user-defined chains for the keyed by name. Users
+ // can give their chains arbitrary names.
+ UserChains map[string]int
+
+ // Metadata holds information about the Table that is useful to users
+ // of IPTables, but not to the netstack IPTables code itself.
+ metadata interface{}
+}
+
+// ValidHooks returns a bitmap of the builtin hooks for the given table.
+func (table *Table) ValidHooks() uint32 {
+ hooks := uint32(0)
+ for hook := range table.BuiltinChains {
+ hooks |= 1 << hook
+ }
+ return hooks
+}
+
+// Metadata returns the metadata object stored in table.
+func (table *Table) Metadata() interface{} {
+ return table.metadata
+}
+
+// SetMetadata sets the metadata object stored in table.
+func (table *Table) SetMetadata(metadata interface{}) {
+ table.metadata = metadata
+}
+
+// A Rule is a packet processing rule. It consists of two pieces. First it
+// contains zero or more matchers, each of which is a specification of which
+// packets this rule applies to. If there are no matchers in the rule, it
+// applies to any packet.
+type Rule struct {
+ // Filter holds basic IP filtering fields common to every rule.
+ Filter IPHeaderFilter
+
+ // Matchers is the list of matchers for this rule.
+ Matchers []Matcher
+
+ // Target is the action to invoke if all the matchers match the packet.
+ Target Target
+}
+
+// IPHeaderFilter holds basic IP filtering data common to every rule.
+type IPHeaderFilter struct {
+ // Protocol matches the transport protocol.
+ Protocol tcpip.TransportProtocolNumber
+
+ // Dst matches the destination IP address.
+ Dst tcpip.Address
+
+ // DstMask masks bits of the destination IP address when comparing with
+ // Dst.
+ DstMask tcpip.Address
+
+ // DstInvert inverts the meaning of the destination IP check, i.e. when
+ // true the filter will match packets that fail the destination
+ // comparison.
+ DstInvert bool
+}
+
+// A Matcher is the interface for matching packets.
+type Matcher interface {
+ // Name returns the name of the Matcher.
+ Name() string
+
+ // Match returns whether the packet matches and whether the packet
+ // should be "hotdropped", i.e. dropped immediately. This is usually
+ // used for suspicious packets.
+ //
+ // Precondition: packet.NetworkHeader is set.
+ Match(hook Hook, packet PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
+}
+
+// A Target is the interface for taking an action for a packet.
+type Target interface {
+ // Action takes an action on the packet and returns a verdict on how
+ // traversal should (or should not) continue. If the return value is
+ // Jump, it also returns the index of the rule to jump to.
+ Action(packet PacketBuffer) (RuleVerdict, int)
+}
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index a9f4d5dad..630fdefc5 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -361,16 +361,16 @@ type ndpState struct {
// The default routers discovered through Router Advertisements.
defaultRouters map[tcpip.Address]defaultRouterState
+ // The timer used to send the next router solicitation message.
+ rtrSolicitTimer *time.Timer
+
// The on-link prefixes discovered through Router Advertisements' Prefix
// Information option.
onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState
- // The timer used to send the next router solicitation message.
- // If routers are being solicited, rtrSolicitTimer MUST NOT be nil.
- rtrSolicitTimer *time.Timer
-
- // The addresses generated by SLAAC.
- autoGenAddresses map[tcpip.Address]autoGenAddressState
+ // The SLAAC prefixes discovered through Router Advertisements' Prefix
+ // Information option.
+ slaacPrefixes map[tcpip.Subnet]slaacPrefixState
// The last learned DHCPv6 configuration from an NDP RA.
dhcpv6Configuration DHCPv6ConfigurationFromNDPRA
@@ -402,18 +402,16 @@ type onLinkPrefixState struct {
invalidationTimer tcpip.CancellableTimer
}
-// autoGenAddressState holds data associated with an address generated via
-// SLAAC.
-type autoGenAddressState struct {
- // A reference to the referencedNetworkEndpoint that this autoGenAddressState
- // is holding state for.
- ref *referencedNetworkEndpoint
-
+// slaacPrefixState holds state associated with a SLAAC prefix.
+type slaacPrefixState struct {
deprecationTimer tcpip.CancellableTimer
invalidationTimer tcpip.CancellableTimer
// Nonzero only when the address is not valid forever.
validUntil time.Time
+
+ // The prefix's permanent address endpoint.
+ ref *referencedNetworkEndpoint
}
// startDuplicateAddressDetection performs Duplicate Address Detection.
@@ -566,7 +564,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error {
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
TOS: DefaultTOS,
- }, tcpip.PacketBuffer{Header: hdr},
+ }, PacketBuffer{Header: hdr},
); err != nil {
sent.Dropped.Increment()
return err
@@ -899,23 +897,15 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
prefix := pi.Subnet()
- // Check if we already have an auto-generated address for prefix.
- for addr, addrState := range ndp.autoGenAddresses {
- refAddrWithPrefix := tcpip.AddressWithPrefix{Address: addr, PrefixLen: addrState.ref.ep.PrefixLen()}
- if refAddrWithPrefix.Subnet() != prefix {
- continue
- }
-
- // At this point, we know we are refreshing a SLAAC generated IPv6 address
- // with the prefix prefix. Do the work as outlined by RFC 4862 section
- // 5.5.3.e.
- ndp.refreshAutoGenAddressLifetimes(addr, pl, vl)
+ // Check if we already maintain SLAAC state for prefix.
+ if _, ok := ndp.slaacPrefixes[prefix]; ok {
+ // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes.
+ ndp.refreshSLAACPrefixLifetimes(prefix, pl, vl)
return
}
- // We do not already have an address with the prefix prefix. Do the
- // work as outlined by RFC 4862 section 5.5.3.d if n is configured
- // to auto-generate global addresses by SLAAC.
+ // prefix is a new SLAAC prefix. Do the work as outlined by RFC 4862 section
+ // 5.5.3.d if ndp is configured to auto-generate new addresses via SLAAC.
if !ndp.configs.AutoGenGlobalAddresses {
return
}
@@ -927,6 +917,8 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform
// for prefix.
//
// pl is the new preferred lifetime. vl is the new valid lifetime.
+//
+// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// If we do not already have an address for this prefix and the valid
// lifetime is 0, no need to do anything further, as per RFC 4862
@@ -942,9 +934,59 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
return
}
+ // If the preferred lifetime is zero, then the prefix should be considered
+ // deprecated.
+ deprecated := pl == 0
+ ref := ndp.addSLAACAddr(prefix, deprecated)
+ if ref == nil {
+ // We were unable to generate a permanent address for prefix so do nothing
+ // further as there is no reason to maintain state for a SLAAC prefix we
+ // cannot generate a permanent address for.
+ return
+ }
+
+ state := slaacPrefixState{
+ deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ log.Fatalf("ndp: must have a slaacPrefixes entry for the SLAAC prefix %s", prefix)
+ }
+
+ ndp.deprecateSLAACAddress(prefixState.ref)
+ }),
+ invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
+ ndp.invalidateSLAACPrefix(prefix, true)
+ }),
+ ref: ref,
+ }
+
+ // Setup the initial timers to deprecate and invalidate prefix.
+
+ if !deprecated && pl < header.NDPInfiniteLifetime {
+ state.deprecationTimer.Reset(pl)
+ }
+
+ if vl < header.NDPInfiniteLifetime {
+ state.invalidationTimer.Reset(vl)
+ state.validUntil = time.Now().Add(vl)
+ }
+
+ ndp.slaacPrefixes[prefix] = state
+}
+
+// addSLAACAddr adds a SLAAC address for prefix.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) addSLAACAddr(prefix tcpip.Subnet, deprecated bool) *referencedNetworkEndpoint {
addrBytes := []byte(prefix.ID())
if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
- addrBytes = header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], prefix, oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name), 0 /* dadCounter */, oIID.SecretKey)
+ addrBytes = header.AppendOpaqueInterfaceIdentifier(
+ addrBytes[:header.IIDOffsetInIPv6Address],
+ prefix,
+ oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name),
+ 0, /* dadCounter */
+ oIID.SecretKey,
+ )
} else {
// Only attempt to generate an interface-specific IID if we have a valid
// link address.
@@ -953,137 +995,103 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
// LinkEndpoint.LinkAddress) before reaching this point.
linkAddr := ndp.nic.linkEP.LinkAddress()
if !header.IsValidUnicastEthernetAddress(linkAddr) {
- return
+ return nil
}
// Generate an address within prefix from the modified EUI-64 of ndp's NIC's
// Ethernet MAC address.
header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
}
- addr := tcpip.Address(addrBytes)
- addrWithPrefix := tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: validPrefixLenForAutoGen,
+
+ generatedAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: validPrefixLenForAutoGen,
+ },
}
// If the nic already has this address, do nothing further.
- if ndp.nic.hasPermanentAddrLocked(addr) {
- return
+ if ndp.nic.hasPermanentAddrLocked(generatedAddr.AddressWithPrefix.Address) {
+ return nil
}
// Inform the integrator that we have a new SLAAC address.
ndpDisp := ndp.nic.stack.ndpDisp
if ndpDisp == nil {
- return
+ return nil
}
- if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addrWithPrefix) {
+
+ if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), generatedAddr.AddressWithPrefix) {
// Informed by the integrator not to add the address.
- return
+ return nil
}
- protocolAddr := tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: addrWithPrefix,
- }
- // If the preferred lifetime is zero, then the address should be considered
- // deprecated.
- deprecated := pl == 0
- ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated)
+ ref, err := ndp.nic.addAddressLocked(generatedAddr, FirstPrimaryEndpoint, permanent, slaac, deprecated)
if err != nil {
- log.Fatalf("ndp: error when adding address %s: %s", protocolAddr, err)
- }
-
- state := autoGenAddressState{
- ref: ref,
- deprecationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
- addrState, ok := ndp.autoGenAddresses[addr]
- if !ok {
- log.Fatalf("ndp: must have an autoGenAddressess entry for the SLAAC generated IPv6 address %s", addr)
- }
- addrState.ref.deprecated = true
- ndp.notifyAutoGenAddressDeprecated(addr)
- }),
- invalidationTimer: tcpip.MakeCancellableTimer(&ndp.nic.mu, func() {
- ndp.invalidateAutoGenAddress(addr)
- }),
+ log.Fatalf("ndp: error when adding address %+v: %s", generatedAddr, err)
}
- // Setup the initial timers to deprecate and invalidate this newly generated
- // address.
-
- if !deprecated && pl < header.NDPInfiniteLifetime {
- state.deprecationTimer.Reset(pl)
- }
-
- if vl < header.NDPInfiniteLifetime {
- state.invalidationTimer.Reset(vl)
- state.validUntil = time.Now().Add(vl)
- }
-
- ndp.autoGenAddresses[addr] = state
+ return ref
}
-// refreshAutoGenAddressLifetimes refreshes the lifetime of a SLAAC generated
-// address addr.
+// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix.
//
// pl is the new preferred lifetime. vl is the new valid lifetime.
-func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl time.Duration) {
- addrState, ok := ndp.autoGenAddresses[addr]
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, pl, vl time.Duration) {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
if !ok {
- log.Fatalf("ndp: SLAAC state not found to refresh lifetimes for %s", addr)
+ log.Fatalf("ndp: SLAAC prefix state not found to refresh lifetimes for %s", prefix)
}
- defer func() { ndp.autoGenAddresses[addr] = addrState }()
+ defer func() { ndp.slaacPrefixes[prefix] = prefixState }()
- // If the preferred lifetime is zero, then the address should be considered
- // deprecated.
+ // If the preferred lifetime is zero, then the prefix should be deprecated.
deprecated := pl == 0
- wasDeprecated := addrState.ref.deprecated
- addrState.ref.deprecated = deprecated
-
- // Only send the deprecation event if the deprecated status for addr just
- // changed from non-deprecated to deprecated.
- if !wasDeprecated && deprecated {
- ndp.notifyAutoGenAddressDeprecated(addr)
+ if deprecated {
+ ndp.deprecateSLAACAddress(prefixState.ref)
+ } else {
+ prefixState.ref.deprecated = false
}
- // If addr was preferred for some finite lifetime before, stop the deprecation
- // timer so it can be reset.
- addrState.deprecationTimer.StopLocked()
+ // If prefix was preferred for some finite lifetime before, stop the
+ // deprecation timer so it can be reset.
+ prefixState.deprecationTimer.StopLocked()
- // Reset the deprecation timer if addr has a finite preferred lifetime.
+ // Reset the deprecation timer if prefix has a finite preferred lifetime.
if !deprecated && pl < header.NDPInfiniteLifetime {
- addrState.deprecationTimer.Reset(pl)
+ prefixState.deprecationTimer.Reset(pl)
}
- // As per RFC 4862 section 5.5.3.e, the valid lifetime of the address
- //
+ // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix:
//
// 1) If the received Valid Lifetime is greater than 2 hours or greater than
- // RemainingLifetime, set the valid lifetime of the address to the
+ // RemainingLifetime, set the valid lifetime of the prefix to the
// advertised Valid Lifetime.
//
// 2) If RemainingLifetime is less than or equal to 2 hours, ignore the
// advertised Valid Lifetime.
//
- // 3) Otherwise, reset the valid lifetime of the address to 2 hours.
+ // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
// Handle the infinite valid lifetime separately as we do not keep a timer in
// this case.
if vl >= header.NDPInfiniteLifetime {
- addrState.invalidationTimer.StopLocked()
- addrState.validUntil = time.Time{}
+ prefixState.invalidationTimer.StopLocked()
+ prefixState.validUntil = time.Time{}
return
}
var effectiveVl time.Duration
var rl time.Duration
- // If the address was originally set to be valid forever, assume the remaining
+ // If the prefix was originally set to be valid forever, assume the remaining
// time to be the maximum possible value.
- if addrState.validUntil == (time.Time{}) {
+ if prefixState.validUntil == (time.Time{}) {
rl = header.NDPInfiniteLifetime
} else {
- rl = time.Until(addrState.validUntil)
+ rl = time.Until(prefixState.validUntil)
}
if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
@@ -1094,58 +1102,66 @@ func (ndp *ndpState) refreshAutoGenAddressLifetimes(addr tcpip.Address, pl, vl t
effectiveVl = MinPrefixInformationValidLifetimeForUpdate
}
- addrState.invalidationTimer.StopLocked()
- addrState.invalidationTimer.Reset(effectiveVl)
- addrState.validUntil = time.Now().Add(effectiveVl)
-}
-
-// notifyAutoGenAddressDeprecated notifies the stack's NDP dispatcher that addr
-// has been deprecated.
-func (ndp *ndpState) notifyAutoGenAddressDeprecated(addr tcpip.Address) {
- if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
- ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: validPrefixLenForAutoGen,
- })
- }
+ prefixState.invalidationTimer.StopLocked()
+ prefixState.invalidationTimer.Reset(effectiveVl)
+ prefixState.validUntil = time.Now().Add(effectiveVl)
}
-// invalidateAutoGenAddress invalidates an auto-generated address.
+// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP
+// dispatcher that ref has been deprecated.
+//
+// deprecateSLAACAddress does nothing if ref is already deprecated.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) invalidateAutoGenAddress(addr tcpip.Address) {
- if !ndp.cleanupAutoGenAddrResourcesAndNotify(addr) {
+func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) {
+ if ref.deprecated {
return
}
- ndp.nic.removePermanentAddressLocked(addr)
+ ref.deprecated = true
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), tcpip.AddressWithPrefix{
+ Address: ref.ep.ID().LocalAddress,
+ PrefixLen: ref.ep.PrefixLen(),
+ })
+ }
}
-// cleanupAutoGenAddrResourcesAndNotify cleans up an invalidated auto-generated
-// address's resources from ndp. If the stack has an NDP dispatcher, it will
-// be notified that addr has been invalidated.
-//
-// Returns true if ndp had resources for addr to cleanup.
+// invalidateSLAACPrefix invalidates a SLAAC prefix.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bool {
- state, ok := ndp.autoGenAddresses[addr]
+func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, removeAddr bool) {
+ state, ok := ndp.slaacPrefixes[prefix]
if !ok {
- return false
+ return
}
state.deprecationTimer.StopLocked()
state.invalidationTimer.StopLocked()
- delete(ndp.autoGenAddresses, addr)
+ delete(ndp.slaacPrefixes, prefix)
+
+ addr := state.ref.ep.ID().LocalAddress
+
+ if removeAddr {
+ if err := ndp.nic.removePermanentAddressLocked(addr); err != nil {
+ log.Fatalf("ndp: removePermanentAddressLocked(%s): %s", addr, err)
+ }
+ }
if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), tcpip.AddressWithPrefix{
Address: addr,
- PrefixLen: validPrefixLenForAutoGen,
+ PrefixLen: state.ref.ep.PrefixLen(),
})
}
+}
- return true
+// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC
+// address's resources from ndp.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix) {
+ ndp.invalidateSLAACPrefix(addr.Subnet(), false)
}
// cleanupState cleans up ndp's state.
@@ -1163,21 +1179,21 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo
// The NIC that ndp belongs to MUST be locked.
func (ndp *ndpState) cleanupState(hostOnly bool) {
linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
- linkLocalAddrs := 0
- for addr := range ndp.autoGenAddresses {
+ linkLocalPrefixes := 0
+ for prefix := range ndp.slaacPrefixes {
// RFC 4862 section 5 states that routers are also expected to generate a
// link-local address so we do not invalidate them if we are cleaning up
// host-only state.
- if hostOnly && linkLocalSubnet.Contains(addr) {
- linkLocalAddrs++
+ if hostOnly && prefix == linkLocalSubnet {
+ linkLocalPrefixes++
continue
}
- ndp.invalidateAutoGenAddress(addr)
+ ndp.invalidateSLAACPrefix(prefix, true)
}
- if got := len(ndp.autoGenAddresses); got != linkLocalAddrs {
- log.Fatalf("ndp: still have non-linklocal auto-generated addresses after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalAddrs)
+ if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
+ log.Fatalf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes)
}
for prefix := range ndp.onLinkPrefixes {
@@ -1267,7 +1283,7 @@ func (ndp *ndpState) startSolicitingRouters() {
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
TOS: DefaultTOS,
- }, tcpip.PacketBuffer{Header: hdr},
+ }, PacketBuffer{Header: hdr},
); err != nil {
sent.Dropped.Increment()
log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 98b1c807c..06edd05b6 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -602,7 +602,7 @@ func TestDADFail(t *testing.T) {
// Receive a packet to simulate multiple nodes owning or
// attempting to own the same address.
hdr := test.makeBuf(addr1)
- e.InjectInbound(header.IPv6ProtocolNumber, tcpip.PacketBuffer{
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.PacketBuffer{
Data: hdr.View().ToVectorisedView(),
})
@@ -639,8 +639,9 @@ func TestDADStop(t *testing.T) {
const nicID = 1
tests := []struct {
- name string
- stopFn func(t *testing.T, s *stack.Stack)
+ name string
+ stopFn func(t *testing.T, s *stack.Stack)
+ skipFinalAddrCheck bool
}{
// Tests to make sure that DAD stops when an address is removed.
{
@@ -661,6 +662,19 @@ func TestDADStop(t *testing.T) {
}
},
},
+
+ // Tests to make sure that DAD stops when the NIC is removed.
+ {
+ name: "Remove NIC",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("RemoveNIC(%d): %s", nicID, err)
+ }
+ },
+ // The NIC is removed so we can't check its addresses after calling
+ // stopFn.
+ skipFinalAddrCheck: true,
+ },
}
for _, test := range tests {
@@ -710,12 +724,15 @@ func TestDADStop(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+
+ if !test.skipFinalAddrCheck {
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
}
// Should not have sent more than 1 NS message.
@@ -901,7 +918,7 @@ func TestSetNDPConfigurations(t *testing.T) {
// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
// and DHCPv6 configurations specified.
-func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
+func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) stack.PacketBuffer {
icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
pkt := header.ICMPv6(hdr.Prepend(icmpSize))
@@ -936,14 +953,14 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo
DstAddr: header.IPv6AllNodesMulticastAddress,
})
- return tcpip.PacketBuffer{Data: hdr.View().ToVectorisedView()}
+ return stack.PacketBuffer{Data: hdr.View().ToVectorisedView()}
}
// raBufWithOpts returns a valid NDP Router Advertisement with options.
//
// Note, raBufWithOpts does not populate any of the RA fields other than the
// Router Lifetime.
-func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) tcpip.PacketBuffer {
+func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) stack.PacketBuffer {
return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
}
@@ -952,7 +969,7 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ
//
// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
// DHCPv6 related ones.
-func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) tcpip.PacketBuffer {
+func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) stack.PacketBuffer {
return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
}
@@ -960,7 +977,7 @@ func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bo
//
// Note, raBuf does not populate any of the RA fields other than the
// Router Lifetime.
-func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer {
+func raBuf(ip tcpip.Address, rl uint16) stack.PacketBuffer {
return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{})
}
@@ -969,7 +986,7 @@ func raBuf(ip tcpip.Address, rl uint16) tcpip.PacketBuffer {
//
// Note, raBufWithPI does not populate any of the RA fields other than the
// Router Lifetime.
-func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) tcpip.PacketBuffer {
+func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) stack.PacketBuffer {
flags := uint8(0)
if onLink {
// The OnLink flag is the 7th bit in the flags byte.
@@ -2983,11 +3000,12 @@ func TestCleanupNDPState(t *testing.T) {
cleanupFn func(t *testing.T, s *stack.Stack)
keepAutoGenLinkLocal bool
maxAutoGenAddrEvents int
+ skipFinalAddrCheck bool
}{
// A NIC should still keep its auto-generated link-local address when
// becoming a router.
{
- name: "Forwarding Enable",
+ name: "Enable forwarding",
cleanupFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
s.SetForwarding(true)
@@ -2998,7 +3016,7 @@ func TestCleanupNDPState(t *testing.T) {
// A NIC should cleanup all NDP state when it is disabled.
{
- name: "NIC Disable",
+ name: "Disable NIC",
cleanupFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
@@ -3012,6 +3030,26 @@ func TestCleanupNDPState(t *testing.T) {
keepAutoGenLinkLocal: false,
maxAutoGenAddrEvents: 6,
},
+
+ // A NIC should cleanup all NDP state when it is removed.
+ {
+ name: "Remove NIC",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ if err := s.RemoveNIC(nicID1); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err)
+ }
+ if err := s.RemoveNIC(nicID2); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err)
+ }
+ },
+ keepAutoGenLinkLocal: false,
+ maxAutoGenAddrEvents: 6,
+ // The NICs are removed so we can't check their addresses after calling
+ // stopFn.
+ skipFinalAddrCheck: true,
+ },
}
for _, test := range tests {
@@ -3230,35 +3268,37 @@ func TestCleanupNDPState(t *testing.T) {
t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
}
- // Make sure the auto-generated addresses got removed.
- nicinfo = s.NICInfo()
- nic1Addrs = nicinfo[nicID1].ProtocolAddresses
- nic2Addrs = nicinfo[nicID2].ProtocolAddresses
- if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
- if test.keepAutoGenLinkLocal {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- } else {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ if !test.skipFinalAddrCheck {
+ // Make sure the auto-generated addresses got removed.
+ nicinfo = s.NICInfo()
+ nic1Addrs = nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs = nicinfo[nicID2].ProtocolAddresses
+ if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
}
- }
- if containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
- if test.keepAutoGenLinkLocal {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- } else {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ if containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ }
+ if containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
}
- }
- if containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
}
// Should not get any more events (invalidation timers should have been
@@ -3575,17 +3615,19 @@ func TestStopStartSolicitingRouters(t *testing.T) {
tests := []struct {
name string
startFn func(t *testing.T, s *stack.Stack)
- stopFn func(t *testing.T, s *stack.Stack)
+ // first is used to tell stopFn that it is being called for the first time
+ // after router solicitations were last enabled.
+ stopFn func(t *testing.T, s *stack.Stack, first bool)
}{
// Tests that when forwarding is enabled or disabled, router solicitations
// are stopped or started, respectively.
{
- name: "Forwarding enabled and disabled",
+ name: "Enable and disable forwarding",
startFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
s.SetForwarding(false)
},
- stopFn: func(t *testing.T, s *stack.Stack) {
+ stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
t.Helper()
s.SetForwarding(true)
},
@@ -3594,7 +3636,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Tests that when a NIC is enabled or disabled, router solicitations
// are started or stopped, respectively.
{
- name: "NIC disabled and enabled",
+ name: "Enable and disable NIC",
startFn: func(t *testing.T, s *stack.Stack) {
t.Helper()
@@ -3602,7 +3644,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
}
},
- stopFn: func(t *testing.T, s *stack.Stack) {
+ stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
t.Helper()
if err := s.DisableNIC(nicID); err != nil {
@@ -3610,6 +3652,25 @@ func TestStopStartSolicitingRouters(t *testing.T) {
}
},
},
+
+ // Tests that when a NIC is removed, router solicitations are stopped. We
+ // cannot start router solications on a removed NIC.
+ {
+ name: "Remove NIC",
+ stopFn: func(t *testing.T, s *stack.Stack, first bool) {
+ t.Helper()
+
+ // Only try to remove the NIC the first time stopFn is called since it's
+ // impossible to remove an already removed NIC.
+ if !first {
+ return
+ }
+
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
+ }
+ },
+ },
}
for _, test := range tests {
@@ -3648,7 +3709,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
}
// Stop soliciting routers.
- test.stopFn(t, s)
+ test.stopFn(t, s, true /* first */)
ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
@@ -3662,13 +3723,18 @@ func TestStopStartSolicitingRouters(t *testing.T) {
// Stopping router solicitations after it has already been stopped should
// do nothing.
- test.stopFn(t, s)
+ test.stopFn(t, s, false /* first */)
ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
defer cancel()
if _, ok := e.ReadContext(ctx); ok {
t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
}
+ // If test.startFn is nil, there is no way to restart router solications.
+ if test.startFn == nil {
+ return
+ }
+
// Start soliciting routers.
test.startFn(t, s)
waitForPkt(delay + defaultAsyncEventTimeout)
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 3cd5fec71..b6fa647ea 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -15,6 +15,7 @@
package stack
import (
+ "fmt"
"log"
"reflect"
"sort"
@@ -25,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
)
var ipv4BroadcastAddr = tcpip.ProtocolAddress{
@@ -55,7 +55,7 @@ type NIC struct {
primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
addressRanges []tcpip.Subnet
- mcastJoins map[NetworkEndpointID]int32
+ mcastJoins map[NetworkEndpointID]uint32
// packetEPs is protected by mu, but the contained PacketEndpoint
// values are not.
packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
@@ -122,15 +122,15 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
}
nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
- nic.mu.mcastJoins = make(map[NetworkEndpointID]int32)
+ nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32)
nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
nic.mu.ndp = ndpState{
- nic: nic,
- configs: stack.ndpConfigs,
- dad: make(map[tcpip.Address]dadState),
- defaultRouters: make(map[tcpip.Address]defaultRouterState),
- onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
- autoGenAddresses: make(map[tcpip.Address]autoGenAddressState),
+ nic: nic,
+ configs: stack.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
}
// Register supported packet endpoint protocols.
@@ -166,8 +166,17 @@ func (n *NIC) disable() *tcpip.Error {
}
n.mu.Lock()
- defer n.mu.Unlock()
+ err := n.disableLocked()
+ n.mu.Unlock()
+ return err
+}
+// disableLocked disables n.
+//
+// It undoes the work done by enable.
+//
+// n MUST be locked.
+func (n *NIC) disableLocked() *tcpip.Error {
if !n.mu.enabled {
return nil
}
@@ -190,7 +199,7 @@ func (n *NIC) disable() *tcpip.Error {
}
// The NIC may have already left the multicast group.
- if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
}
@@ -306,24 +315,33 @@ func (n *NIC) remove() *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- // Detach from link endpoint, so no packet comes in.
- n.linkEP.Attach(nil)
+ n.disableLocked()
+
+ // TODO(b/151378115): come up with a better way to pick an error than the
+ // first one.
+ var err *tcpip.Error
+
+ // Forcefully leave multicast groups.
+ for nid := range n.mu.mcastJoins {
+ if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil {
+ err = tempErr
+ }
+ }
// Remove permanent and permanentTentative addresses, so no packet goes out.
- var errs []*tcpip.Error
for nid, ref := range n.mu.endpoints {
switch ref.getKind() {
case permanentTentative, permanent:
- if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil {
- errs = append(errs, err)
+ if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil {
+ err = tempErr
}
}
}
- if len(errs) > 0 {
- return errs[0]
- }
- return nil
+ // Detach from link endpoint, so no packet comes in.
+ n.linkEP.Attach(nil)
+
+ return err
}
// becomeIPv6Router transitions n into an IPv6 router.
@@ -970,6 +988,7 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
for i, ref := range refs {
if ref == r {
n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ refs[len(refs)-1] = nil
break
}
}
@@ -997,8 +1016,7 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr)
if isIPv6Unicast {
- // If we are removing a tentative IPv6 unicast address, stop
- // DAD.
+ // If we are removing a tentative IPv6 unicast address, stop DAD.
if kind == permanentTentative {
n.mu.ndp.stopDuplicateAddressDetection(addr)
}
@@ -1006,7 +1024,10 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
// If we are removing an address generated via SLAAC, cleanup
// its SLAAC resources and notify the integrator.
if r.configType == slaac {
- n.mu.ndp.cleanupAutoGenAddrResourcesAndNotify(addr)
+ n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(tcpip.AddressWithPrefix{
+ Address: addr,
+ PrefixLen: r.ep.PrefixLen(),
+ })
}
}
@@ -1020,9 +1041,12 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
// If we are removing an IPv6 unicast address, leave the solicited-node
// multicast address.
+ //
+ // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node
+ // multicast group may be left by user action.
if isIPv6Unicast {
snmc := header.SolicitedNodeAddr(addr)
- if err := n.leaveGroupLocked(snmc); err != nil {
+ if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
}
@@ -1082,26 +1106,31 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- return n.leaveGroupLocked(addr)
+ return n.leaveGroupLocked(addr, false /* force */)
}
// leaveGroupLocked decrements the count for the given multicast address, and
// when it reaches zero removes the endpoint for this address. n MUST be locked
// before leaveGroupLocked is called.
-func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+//
+// If force is true, then the count for the multicast addres is ignored and the
+// endpoint will be removed immediately.
+func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error {
id := NetworkEndpointID{addr}
- joins := n.mu.mcastJoins[id]
- switch joins {
- case 0:
+ joins, ok := n.mu.mcastJoins[id]
+ if !ok {
// There are no joins with this address on this NIC.
return tcpip.ErrBadLocalAddress
- case 1:
- // This is the last one, clean up.
- if err := n.removePermanentAddressLocked(addr); err != nil {
- return err
- }
}
- n.mu.mcastJoins[id] = joins - 1
+
+ joins--
+ if force || joins == 0 {
+ // There are no outstanding joins or we are forced to leave, clean up.
+ delete(n.mu.mcastJoins, id)
+ return n.removePermanentAddressLocked(addr)
+ }
+
+ n.mu.mcastJoins[id] = joins
return nil
}
@@ -1114,7 +1143,7 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool {
return joins != 0
}
-func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) {
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt PacketBuffer) {
r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
@@ -1128,7 +1157,7 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address,
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
-func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
n.mu.RLock()
enabled := n.mu.enabled
// If the NIC is not yet enabled, don't receive any packets.
@@ -1192,7 +1221,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
// TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet.
if protocol == header.IPv4ProtocolNumber {
ipt := n.stack.IPTables()
- if ok := ipt.Check(iptables.Prerouting, pkt); !ok {
+ if ok := ipt.Check(Prerouting, pkt); !ok {
// iptables is telling us to drop the packet.
return
}
@@ -1257,11 +1286,26 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
}
-func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
+func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
- pkt.Header = buffer.NewPrependableFromView(pkt.Data.First())
+
+ firstData := pkt.Data.First()
pkt.Data.RemoveFirst()
+ if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 {
+ pkt.Header = buffer.NewPrependableFromView(firstData)
+ } else {
+ firstDataLen := len(firstData)
+
+ // pkt.Header should have enough capacity to hold n.linkEP's headers.
+ pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen)
+
+ // TODO(b/151227689): avoid copying the packet when forwarding
+ if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen))
+ }
+ }
+
if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return
@@ -1273,7 +1317,7 @@ func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -1319,7 +1363,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// DeliverTransportControlPacket delivers control packets to the appropriate
// transport protocol endpoint.
-func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer) {
state, ok := n.stack.transportProtocols[trans]
if !ok {
return
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index edaee3b86..d672fc157 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -17,7 +17,6 @@ package stack
import (
"testing"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -45,7 +44,7 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
t.FailNow()
}
- nic.DeliverNetworkPacket(nil, "", "", 0, tcpip.PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
+ nic.DeliverNetworkPacket(nil, "", "", 0, PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()})
if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
new file mode 100644
index 000000000..9505a4e92
--- /dev/null
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -0,0 +1,71 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import "gvisor.dev/gvisor/pkg/tcpip/buffer"
+
+// A PacketBuffer contains all the data of a network packet.
+//
+// As a PacketBuffer traverses up the stack, it may be necessary to pass it to
+// multiple endpoints. Clone() should be called in such cases so that
+// modifications to the Data field do not affect other copies.
+//
+// +stateify savable
+type PacketBuffer struct {
+ // Data holds the payload of the packet. For inbound packets, it also
+ // holds the headers, which are consumed as the packet moves up the
+ // stack. Headers are guaranteed not to be split across views.
+ //
+ // The bytes backing Data are immutable, but Data itself may be trimmed
+ // or otherwise modified.
+ Data buffer.VectorisedView
+
+ // DataOffset is used for GSO output. It is the offset into the Data
+ // field where the payload of this packet starts.
+ DataOffset int
+
+ // DataSize is used for GSO output. It is the size of this packet's
+ // payload.
+ DataSize int
+
+ // Header holds the headers of outbound packets. As a packet is passed
+ // down the stack, each layer adds to Header.
+ Header buffer.Prependable
+
+ // These fields are used by both inbound and outbound packets. They
+ // typically overlap with the Data and Header fields.
+ //
+ // The bytes backing these views are immutable. Each field may be nil
+ // if either it has not been set yet or no such header exists (e.g.
+ // packets sent via loopback may not have a link header).
+ //
+ // These fields may be Views into other slices (either Data or Header).
+ // SR dosen't support this, so deep copies are necessary in some cases.
+ LinkHeader buffer.View
+ NetworkHeader buffer.View
+ TransportHeader buffer.View
+
+ // Hash is the transport layer hash of this packet. A value of zero
+ // indicates no valid hash has been set.
+ Hash uint32
+}
+
+// Clone makes a copy of pk. It clones the Data field, which creates a new
+// VectorisedView but does not deep copy the underlying bytes.
+//
+// Clone also does not deep copy any of its other fields.
+func (pk PacketBuffer) Clone() PacketBuffer {
+ pk.Data = pk.Data.Clone(nil)
+ return pk
+}
diff --git a/pkg/tcpip/stack/packet_buffer_state.go b/pkg/tcpip/stack/packet_buffer_state.go
new file mode 100644
index 000000000..0c6b7924c
--- /dev/null
+++ b/pkg/tcpip/stack/packet_buffer_state.go
@@ -0,0 +1,27 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import "gvisor.dev/gvisor/pkg/tcpip/buffer"
+
+// beforeSave is invoked by stateify.
+func (pk *PacketBuffer) beforeSave() {
+ // Non-Data fields may be slices of the Data field. This causes
+ // problems for SR, so during save we make each header independent.
+ pk.Header = pk.Header.DeepCopy()
+ pk.LinkHeader = append(buffer.View(nil), pk.LinkHeader...)
+ pk.NetworkHeader = append(buffer.View(nil), pk.NetworkHeader...)
+ pk.TransportHeader = append(buffer.View(nil), pk.TransportHeader...)
+}
diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go
new file mode 100644
index 000000000..421fb5c15
--- /dev/null
+++ b/pkg/tcpip/stack/rand.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ mathrand "math/rand"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// lockedRandomSource provides a threadsafe rand.Source.
+type lockedRandomSource struct {
+ mu sync.Mutex
+ src mathrand.Source
+}
+
+func (r *lockedRandomSource) Int63() (n int64) {
+ r.mu.Lock()
+ n = r.src.Int63()
+ r.mu.Unlock()
+ return n
+}
+
+func (r *lockedRandomSource) Seed(seed int64) {
+ r.mu.Lock()
+ r.src.Seed(seed)
+ r.mu.Unlock()
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index f9fd8f18f..ac043b722 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -67,12 +67,12 @@ type TransportEndpoint interface {
// this transport endpoint. It sets pkt.TransportHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer)
+ HandlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer)
// HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
// HandleControlPacket takes ownership of pkt.
- HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer)
// Abort initiates an expedited endpoint teardown. It puts the endpoint
// in a closed state and frees all resources associated with it. This
@@ -100,7 +100,7 @@ type RawTransportEndpoint interface {
// layer up.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt tcpip.PacketBuffer)
+ HandlePacket(r *Route, pkt PacketBuffer)
}
// PacketEndpoint is the interface that needs to be implemented by packet
@@ -118,7 +118,7 @@ type PacketEndpoint interface {
// should construct its own ethernet header for applications.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
+ HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt PacketBuffer)
}
// TransportProtocol is the interface that needs to be implemented by transport
@@ -150,7 +150,7 @@ type TransportProtocol interface {
// stats purposes only).
//
// HandleUnknownDestinationPacket takes ownership of pkt.
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt PacketBuffer) bool
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -180,7 +180,7 @@ type TransportDispatcher interface {
// pkt.NetworkHeader must be set before calling DeliverTransportPacket.
//
// DeliverTransportPacket takes ownership of pkt.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer)
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
@@ -189,7 +189,7 @@ type TransportDispatcher interface {
// DeliverTransportControlPacket.
//
// DeliverTransportControlPacket takes ownership of pkt.
- DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer)
}
// PacketLooping specifies where an outbound packet should be sent.
@@ -242,15 +242,15 @@ type NetworkEndpoint interface {
// WritePacket writes a packet to the given destination address and
// protocol. It sets pkt.NetworkHeader. pkt.TransportHeader must have
// already been set.
- WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error
// WritePackets writes packets to the given destination address and
// protocol. pkts must not be zero length.
- WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error)
+ WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
// header to the given destination address.
- WriteHeaderIncludedPacket(r *Route, pkt tcpip.PacketBuffer) *tcpip.Error
+ WriteHeaderIncludedPacket(r *Route, pkt PacketBuffer) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
@@ -265,7 +265,7 @@ type NetworkEndpoint interface {
// this network endpoint. It sets pkt.NetworkHeader.
//
// HandlePacket takes ownership of pkt.
- HandlePacket(r *Route, pkt tcpip.PacketBuffer)
+ HandlePacket(r *Route, pkt PacketBuffer)
// Close is called when the endpoint is reomved from a stack.
Close()
@@ -322,7 +322,7 @@ type NetworkDispatcher interface {
// packets sent via loopback), and won't have the field set.
//
// DeliverNetworkPacket takes ownership of pkt.
- DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
+ DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -354,7 +354,7 @@ const (
// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
// ethernet, loopback, raw) and used by network layer protocols to send packets
// out through the implementer's data link endpoint. When a link header exists,
-// it sets each tcpip.PacketBuffer's LinkHeader field before passing it up the
+// it sets each PacketBuffer's LinkHeader field before passing it up the
// stack.
type LinkEndpoint interface {
// MTU is the maximum transmission unit for this endpoint. This is
@@ -385,7 +385,7 @@ type LinkEndpoint interface {
// To participate in transparent bridging, a LinkEndpoint implementation
// should call eth.Encode with header.EthernetFields.SrcAddr set to
// r.LocalLinkAddress if it is provided.
- WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer) *tcpip.Error
// WritePackets writes packets with the given protocol through the
// given route. pkts must not be zero length.
@@ -393,7 +393,7 @@ type LinkEndpoint interface {
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, it may
// require to change syscall filters.
- WritePackets(r *Route, gso *GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+ WritePackets(r *Route, gso *GSO, pkts []PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
// WriteRawPacket writes a packet directly to the link. The packet
// should already have an ethernet header.
@@ -401,6 +401,9 @@ type LinkEndpoint interface {
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
+ //
+ // Attach will be called with a nil dispatcher if the receiver's associated
+ // NIC is being removed.
Attach(dispatcher NetworkDispatcher)
// IsAttached returns whether a NetworkDispatcher is attached to the
@@ -423,7 +426,7 @@ type InjectableLinkEndpoint interface {
LinkEndpoint
// InjectInbound injects an inbound packet.
- InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer)
+ InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt PacketBuffer)
// InjectOutbound writes a fully formed outbound packet directly to the
// link.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index f565aafb2..9fbe8a411 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -153,7 +153,7 @@ func (r *Route) IsResolutionRequired() bool {
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
@@ -169,7 +169,7 @@ func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt tcpip.Pack
}
// WritePackets writes the set of packets through the given route.
-func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) {
+func (r *Route) WritePackets(gso *GSO, pkts []PacketBuffer, params NetworkHeaderParams) (int, *tcpip.Error) {
if !r.ref.isValidForOutgoing() {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -190,7 +190,7 @@ func (r *Route) WritePackets(gso *GSO, pkts []tcpip.PacketBuffer, params Network
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (r *Route) WriteHeaderIncludedPacket(pkt tcpip.PacketBuffer) *tcpip.Error {
+func (r *Route) WriteHeaderIncludedPacket(pkt PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 6f423874a..41398a1b6 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -20,7 +20,9 @@
package stack
import (
+ "bytes"
"encoding/binary"
+ mathrand "math/rand"
"sync/atomic"
"time"
@@ -31,7 +33,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/waiter"
@@ -51,7 +52,7 @@ const (
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) bool
+ defaultHandler func(r *Route, id TransportEndpointID, pkt PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -428,7 +429,7 @@ type Stack struct {
// tables are the iptables packet filtering and manipulation rules. The are
// protected by tablesMu.`
- tables iptables.IPTables
+ tables IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
// stack is being restored.
@@ -466,6 +467,10 @@ type Stack struct {
// forwarder holds the packets that wait for their link-address resolutions
// to complete, and forwards them when each resolution is done.
forwarder *forwardQueue
+
+ // randomGenerator is an injectable pseudo random generator that can be
+ // used when a random number is required.
+ randomGenerator *mathrand.Rand
}
// UniqueID is an abstract generator of unique identifiers.
@@ -526,9 +531,16 @@ type Options struct {
// this is non-nil.
RawFactory RawFactory
- // OpaqueIIDOpts hold the options for generating opaque interface identifiers
- // (IIDs) as outlined by RFC 7217.
+ // OpaqueIIDOpts hold the options for generating opaque interface
+ // identifiers (IIDs) as outlined by RFC 7217.
OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // RandSource is an optional source to use to generate random
+ // numbers. If omitted it defaults to a Source seeded by the data
+ // returned by rand.Read().
+ //
+ // RandSource must be thread-safe.
+ RandSource mathrand.Source
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -624,6 +636,13 @@ func New(opts Options) *Stack {
opts.UniqueID = new(uniqueIDGenerator)
}
+ randSrc := opts.RandSource
+ if randSrc == nil {
+ // Source provided by mathrand.NewSource is not thread-safe so
+ // we wrap it in a simple thread-safe version.
+ randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
+ }
+
// Make sure opts.NDPConfigs contains valid values only.
opts.NDPConfigs.validate()
@@ -646,6 +665,7 @@ func New(opts Options) *Stack {
ndpDisp: opts.NDPDisp,
opaqueIIDOpts: opts.OpaqueIIDOpts,
forwarder: newForwardQueue(),
+ randomGenerator: mathrand.New(randSrc),
}
// Add specified network protocols.
@@ -738,7 +758,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, tcpip.PacketBuffer) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
@@ -1701,7 +1721,7 @@ func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool,
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() iptables.IPTables {
+func (s *Stack) IPTables() IPTables {
s.tablesMu.RLock()
t := s.tables
s.tablesMu.RUnlock()
@@ -1709,7 +1729,7 @@ func (s *Stack) IPTables() iptables.IPTables {
}
// SetIPTables sets the stack's iptables.
-func (s *Stack) SetIPTables(ipt iptables.IPTables) {
+func (s *Stack) SetIPTables(ipt IPTables) {
s.tablesMu.Lock()
s.tables = ipt
s.tablesMu.Unlock()
@@ -1819,6 +1839,12 @@ func (s *Stack) Seed() uint32 {
return s.seed
}
+// Rand returns a reference to a pseudo random generator that can be used
+// to generate random numbers as required.
+func (s *Stack) Rand() *mathrand.Rand {
+ return s.randomGenerator
+}
+
func generateRandUint32() uint32 {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
@@ -1826,3 +1852,16 @@ func generateRandUint32() uint32 {
}
return binary.LittleEndian.Uint32(b)
}
+
+func generateRandInt64() int64 {
+ b := make([]byte, 8)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ buf := bytes.NewReader(b)
+ var v int64
+ if err := binary.Read(buf, binary.LittleEndian, &v); err != nil {
+ panic(err)
+ }
+ return v
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index e15db40fb..555fcd92f 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -90,7 +90,7 @@ func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
return &f.id
}
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) {
+func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
@@ -126,7 +126,7 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return f.ep.Capabilities()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt stack.PacketBuffer) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
@@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
views[0] = pkt.Header.View()
views = append(views, pkt.Data.Views()...)
- f.HandlePacket(r, tcpip.PacketBuffer{
+ f.HandlePacket(r, stack.PacketBuffer{
Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
})
}
@@ -153,11 +153,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []stack.PacketBuffer, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt tcpip.PacketBuffer) *tcpip.Error {
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -255,7 +255,7 @@ type linkEPWithMockedAttach struct {
// Attach implements stack.LinkEndpoint.Attach.
func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) {
l.LinkEndpoint.Attach(d)
- l.attached = true
+ l.attached = d != nil
}
func (l *linkEPWithMockedAttach) isAttached() bool {
@@ -287,7 +287,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet with wrong address is not delivered.
buf[0] = 3
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 0 {
@@ -299,7 +299,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to first endpoint.
buf[0] = 1
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -311,7 +311,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to second endpoint.
buf[0] = 2
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -322,7 +322,7 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- ep.InjectInbound(fakeNetNumber-1, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber-1, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -334,7 +334,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeNet.packetCount[1] != 1 {
@@ -356,7 +356,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
func send(r stack.Route, payload buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: payload.ToVectorisedView(),
})
@@ -414,7 +414,7 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b
func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
t.Helper()
- ep.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if got := fakeNet.PacketCount(localAddrByte); got != want {
@@ -566,7 +566,7 @@ func TestAttachToLinkEndpointImmediately(t *testing.T) {
t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err)
}
if !e.isAttached() {
- t.Fatalf("link endpoint not attached to a network disatcher")
+ t.Fatal("link endpoint not attached to a network dispatcher")
}
})
}
@@ -631,196 +631,240 @@ func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
checkNIC(false)
}
-func TestRoutesWithDisabledNIC(t *testing.T) {
- const unspecifiedNIC = 0
- const nicID1 = 1
- const nicID2 = 2
-
+func TestRemoveUnknownNIC(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
- ep1 := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
}
+}
- addr1 := tcpip.Address("\x01")
- if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
- }
+func TestRemoveNIC(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- ep2 := channel.New(0, defaultMTU, "")
- if err := s.CreateNIC(nicID2, ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: loopback.New(),
+ }
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- addr2 := tcpip.Address("\x02")
- if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ // NIC should be present in NICInfo and attached to a NetworkDispatcher.
+ allNICInfo := s.NICInfo()
+ if _, ok := allNICInfo[nicID]; !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ }
+ if !e.isAttached() {
+ t.Fatal("link endpoint not attached to a network dispatcher")
}
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- {
- subnet0, err := tcpip.NewSubnet("\x00", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
- {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
- })
+ // Removing a NIC should remove it from NICInfo and e should be detached from
+ // the NetworkDispatcher.
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
+ }
+ if nicInfo, ok := s.NICInfo()[nicID]; ok {
+ t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo)
}
+ if e.isAttached() {
+ t.Error("link endpoint for removed NIC still attached to a network dispatcher")
+ }
+}
- // Test routes to odd address.
- testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
- testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
- testRoute(t, s, nicID1, addr1, "\x05", addr1)
+func TestRouteWithDownNIC(t *testing.T) {
+ tests := []struct {
+ name string
+ downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
+ upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
+ }{
+ {
+ name: "Disabled NIC",
+ downFn: (*stack.Stack).DisableNIC,
+ upFn: (*stack.Stack).EnableNIC,
+ },
+
+ // Once a NIC is removed, it cannot be brought up.
+ {
+ name: "Removed NIC",
+ downFn: (*stack.Stack).RemoveNIC,
+ },
+ }
- // Test routes to even address.
- testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
- testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
- testRoute(t, s, nicID2, addr2, "\x06", addr2)
-
- // Disabling NIC1 should result in no routes to odd addresses. Routes to even
- // addresses should continue to be available as NIC2 is still enabled.
- if err := s.DisableNIC(nicID1); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
- }
- nic1Dst := tcpip.Address("\x05")
- testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
- testNoRoute(t, s, nicID1, addr1, nic1Dst)
- nic2Dst := tcpip.Address("\x06")
- testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
- testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
- testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
-
- // Disabling NIC2 should result in no routes to even addresses. No route
- // should be available to any address as routes to odd addresses were made
- // unavailable by disabling NIC1 above.
- if err := s.DisableNIC(nicID2); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
- }
- testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
- testNoRoute(t, s, nicID1, addr1, nic1Dst)
- testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
- testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
- testNoRoute(t, s, nicID2, addr2, nic2Dst)
-
- // Enabling NIC1 should make routes to odd addresses available again. Routes
- // to even addresses should continue to be unavailable as NIC2 is still
- // disabled.
- if err := s.EnableNIC(nicID1); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
- }
- testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
- testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
- testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
- testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
- testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
- testNoRoute(t, s, nicID2, addr2, nic2Dst)
-}
-
-func TestRouteWritePacketWithDisabledNIC(t *testing.T) {
const unspecifiedNIC = 0
const nicID1 = 1
const nicID2 = 2
+ const addr1 = tcpip.Address("\x01")
+ const addr2 = tcpip.Address("\x02")
+ const nic1Dst = tcpip.Address("\x05")
+ const nic2Dst = tcpip.Address("\x06")
+
+ setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
- ep1 := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID1, ep1); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
- }
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
- addr1 := tcpip.Address("\x01")
- if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
- }
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
- ep2 := channel.New(1, defaultMTU, "")
- if err := s.CreateNIC(nicID2, ep2); err != nil {
- t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
- }
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
- addr2 := tcpip.Address("\x02")
- if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
+
+ return s, ep1, ep2
}
- // Set a route table that sends all packets with odd destination
- // addresses through the first NIC, and all even destination address
- // through the second one.
- {
- subnet0, err := tcpip.NewSubnet("\x00", "\x01")
- if err != nil {
- t.Fatal(err)
- }
- subnet1, err := tcpip.NewSubnet("\x01", "\x01")
- if err != nil {
- t.Fatal(err)
+ // Tests that routes through a down NIC are not used when looking up a route
+ // for a destination.
+ t.Run("Find", func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, _, _ := setup(t)
+
+ // Test routes to odd address.
+ testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
+ testRoute(t, s, nicID1, addr1, "\x05", addr1)
+
+ // Test routes to even address.
+ testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
+ testRoute(t, s, nicID2, addr2, "\x06", addr2)
+
+ // Bringing NIC1 down should result in no routes to odd addresses. Routes to
+ // even addresses should continue to be available as NIC2 is still up.
+ if err := test.downFn(s, nicID1); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
+ testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
+
+ // Bringing NIC2 down should result in no routes to even addresses. No
+ // route should be available to any address as routes to odd addresses
+ // were made unavailable by bringing NIC1 down above.
+ if err := test.downFn(s, nicID2); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+
+ if upFn := test.upFn; upFn != nil {
+ // Bringing NIC1 up should make routes to odd addresses available
+ // again. Routes to even addresses should continue to be unavailable
+ // as NIC2 is still down.
+ if err := upFn(s, nicID1); err != nil {
+ t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
+ }
+ testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
+ testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+ }
+ })
}
- s.SetRouteTable([]tcpip.Route{
- {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
- {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
- })
- }
+ })
- nic1Dst := tcpip.Address("\x05")
- r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
- }
- defer r1.Release()
+ // Tests that writing a packet using a Route through a down NIC fails.
+ t.Run("WritePacket", func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep1, ep2 := setup(t)
- nic2Dst := tcpip.Address("\x06")
- r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
- if err != nil {
- t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
- }
- defer r2.Release()
+ r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
+ }
+ defer r1.Release()
- // If we failed to get routes r1 or r2, we cannot proceed with the test.
- if t.Failed() {
- t.FailNow()
- }
+ r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
+ }
+ defer r2.Release()
- buf := buffer.View([]byte{1})
- testSend(t, r1, ep1, buf)
- testSend(t, r2, ep2, buf)
+ // If we failed to get routes r1 or r2, we cannot proceed with the test.
+ if t.Failed() {
+ t.FailNow()
+ }
- // Writes with Routes that use the disabled NIC1 should fail.
- if err := s.DisableNIC(nicID1); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
- }
- testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
- testSend(t, r2, ep2, buf)
+ buf := buffer.View([]byte{1})
+ testSend(t, r1, ep1, buf)
+ testSend(t, r2, ep2, buf)
- // Writes with Routes that use the disabled NIC2 should fail.
- if err := s.DisableNIC(nicID2); err != nil {
- t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
- }
- testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
- testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ // Writes with Routes that use NIC1 after being brought down should fail.
+ if err := test.downFn(s, nicID1); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testSend(t, r2, ep2, buf)
- // Writes with Routes that use the re-enabled NIC1 should succeed.
- // TODO(b/147015577): Should we instead completely invalidate all Routes that
- // were bound to a disabled NIC at some point?
- if err := s.EnableNIC(nicID1); err != nil {
- t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
- }
- testSend(t, r1, ep1, buf)
- testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ // Writes with Routes that use NIC2 after being brought down should fail.
+ if err := test.downFn(s, nicID2); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+
+ if upFn := test.upFn; upFn != nil {
+ // Writes with Routes that use NIC1 after being brought up should
+ // succeed.
+ //
+ // TODO(b/147015577): Should we instead completely invalidate all
+ // Routes that were bound to a NIC that was brought down at some
+ // point?
+ if err := upFn(s, nicID1); err != nil {
+ t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
+ }
+ testSend(t, r1, ep1, buf)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ }
+ })
+ }
+ })
}
func TestRoutes(t *testing.T) {
@@ -2213,7 +2257,7 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
@@ -2240,56 +2284,84 @@ func TestNICStats(t *testing.T) {
}
func TestNICForwarding(t *testing.T) {
- // Create a stack with the fake network protocol, two NICs, each with
- // an address.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- s.SetForwarding(true)
+ const nicID1 = 1
+ const nicID2 = 2
+ const dstAddr = tcpip.Address("\x03")
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC #1 failed:", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress #1 failed:", err)
+ tests := []struct {
+ name string
+ headerLen uint16
+ }{
+ {
+ name: "Zero header length",
+ },
+ {
+ name: "Non-zero header length",
+ headerLen: 16,
+ },
}
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatal("CreateNIC #2 failed:", err)
- }
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress #2 failed:", err)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ s.SetForwarding(true)
- // Route all packets to address 3 to NIC 2.
- {
- subnet, err := tcpip.NewSubnet("\x03", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}})
- }
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err)
+ }
- // Send a packet to address 3.
- buf := buffer.NewView(30)
- buf[0] = 3
- ep1.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
- Data: buf.ToVectorisedView(),
- })
+ ep2 := channelLinkWithHeaderLength{
+ Endpoint: channel.New(10, defaultMTU, ""),
+ headerLength: test.headerLen,
+ }
+ if err := s.CreateNIC(nicID2, &ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err)
+ }
- if _, ok := ep2.Read(); !ok {
- t.Fatal("Packet not forwarded")
- }
+ // Route all packets to dstAddr to NIC 2.
+ {
+ subnet, err := tcpip.NewSubnet(dstAddr, "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}})
+ }
- // Test that forwarding increments Tx stats correctly.
- if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
- }
+ // Send a packet to dstAddr.
+ buf := buffer.NewView(30)
+ buf[0] = dstAddr[0]
+ ep1.InjectInbound(fakeNetNumber, stack.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
- if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
+ pkt, ok := ep2.Read()
+ if !ok {
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the link's MaxHeaderLength is honoured.
+ if capacity, want := pkt.Pkt.Header.AvailableLength(), int(test.headerLen); capacity != want {
+ t.Errorf("got Header.AvailableLength() = %d, want = %d", capacity, want)
+ }
+
+ // Test that forwarding increments Tx stats correctly.
+ if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want {
+ t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
+ t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
+ }
+ })
}
}
@@ -3010,6 +3082,50 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
}
}
+// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6
+// address after leaving its solicited node multicast address does not result in
+// an error.
+func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ }
+
+ // The NIC should have joined addr1's solicited node multicast address.
+ snmc := header.SolicitedNodeAddr(addr1)
+ in, err := s.IsInGroup(nicID, snmc)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err)
+ }
+ if !in {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc)
+ }
+
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err)
+ }
+ in, err = s.IsInGroup(nicID, snmc)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err)
+ }
+ if in {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc)
+ }
+
+ if err := s.RemoveAddress(nicID, addr1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
+ }
+}
+
func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
const nicID = 1
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index ff1845bfb..c55e3e8bc 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -85,7 +85,7 @@ func (epsByNic *endpointsByNic) transportEndpoints() []TransportEndpoint {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, pkt PacketBuffer) {
epsByNic.mu.RLock()
mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
@@ -116,7 +116,7 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, p
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) {
+func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt PacketBuffer) {
epsByNic.mu.RLock()
defer epsByNic.mu.RUnlock()
@@ -184,7 +184,7 @@ type transportDemuxer struct {
// the dispatcher to delivery packets to the QueuePacket method instead of
// calling HandlePacket directly on the endpoint.
type queuedTransportProtocol interface {
- QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt tcpip.PacketBuffer)
+ QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
@@ -312,7 +312,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[idx]
}
-func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt tcpip.PacketBuffer) {
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt PacketBuffer) {
ep.mu.RLock()
queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
// HandlePacket takes ownership of pkt, so each endpoint needs
@@ -403,73 +403,57 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if
// the packet no longer needs to be handled.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
}
- eps.mu.RLock()
-
- // Determine which transport endpoint or endpoints to deliver this packet to.
// If the packet is a UDP broadcast or multicast, then find all matching
- // transport endpoints. If the packet is a TCP packet with a non-unicast
- // source or destination address, then do nothing further and instruct
- // the caller to do the same.
- var destEps []*endpointsByNic
- switch protocol {
- case header.UDPProtocolNumber:
- if isMulticastOrBroadcast(id.LocalAddress) {
- destEps = d.findAllEndpointsLocked(eps, id)
- break
- }
-
- if ep := d.findEndpointLocked(eps, id); ep != nil {
- destEps = append(destEps, ep)
+ // transport endpoints.
+ if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
+ eps.mu.RLock()
+ destEPs := d.findAllEndpointsLocked(eps, id)
+ eps.mu.RUnlock()
+ // Fail if we didn't find at least one matching transport endpoint.
+ if len(destEPs) == 0 {
+ r.Stats().UDP.UnknownPortErrors.Increment()
+ return false
}
-
- case header.TCPProtocolNumber:
- if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) {
- // TCP can only be used to communicate between a single
- // source and a single destination; the addresses must
- // be unicast.
- eps.mu.RUnlock()
- r.Stats().TCP.InvalidSegmentsReceived.Increment()
- return true
+ // handlePacket takes ownership of pkt, so each endpoint needs its own
+ // copy except for the final one.
+ for _, ep := range destEPs[:len(destEPs)-1] {
+ ep.handlePacket(r, id, pkt.Clone())
}
+ destEPs[len(destEPs)-1].handlePacket(r, id, pkt)
+ return true
+ }
- fallthrough
-
- default:
- if ep := d.findEndpointLocked(eps, id); ep != nil {
- destEps = append(destEps, ep)
- }
+ // If the packet is a TCP packet with a non-unicast source or destination
+ // address, then do nothing further and instruct the caller to do the same.
+ if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ // TCP can only be used to communicate between a single source and a
+ // single destination; the addresses must be unicast.
+ r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ return true
}
+ eps.mu.RLock()
+ ep := d.findEndpointLocked(eps, id)
eps.mu.RUnlock()
-
- // Fail if we didn't find at least one matching transport endpoint.
- if len(destEps) == 0 {
- // UDP packet could not be delivered to an unknown destination port.
+ if ep == nil {
if protocol == header.UDPProtocolNumber {
r.Stats().UDP.UnknownPortErrors.Increment()
}
return false
}
-
- // HandlePacket takes ownership of pkt, so each endpoint needs its own
- // copy except for the final one.
- for _, ep := range destEps[:len(destEps)-1] {
- ep.handlePacket(r, id, pkt.Clone())
- }
- destEps[len(destEps)-1].handlePacket(r, id, pkt)
-
+ ep.handlePacket(r, id, pkt)
return true
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer) bool {
+func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt PacketBuffer) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -493,7 +477,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt tcpip.PacketBuffer, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
@@ -515,11 +499,17 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco
return true
}
-func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic {
- var matchedEPs []*endpointsByNic
+// iterEndpointsLocked yields all endpointsByNic in eps that match id, in
+// descending order of match quality. If a call to yield returns false,
+// iterEndpointsLocked stops iteration and returns immediately.
+//
+// Preconditions: eps.mu must be locked.
+func (d *transportDemuxer) iterEndpointsLocked(eps *transportEndpoints, id TransportEndpointID, yield func(*endpointsByNic) bool) {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
// Try to find a match with the id minus the local address.
@@ -527,7 +517,9 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
// Try to find a match with the id minus the remote part.
@@ -535,14 +527,26 @@ func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id Tr
nid.RemoteAddress = ""
nid.RemotePort = 0
if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
// Try to find a match with only the local port.
nid.LocalAddress = ""
if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
+ if !yield(ep) {
+ return
+ }
}
+}
+
+func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, id TransportEndpointID) []*endpointsByNic {
+ var matchedEPs []*endpointsByNic
+ d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool {
+ matchedEPs = append(matchedEPs, ep)
+ return true
+ })
return matchedEPs
}
@@ -580,10 +584,12 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN
// findEndpointLocked returns the endpoint that most closely matches the given
// id.
func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, id TransportEndpointID) *endpointsByNic {
- if matchedEPs := d.findAllEndpointsLocked(eps, id); len(matchedEPs) > 0 {
- return matchedEPs[0]
- }
- return nil
+ var matchedEP *endpointsByNic
+ d.iterEndpointsLocked(eps, id, func(ep *endpointsByNic) bool {
+ matchedEP = ep
+ return false
+ })
+ return matchedEP
}
// registerRawEndpoint registers the given endpoint with the dispatcher such
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 0e3e239c5..84311bcc8 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -150,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 5d1da2f8b..8ca9ac3cf 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -87,7 +86,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
if err != nil {
return 0, nil, err
}
- if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{
+ if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketBuffer{
Header: hdr,
Data: buffer.View(v).ToVectorisedView(),
}); err != nil {
@@ -214,7 +213,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ stack.PacketBuffer) {
// Increment the number of received packets.
f.proto.packetCount++
if f.acceptQueue != nil {
@@ -231,7 +230,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE
}
}
-func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) {
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, stack.PacketBuffer) {
// Increment the number of received control packets.
f.proto.controlCount++
}
@@ -242,8 +241,8 @@ func (f *fakeTransportEndpoint) State() uint32 {
func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
-func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) {
- return iptables.IPTables{}, nil
+func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) {
+ return stack.IPTables{}, nil
}
func (f *fakeTransportEndpoint) Resume(*stack.Stack) {}
@@ -288,7 +287,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, stack.PacketBuffer) bool {
return true
}
@@ -368,7 +367,7 @@ func TestTransportReceive(t *testing.T) {
// Make sure packet with wrong protocol is not delivered.
buf[0] = 1
buf[2] = 0
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 0 {
@@ -379,7 +378,7 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 3
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 0 {
@@ -390,7 +389,7 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 2
buf[2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.packetCount != 1 {
@@ -445,7 +444,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 0
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = 0
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 0 {
@@ -456,7 +455,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 3
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 0 {
@@ -467,7 +466,7 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 2
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ linkEP.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
if fakeTrans.controlCount != 1 {
@@ -622,7 +621,7 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{
+ ep2.InjectInbound(fakeNetNumber, stack.PacketBuffer{
Data: req.ToVectorisedView(),
})