From 9143fcd7fd38243dd40f927dafaeb75f6ef8ef49 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Tue, 21 Jan 2020 14:47:17 -0800 Subject: Add UDP matchers. --- pkg/tcpip/iptables/BUILD | 2 + pkg/tcpip/iptables/iptables.go | 3 + pkg/tcpip/iptables/tcp_matcher.go | 122 ++++++++++++++++++++++++++++++++++++ pkg/tcpip/iptables/types.go | 17 +++++ pkg/tcpip/iptables/udp_matcher.go | 127 ++++++++++++++++++++++++++++++++++++++ pkg/tcpip/network/ipv4/ipv4.go | 10 +-- 6 files changed, 276 insertions(+), 5 deletions(-) create mode 100644 pkg/tcpip/iptables/tcp_matcher.go create mode 100644 pkg/tcpip/iptables/udp_matcher.go (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index 297eaccaf..ff4e3c932 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -7,7 +7,9 @@ go_library( srcs = [ "iptables.go", "targets.go", + "tcp_matcher.go", "types.go", + "udp_matcher.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables", visibility = ["//visibility:public"], diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index fc06b5b87..accedba1e 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -138,6 +138,8 @@ func EmptyFilterTable() Table { // Check runs pkt through the rules for hook. It returns true when the packet // should continue traversing the network stack and false when it should be // dropped. +// +// Precondition: pkt.NetworkHeader is set. func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool { // TODO(gvisor.dev/issue/170): A lot of this is uncomplicated because // we're missing features. Jumps, the call stack, etc. aren't checked @@ -163,6 +165,7 @@ func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool { return true } +// Precondition: pkt.NetworkHeader is set. func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename string) Verdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. diff --git a/pkg/tcpip/iptables/tcp_matcher.go b/pkg/tcpip/iptables/tcp_matcher.go new file mode 100644 index 000000000..6acbd6eb9 --- /dev/null +++ b/pkg/tcpip/iptables/tcp_matcher.go @@ -0,0 +1,122 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +type TCPMatcher struct { + data TCPMatcherData + + // tablename string + // unsigned int matchsize; + // unsigned int usersize; + // #ifdef CONFIG_COMPAT + // unsigned int compatsize; + // #endif + // unsigned int hooks; + // unsigned short proto; + // unsigned short family; +} + +// TODO: Delete? +// MatchCheckEntryParams + +type TCPMatcherData struct { + // Filter IPHeaderFilter + + SourcePortStart uint16 + SourcePortEnd uint16 + DestinationPortStart uint16 + DestinationPortEnd uint16 + Option uint8 + FlagMask uint8 + FlagCompare uint8 + InverseFlags uint8 +} + +func NewTCPMatcher(filter IPHeaderFilter, data TCPMatcherData) (Matcher, error) { + // TODO: We currently only support source port and destination port. + log.Infof("Adding rule with TCPMatcherData: %+v", data) + + if data.Option != 0 || + data.FlagMask != 0 || + data.FlagCompare != 0 || + data.InverseFlags != 0 { + return nil, fmt.Errorf("unsupported TCP matcher flags set") + } + + if filter.Protocol != header.TCPProtocolNumber { + log.Warningf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber) + } + + return &TCPMatcher{data: data}, nil +} + +// TODO: Check xt_tcpudp.c. Need to check for same things (e.g. fragments). +func (tm *TCPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) { + netHeader := header.IPv4(pkt.NetworkHeader) + + // TODO: Do we check proto here or elsewhere? I think elsewhere (check + // codesearch). + if netHeader.TransportProtocol() != header.TCPProtocolNumber { + return false, false + } + + // We dont't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + log.Warningf("Dropping TCP packet: malicious packet with fragment with fragment offest of 1.") + return false, true + } + return false, false + } + + // Now we need the transport header. However, this may not have been set + // yet. + // TODO + var tcpHeader header.TCP + if pkt.TransportHeader != nil { + tcpHeader = header.TCP(pkt.TransportHeader) + } else { + // The TCP header hasn't been parsed yet. We have to do it here. + if len(pkt.Data.First()) < header.TCPMinimumSize { + // There's no valid TCP header here, so we hotdrop the + // packet. + // TODO: Stats. + log.Warningf("Dropping TCP packet: size to small.") + return false, true + } + tcpHeader = header.TCP(pkt.Data.First()) + } + + // Check whether the source and destination ports are within the + // matching range. + sourcePort := tcpHeader.SourcePort() + destinationPort := tcpHeader.DestinationPort() + if sourcePort < tm.data.SourcePortStart || tm.data.SourcePortEnd < sourcePort { + return false, false + } + if destinationPort < tm.data.DestinationPortStart || tm.data.DestinationPortEnd < destinationPort { + return false, false + } + + return true, false +} diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index a0bfc8b41..54e66f09a 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -169,12 +169,29 @@ type IPHeaderFilter struct { Protocol tcpip.TransportProtocolNumber } +// TODO: Should these be able to marshal/unmarshal themselves? +// TODO: Something has to map the name to the matcher. // A Matcher is the interface for matching packets. type Matcher interface { // Match returns whether the packet matches and whether the packet // should be "hotdropped", i.e. dropped immediately. This is usually // used for suspicious packets. + // + // Precondition: packet.NetworkHeader is set. Match(hook Hook, packet tcpip.PacketBuffer, interfaceName string) (matches bool, hotdrop bool) + + // TODO: Make this typesafe by having each Matcher have their own, typed CheckEntry? + // CheckEntry(params MatchCheckEntryParams) bool +} + +// TODO: Unused? +type MatchCheckEntryParams struct { + Table string // TODO: Tables should be an enum... + Filter IPHeaderFilter + Info interface{} // TODO: Type unsafe. + // HookMask uint8 + // Family uint8 + // NFTCompat bool } // A Target is the interface for taking an action for a packet. diff --git a/pkg/tcpip/iptables/udp_matcher.go b/pkg/tcpip/iptables/udp_matcher.go new file mode 100644 index 000000000..ce4368a3d --- /dev/null +++ b/pkg/tcpip/iptables/udp_matcher.go @@ -0,0 +1,127 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "fmt" + "runtime/debug" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +type UDPMatcher struct { + data UDPMatcherData + + // tablename string + // unsigned int matchsize; + // unsigned int usersize; + // #ifdef CONFIG_COMPAT + // unsigned int compatsize; + // #endif + // unsigned int hooks; + // unsigned short proto; + // unsigned short family; +} + +// TODO: Delete? +// MatchCheckEntryParams + +type UDPMatcherData struct { + // Filter IPHeaderFilter + + SourcePortStart uint16 + SourcePortEnd uint16 + DestinationPortStart uint16 + DestinationPortEnd uint16 + InverseFlags uint8 +} + +func NewUDPMatcher(filter IPHeaderFilter, data UDPMatcherData) (Matcher, error) { + // TODO: We currently only support source port and destination port. + log.Infof("Adding rule with UDPMatcherData: %+v", data) + + if data.InverseFlags != 0 { + return nil, fmt.Errorf("unsupported UDP matcher flags set") + } + + if filter.Protocol != header.UDPProtocolNumber { + log.Warningf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber) + } + + return &UDPMatcher{data: data}, nil +} + +// TODO: Check xt_tcpudp.c. Need to check for same things (e.g. fragments). +func (tm *UDPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) { + log.Infof("UDPMatcher called from: %s", string(debug.Stack())) + netHeader := header.IPv4(pkt.NetworkHeader) + + // TODO: Do we check proto here or elsewhere? I think elsewhere (check + // codesearch). + if netHeader.TransportProtocol() != header.UDPProtocolNumber { + log.Infof("UDPMatcher: wrong protocol number") + return false, false + } + + // We dont't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + log.Infof("UDPMatcher: it's a fragment") + if frag == 1 { + return false, true + } + log.Warningf("Dropping UDP packet: malicious fragmented packet.") + return false, false + } + + // Now we need the transport header. However, this may not have been set + // yet. + // TODO + var udpHeader header.UDP + if pkt.TransportHeader != nil { + log.Infof("UDPMatcher: transport header is not nil") + udpHeader = header.UDP(pkt.TransportHeader) + } else { + log.Infof("UDPMatcher: transport header is nil") + log.Infof("UDPMatcher: is network header nil: %t", pkt.NetworkHeader == nil) + // The UDP header hasn't been parsed yet. We have to do it here. + if len(pkt.Data.First()) < header.UDPMinimumSize { + // There's no valid UDP header here, so we hotdrop the + // packet. + // TODO: Stats. + log.Warningf("Dropping UDP packet: size to small.") + return false, true + } + udpHeader = header.UDP(pkt.Data.First()) + } + + // Check whether the source and destination ports are within the + // matching range. + sourcePort := udpHeader.SourcePort() + destinationPort := udpHeader.DestinationPort() + log.Infof("UDPMatcher: sport and dport are %d and %d. sports and dport start and end are (%d, %d) and (%d, %d)", + udpHeader.SourcePort(), udpHeader.DestinationPort(), + tm.data.SourcePortStart, tm.data.SourcePortEnd, + tm.data.DestinationPortStart, tm.data.DestinationPortEnd) + if sourcePort < tm.data.SourcePortStart || tm.data.SourcePortEnd < sourcePort { + return false, false + } + if destinationPort < tm.data.DestinationPortStart || tm.data.DestinationPortEnd < destinationPort { + return false, false + } + + return true, false +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 85512f9b2..6597e6781 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -353,6 +353,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { } pkt.NetworkHeader = headerView[:h.HeaderLength()] + hlen := int(h.HeaderLength()) + tlen := int(h.TotalLength()) + pkt.Data.TrimFront(hlen) + pkt.Data.CapLength(tlen - hlen) + // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. ipt := e.stack.IPTables() @@ -361,11 +366,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { return } - hlen := int(h.HeaderLength()) - tlen := int(h.TotalLength()) - pkt.Data.TrimFront(hlen) - pkt.Data.CapLength(tlen - hlen) - more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 if more || h.FragmentOffset() != 0 { if pkt.Data.Size() == 0 { -- cgit v1.2.3 From 2661101ad470548cb15dce0afc694296668d780a Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Tue, 21 Jan 2020 14:51:28 -0800 Subject: Removed TCP work (saved in ipt-tcp-match). --- pkg/abi/linux/netfilter.go | 52 ------------- pkg/sentry/socket/netfilter/netfilter.go | 26 ------- pkg/tcpip/iptables/tcp_matcher.go | 122 ------------------------------- 3 files changed, 200 deletions(-) delete mode 100644 pkg/tcpip/iptables/tcp_matcher.go (limited to 'pkg/tcpip') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index fb4588272..f0e544f9c 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -341,58 +341,6 @@ func goString(cstring []byte) string { return string(cstring) } -// XTTCP holds data for matching TCP packets. It corresponds to struct xt_tcp -// in include/uapi/linux/netfilter/xt_tcpudp.h. -type XTTCP struct { - // SourcePortStart specifies the inclusive start of the range of source - // ports to which the matcher applies. - SourcePortStart uint16 - - // SourcePortEnd specifies the inclusive end of the range of source ports - // to which the matcher applies. - SourcePortEnd uint16 - - // DestinationPortStart specifies the start of the destination port - // range to which the matcher applies. - DestinationPortStart uint16 - - // DestinationPortEnd specifies the start of the destination port - // range to which the matcher applies. - DestinationPortEnd uint16 - - // Option specifies that a particular TCP option must be set. - Option uint8 - - // FlagMask masks the FlagCompare byte when comparing to the TCP flag - // fields. - FlagMask uint8 - - // FlagCompare is binary and-ed with the TCP flag fields. - FlagCompare uint8 - - // InverseFlags flips the meaning of certain fields. See the - // TX_TCP_INV_* flags. - InverseFlags uint8 -} - -// SizeOfXTTCP is the size of an XTTCP. -const SizeOfXTTCP = 12 - -// Flags in XTTCP.InverseFlags. Corresponding constants are in -// include/uapi/linux/netfilter/xt_tcpudp.h. -const ( - // Invert the meaning of SourcePortStart/End. - XT_TCP_INV_SRCPT = 0x01 - // Invert the meaning of DestinationPortStart/End. - XT_TCP_INV_DSTPT = 0x02 - // Invert the meaning of FlagCompare. - XT_TCP_INV_FLAGS = 0x04 - // Invert the meaning of Option. - XT_TCP_INV_OPTION = 0x08 - // Enable all flags. - XT_TCP_INV_MASK = 0x0F -) - // XTUDP holds data for matching UDP packets. It corresponds to struct xt_udp // in include/uapi/linux/netfilter/xt_tcpudp.h. type XTUDP struct { diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 45296b339..f8ed1acbc 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -131,7 +131,6 @@ func FillDefaultIPTables(stack *stack.Stack) { stack.SetIPTables(ipt) } -// TODO: Return proto. // convertNetstackToBinary converts the iptables as stored in netstack to the // format expected by the iptables tool. Linux stores each table as a binary // blob that can only be traversed by parsing a bit, reading some offsets, @@ -456,31 +455,6 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma var matcher iptables.Matcher var err error switch match.Name.String() { - case "tcp": - if len(buf) < linux.SizeOfXTTCP { - log.Warningf("netfilter: optVal has insufficient size for TCP match: %d", len(optVal)) - return nil, syserr.ErrInvalidArgument - } - var matchData linux.XTTCP - // For alignment reasons, the match's total size may exceed what's - // strictly necessary to hold matchData. - binary.Unmarshal(buf[:linux.SizeOfXTUDP], usermem.ByteOrder, &matchData) - log.Infof("parseMatchers: parsed XTTCP: %+v", matchData) - matcher, err = iptables.NewTCPMatcher(filter, iptables.TCPMatcherData{ - SourcePortStart: matchData.SourcePortStart, - SourcePortEnd: matchData.SourcePortEnd, - DestinationPortStart: matchData.DestinationPortStart, - DestinationPortEnd: matchData.DestinationPortEnd, - Option: matchData.Option, - FlagMask: matchData.FlagMask, - FlagCompare: matchData.FlagCompare, - InverseFlags: matchData.InverseFlags, - }) - if err != nil { - log.Warningf("netfilter: failed to create TCP matcher: %v", err) - return nil, syserr.ErrInvalidArgument - } - case "udp": if len(buf) < linux.SizeOfXTUDP { log.Warningf("netfilter: optVal has insufficient size for UDP match: %d", len(optVal)) diff --git a/pkg/tcpip/iptables/tcp_matcher.go b/pkg/tcpip/iptables/tcp_matcher.go deleted file mode 100644 index 6acbd6eb9..000000000 --- a/pkg/tcpip/iptables/tcp_matcher.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables - -import ( - "fmt" - - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -type TCPMatcher struct { - data TCPMatcherData - - // tablename string - // unsigned int matchsize; - // unsigned int usersize; - // #ifdef CONFIG_COMPAT - // unsigned int compatsize; - // #endif - // unsigned int hooks; - // unsigned short proto; - // unsigned short family; -} - -// TODO: Delete? -// MatchCheckEntryParams - -type TCPMatcherData struct { - // Filter IPHeaderFilter - - SourcePortStart uint16 - SourcePortEnd uint16 - DestinationPortStart uint16 - DestinationPortEnd uint16 - Option uint8 - FlagMask uint8 - FlagCompare uint8 - InverseFlags uint8 -} - -func NewTCPMatcher(filter IPHeaderFilter, data TCPMatcherData) (Matcher, error) { - // TODO: We currently only support source port and destination port. - log.Infof("Adding rule with TCPMatcherData: %+v", data) - - if data.Option != 0 || - data.FlagMask != 0 || - data.FlagCompare != 0 || - data.InverseFlags != 0 { - return nil, fmt.Errorf("unsupported TCP matcher flags set") - } - - if filter.Protocol != header.TCPProtocolNumber { - log.Warningf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber) - } - - return &TCPMatcher{data: data}, nil -} - -// TODO: Check xt_tcpudp.c. Need to check for same things (e.g. fragments). -func (tm *TCPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader) - - // TODO: Do we check proto here or elsewhere? I think elsewhere (check - // codesearch). - if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return false, false - } - - // We dont't match fragments. - if frag := netHeader.FragmentOffset(); frag != 0 { - if frag == 1 { - log.Warningf("Dropping TCP packet: malicious packet with fragment with fragment offest of 1.") - return false, true - } - return false, false - } - - // Now we need the transport header. However, this may not have been set - // yet. - // TODO - var tcpHeader header.TCP - if pkt.TransportHeader != nil { - tcpHeader = header.TCP(pkt.TransportHeader) - } else { - // The TCP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.TCPMinimumSize { - // There's no valid TCP header here, so we hotdrop the - // packet. - // TODO: Stats. - log.Warningf("Dropping TCP packet: size to small.") - return false, true - } - tcpHeader = header.TCP(pkt.Data.First()) - } - - // Check whether the source and destination ports are within the - // matching range. - sourcePort := tcpHeader.SourcePort() - destinationPort := tcpHeader.DestinationPort() - if sourcePort < tm.data.SourcePortStart || tm.data.SourcePortEnd < sourcePort { - return false, false - } - if destinationPort < tm.data.DestinationPortStart || tm.data.DestinationPortEnd < destinationPort { - return false, false - } - - return true, false -} -- cgit v1.2.3 From 421b6ff18154f80ea8cbbfd8340042ab458bf813 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Tue, 21 Jan 2020 14:54:39 -0800 Subject: Passes all filter table UDP tests. --- pkg/tcpip/iptables/BUILD | 1 - 1 file changed, 1 deletion(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index ff4e3c932..e41c645ed 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -7,7 +7,6 @@ go_library( srcs = [ "iptables.go", "targets.go", - "tcp_matcher.go", "types.go", "udp_matcher.go", ], -- cgit v1.2.3 From 538053538dfb378aa8bc512d484ea305177e617b Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Tue, 21 Jan 2020 16:51:17 -0800 Subject: Adding serialization. --- pkg/sentry/socket/netfilter/netfilter.go | 29 ++++++++++++++++++++++++++++- pkg/tcpip/iptables/udp_matcher.go | 14 +++++++------- 2 files changed, 35 insertions(+), 8 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index f8ed1acbc..3caabca9a 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -196,7 +196,9 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern } func marshalMatcher(matcher iptables.Matcher) []byte { - switch matcher.(type) { + switch m := matcher.(type) { + case *iptables.UDPMatcher: + return marshalUDPMatcher(m) default: // TODO(gvisor.dev/issue/170): We don't support any matchers // yet, so any call to marshalMatcher will panic. @@ -204,6 +206,31 @@ func marshalMatcher(matcher iptables.Matcher) []byte { } } +func marshalUDPMatcher(matcher *iptables.UDPMatcher) []byte { + type udpMatch struct { + linux.XTEntryMatch + linux.XTUDP + } + linuxMatcher := udpMatch{ + XTEntryMatch: linux.XTEntryMatch{ + MatchSize: linux.SizeOfXTEntryMatch + linux.SizeOfXTUDP, + // Name: "udp", + }, + XTUDP: linux.XTUDP{ + SourcePortStart: matcher.Data.SourcePortStart, + SourcePortEnd: matcher.Data.SourcePortEnd, + DestinationPortStart: matcher.Data.DestinationPortStart, + DestinationPortEnd: matcher.Data.DestinationPortEnd, + InverseFlags: matcher.Data.InverseFlags, + }, + } + copy(linuxMatcher.Name[:], "udp") + + var buf [linux.SizeOfXTEntryMatch + linux.SizeOfXTUDP]byte + binary.Marshal(buf[:], usermem.ByteOrder, linuxMatcher) + return buf[:] +} + func marshalTarget(target iptables.Target) []byte { switch target.(type) { case iptables.UnconditionalAcceptTarget: diff --git a/pkg/tcpip/iptables/udp_matcher.go b/pkg/tcpip/iptables/udp_matcher.go index ce4368a3d..fca457199 100644 --- a/pkg/tcpip/iptables/udp_matcher.go +++ b/pkg/tcpip/iptables/udp_matcher.go @@ -24,7 +24,7 @@ import ( ) type UDPMatcher struct { - data UDPMatcherData + Data UDPMatcherData // tablename string // unsigned int matchsize; @@ -62,11 +62,11 @@ func NewUDPMatcher(filter IPHeaderFilter, data UDPMatcherData) (Matcher, error) log.Warningf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber) } - return &UDPMatcher{data: data}, nil + return &UDPMatcher{Data: data}, nil } // TODO: Check xt_tcpudp.c. Need to check for same things (e.g. fragments). -func (tm *UDPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) { +func (um *UDPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) { log.Infof("UDPMatcher called from: %s", string(debug.Stack())) netHeader := header.IPv4(pkt.NetworkHeader) @@ -114,12 +114,12 @@ func (tm *UDPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName str destinationPort := udpHeader.DestinationPort() log.Infof("UDPMatcher: sport and dport are %d and %d. sports and dport start and end are (%d, %d) and (%d, %d)", udpHeader.SourcePort(), udpHeader.DestinationPort(), - tm.data.SourcePortStart, tm.data.SourcePortEnd, - tm.data.DestinationPortStart, tm.data.DestinationPortEnd) - if sourcePort < tm.data.SourcePortStart || tm.data.SourcePortEnd < sourcePort { + um.Data.SourcePortStart, um.Data.SourcePortEnd, + um.Data.DestinationPortStart, um.Data.DestinationPortEnd) + if sourcePort < um.Data.SourcePortStart || um.Data.SourcePortEnd < sourcePort { return false, false } - if destinationPort < tm.data.DestinationPortStart || tm.data.DestinationPortEnd < destinationPort { + if destinationPort < um.Data.DestinationPortStart || um.Data.DestinationPortEnd < destinationPort { return false, false } -- cgit v1.2.3 From b7853f688b4bcd3465c0c3087fcbd8d53bdf26ae Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 22 Jan 2020 14:46:15 -0800 Subject: Error marshalling the matcher. The iptables binary is looking for libxt_.so when it should be looking for libxt_udp.so, so it's having an issue reading the data in xt_match_entry. I think it may be an alignment issue. Trying to fix this is leading to me fighting with the metadata struct, so I'm gonna go kill that. --- pkg/abi/linux/netfilter.go | 5 +++++ pkg/sentry/socket/netfilter/netfilter.go | 35 ++++++++++++++++++++------------ pkg/tcpip/iptables/udp_matcher.go | 2 +- 3 files changed, 28 insertions(+), 14 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index f0e544f9c..effed7976 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -198,6 +198,11 @@ type XTEntryMatch struct { // SizeOfXTEntryMatch is the size of an XTEntryMatch. const SizeOfXTEntryMatch = 32 +type KernelXTEntryMatch struct { + XTEntryMatch + Data []byte +} + // XTEntryTarget holds a target for a rule. For example, it can specify that // packets matching the rule should DROP, ACCEPT, or use an extension target. // iptables-extension(8) has a list of possible targets. diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 3caabca9a..b49fe5b3e 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -207,26 +207,34 @@ func marshalMatcher(matcher iptables.Matcher) []byte { } func marshalUDPMatcher(matcher *iptables.UDPMatcher) []byte { - type udpMatch struct { - linux.XTEntryMatch - linux.XTUDP - } - linuxMatcher := udpMatch{ + linuxMatcher := linux.KernelXTEntryMatch{ XTEntryMatch: linux.XTEntryMatch{ MatchSize: linux.SizeOfXTEntryMatch + linux.SizeOfXTUDP, // Name: "udp", }, - XTUDP: linux.XTUDP{ - SourcePortStart: matcher.Data.SourcePortStart, - SourcePortEnd: matcher.Data.SourcePortEnd, - DestinationPortStart: matcher.Data.DestinationPortStart, - DestinationPortEnd: matcher.Data.DestinationPortEnd, - InverseFlags: matcher.Data.InverseFlags, - }, + Data: make([]byte, linux.SizeOfXTUDP+22), } + // copy(linuxMatcher.Name[:], "udp") copy(linuxMatcher.Name[:], "udp") - var buf [linux.SizeOfXTEntryMatch + linux.SizeOfXTUDP]byte + // TODO: Must be aligned. + xtudp := linux.XTUDP{ + SourcePortStart: matcher.Data.SourcePortStart, + SourcePortEnd: matcher.Data.SourcePortEnd, + DestinationPortStart: matcher.Data.DestinationPortStart, + DestinationPortEnd: matcher.Data.DestinationPortEnd, + InverseFlags: matcher.Data.InverseFlags, + } + binary.Marshal(linuxMatcher.Data[:linux.SizeOfXTUDP], usermem.ByteOrder, xtudp) + + if binary.Size(linuxMatcher)%64 != 0 { + panic(fmt.Sprintf("size is actually: %d", binary.Size(linuxMatcher))) + } + + var buf [linux.SizeOfXTEntryMatch + linux.SizeOfXTUDP + 22]byte + if len(buf)%64 != 0 { + panic(fmt.Sprintf("len is actually: %d", len(buf))) + } binary.Marshal(buf[:], usermem.ByteOrder, linuxMatcher) return buf[:] } @@ -245,6 +253,7 @@ func marshalTarget(target iptables.Target) []byte { } func marshalStandardTarget(verdict iptables.Verdict) []byte { + // TODO: Must be aligned. // The target's name will be the empty string. target := linux.XTStandardTarget{ Target: linux.XTEntryTarget{ diff --git a/pkg/tcpip/iptables/udp_matcher.go b/pkg/tcpip/iptables/udp_matcher.go index fca457199..65ae7f9e0 100644 --- a/pkg/tcpip/iptables/udp_matcher.go +++ b/pkg/tcpip/iptables/udp_matcher.go @@ -59,7 +59,7 @@ func NewUDPMatcher(filter IPHeaderFilter, data UDPMatcherData) (Matcher, error) } if filter.Protocol != header.UDPProtocolNumber { - log.Warningf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber) + return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber) } return &UDPMatcher{Data: data}, nil -- cgit v1.2.3 From 29316e66adfc49c158425554761e34c12338f1d9 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Mon, 27 Jan 2020 12:27:04 -0800 Subject: Cleanup for GH review. --- pkg/abi/linux/netfilter.go | 14 +-- pkg/sentry/socket/netfilter/netfilter.go | 144 ++++++++++++------------------- pkg/tcpip/iptables/types.go | 15 ---- pkg/tcpip/iptables/udp_matcher.go | 62 ++++++------- test/iptables/filter_input.go | 6 +- 5 files changed, 88 insertions(+), 153 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index effed7976..8e40bcc62 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -198,6 +198,8 @@ type XTEntryMatch struct { // SizeOfXTEntryMatch is the size of an XTEntryMatch. const SizeOfXTEntryMatch = 32 +// KernelXTEntryMatch is identical to XTEntryMatch, but contains +// variable-length Data field. type KernelXTEntryMatch struct { XTEntryMatch Data []byte @@ -349,19 +351,19 @@ func goString(cstring []byte) string { // XTUDP holds data for matching UDP packets. It corresponds to struct xt_udp // in include/uapi/linux/netfilter/xt_tcpudp.h. type XTUDP struct { - // SourcePortStart specifies the inclusive start of the range of source - // ports to which the matcher applies. + // SourcePortStart is the inclusive start of the range of source ports + // to which the matcher applies. SourcePortStart uint16 - // SourcePortEnd specifies the inclusive end of the range of source ports - // to which the matcher applies. + // SourcePortEnd is the inclusive end of the range of source ports to + // which the matcher applies. SourcePortEnd uint16 - // DestinationPortStart specifies the start of the destination port + // DestinationPortStart is the inclusive start of the destination port // range to which the matcher applies. DestinationPortStart uint16 - // DestinationPortEnd specifies the start of the destination port + // DestinationPortEnd is the inclusive end of the destination port // range to which the matcher applies. DestinationPortEnd uint16 diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 6c88a50a6..b8848f08a 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -34,9 +34,16 @@ import ( // shouldn't be reached - an error has occurred if we fall through to one. const errorTargetName = "ERROR" -// metadata is opaque to netstack. It holds data that we need to translate -// between Linux's and netstack's iptables representations. -// TODO(gvisor.dev/issue/170): Use metadata to check correctness. +const ( + matcherNameUDP = "udp" +) + +// Metadata is used to verify that we are correctly serializing and +// deserializing iptables into structs consumable by the iptables tool. We save +// a metadata struct when the tables are written, and when they are read out we +// verify that certain fields are the same. +// +// metadata is opaque to netstack. type metadata struct { HookEntry [linux.NF_INET_NUMHOOKS]uint32 Underflow [linux.NF_INET_NUMHOOKS]uint32 @@ -44,10 +51,12 @@ type metadata struct { Size uint32 } -const enableDebugLog = true +const enableDebug = false +// nflog logs messages related to the writing and reading of iptables, but only +// when enableDebug is true. func nflog(format string, args ...interface{}) { - if enableDebugLog { + if enableDebug { log.Infof("netfilter: "+format, args...) } } @@ -80,7 +89,7 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT info.NumEntries = metadata.NumEntries info.Size = metadata.Size - nflog("GetInfo returning info: %+v", info) + nflog("returning info: %+v", info) return info, nil } @@ -163,19 +172,19 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern copy(entries.Name[:], tablename) for ruleIdx, rule := range table.Rules { - nflog("Current offset: %d", entries.Size) + nflog("convert to binary: current offset: %d", entries.Size) // Is this a chain entry point? for hook, hookRuleIdx := range table.BuiltinChains { if hookRuleIdx == ruleIdx { - nflog("Found hook %d at offset %d", hook, entries.Size) + nflog("convert to binary: found hook %d at offset %d", hook, entries.Size) meta.HookEntry[hook] = entries.Size } } // Is this a chain underflow point? for underflow, underflowRuleIdx := range table.Underflows { if underflowRuleIdx == ruleIdx { - nflog("Found underflow %d at offset %d", underflow, entries.Size) + nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size) meta.Underflow[underflow] = entries.Size } } @@ -195,7 +204,7 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern // Serialize the matcher and add it to the // entry. serialized := marshalMatcher(matcher) - nflog("matcher serialized as: %v", serialized) + nflog("convert to binary: matcher serialized as: %v", serialized) if len(serialized)%8 != 0 { panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) } @@ -212,14 +221,14 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern entry.Elems = append(entry.Elems, serialized...) entry.NextOffset += uint16(len(serialized)) - nflog("Adding entry: %+v", entry) + nflog("convert to binary: adding entry: %+v", entry) entries.Size += uint32(entry.NextOffset) entries.Entrytable = append(entries.Entrytable, entry) meta.NumEntries++ } - nflog("Finished with an marshalled size of %d", meta.Size) + nflog("convert to binary: finished with an marshalled size of %d", meta.Size) meta.Size = entries.Size return entries, meta, nil } @@ -237,16 +246,18 @@ func marshalMatcher(matcher iptables.Matcher) []byte { } func marshalUDPMatcher(matcher *iptables.UDPMatcher) []byte { - nflog("Marshalling UDP matcher: %+v", matcher) + nflog("convert to binary: marshalling UDP matcher: %+v", matcher) + + // We have to pad this struct size to a multiple of 8 bytes. + const size = linux.SizeOfXTEntryMatch + linux.SizeOfXTUDP + 6 linuxMatcher := linux.KernelXTEntryMatch{ XTEntryMatch: linux.XTEntryMatch{ - MatchSize: linux.SizeOfXTEntryMatch + linux.SizeOfXTUDP + 6, - // Name: "udp", + MatchSize: size, }, Data: make([]byte, 0, linux.SizeOfXTUDP), } - copy(linuxMatcher.Name[:], "udp") + copy(linuxMatcher.Name[:], matcherNameUDP) xtudp := linux.XTUDP{ SourcePortStart: matcher.Data.SourcePortStart, @@ -255,17 +266,12 @@ func marshalUDPMatcher(matcher *iptables.UDPMatcher) []byte { DestinationPortEnd: matcher.Data.DestinationPortEnd, InverseFlags: matcher.Data.InverseFlags, } - nflog("marshalUDPMatcher: xtudp: %+v", xtudp) linuxMatcher.Data = binary.Marshal(linuxMatcher.Data, usermem.ByteOrder, xtudp) - nflog("marshalUDPMatcher: linuxMatcher: %+v", linuxMatcher) - // We have to pad this struct size to a multiple of 8 bytes, so we make - // this a little longer than it needs to be. - buf := make([]byte, 0, linux.SizeOfXTEntryMatch+linux.SizeOfXTUDP+6) + buf := make([]byte, 0, size) buf = binary.Marshal(buf, usermem.ByteOrder, linuxMatcher) buf = append(buf, []byte{0, 0, 0, 0, 0, 0}...) - nflog("Marshalled into matcher of size %d", len(buf)) - nflog("marshalUDPMatcher: buf is: %v", buf) + nflog("convert to binary: marshalled UDP matcher into %v", buf) return buf[:] } @@ -283,9 +289,8 @@ func marshalTarget(target iptables.Target) []byte { } func marshalStandardTarget(verdict iptables.Verdict) []byte { - nflog("Marshalling standard target with size %d", linux.SizeOfXTStandardTarget) + nflog("convert to binary: marshalling standard target with size %d", linux.SizeOfXTStandardTarget) - // TODO: Must be aligned. // The target's name will be the empty string. target := linux.XTStandardTarget{ Target: linux.XTEntryTarget{ @@ -353,8 +358,6 @@ func translateToStandardVerdict(val int32) (iptables.Verdict, *syserr.Error) { // SetEntries sets iptables rules for a single table. See // net/ipv4/netfilter/ip_tables.c:translate_table for reference. func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { - // printReplace(optVal) - // Get the basic rules data (struct ipt_replace). if len(optVal) < linux.SizeOfIPTReplace { log.Warningf("netfilter.SetEntries: optVal has insufficient size for replace %d", len(optVal)) @@ -375,13 +378,13 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { return syserr.ErrInvalidArgument } - nflog("Setting entries in table %q", replace.Name.String()) + nflog("set entries: setting entries in table %q", replace.Name.String()) // Convert input into a list of rules and their offsets. var offset uint32 var offsets []uint32 for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { - nflog("Processing entry at offset %d", offset) + nflog("set entries: processing entry at offset %d", offset) // Get the struct ipt_entry. if len(optVal) < linux.SizeOfIPTEntry { @@ -406,11 +409,13 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { return err } - // TODO: Matchers (and maybe targets) can specify that they only work for certiain protocols, hooks, tables. + // TODO(gvisor.dev/issue/170): Matchers and targets can specify + // that they only work for certiain protocols, hooks, tables. // Get matchers. matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry if len(optVal) < int(matchersSize) { log.Warningf("netfilter: entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) + return syserr.ErrInvalidArgument } matchers, err := parseMatchers(filter, optVal[:matchersSize]) if err != nil { @@ -423,6 +428,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { targetSize := entry.NextOffset - entry.TargetOffset if len(optVal) < int(targetSize) { log.Warningf("netfilter: entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) + return syserr.ErrInvalidArgument } target, err := parseTarget(optVal[:targetSize]) if err != nil { @@ -500,10 +506,11 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // parseMatchers parses 0 or more matchers from optVal. optVal should contain // only the matchers. func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Matcher, *syserr.Error) { - nflog("Parsing matchers of size %d", len(optVal)) + nflog("set entries: parsing matchers of size %d", len(optVal)) var matchers []iptables.Matcher for len(optVal) > 0 { - nflog("parseMatchers: optVal has len %d", len(optVal)) + nflog("set entries: optVal has len %d", len(optVal)) + // Get the XTEntryMatch. if len(optVal) < linux.SizeOfXTEntryMatch { log.Warningf("netfilter: optVal has insufficient size for entry match: %d", len(optVal)) @@ -512,7 +519,7 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma var match linux.XTEntryMatch buf := optVal[:linux.SizeOfXTEntryMatch] binary.Unmarshal(buf, usermem.ByteOrder, &match) - nflog("parseMatchers: parsed entry match %q: %+v", match.Name.String(), match) + nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match) // Check some invariants. if match.MatchSize < linux.SizeOfXTEntryMatch { @@ -528,17 +535,17 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma var matcher iptables.Matcher var err error switch match.Name.String() { - case "udp": + case matcherNameUDP: if len(buf) < linux.SizeOfXTUDP { log.Warningf("netfilter: optVal has insufficient size for UDP match: %d", len(optVal)) return nil, syserr.ErrInvalidArgument } + // For alignment reasons, the match's total size may + // exceed what's strictly necessary to hold matchData. var matchData linux.XTUDP - // For alignment reasons, the match's total size may exceed what's - // strictly necessary to hold matchData. binary.Unmarshal(buf[:linux.SizeOfXTUDP], usermem.ByteOrder, &matchData) log.Infof("parseMatchers: parsed XTUDP: %+v", matchData) - matcher, err = iptables.NewUDPMatcher(filter, iptables.UDPMatcherData{ + matcher, err = iptables.NewUDPMatcher(filter, iptables.UDPMatcherParams{ SourcePortStart: matchData.SourcePortStart, SourcePortEnd: matchData.SourcePortEnd, DestinationPortStart: matchData.DestinationPortStart, @@ -557,19 +564,22 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma matchers = append(matchers, matcher) - // TODO: Support revision. - // TODO: Support proto -- matchers usually specify which proto(s) they work with. + // TODO(gvisor.dev/issue/170): Check the revision field. optVal = optVal[match.MatchSize:] } - // TODO: Check that optVal is exhausted. + if len(optVal) != 0 { + log.Warningf("netfilter: optVal should be exhausted after parsing matchers") + return nil, syserr.ErrInvalidArgument + } + return matchers, nil } // parseTarget parses a target from optVal. optVal should contain only the // target. func parseTarget(optVal []byte) (iptables.Target, *syserr.Error) { - nflog("Parsing target of size %d", len(optVal)) + nflog("set entries: parsing target of size %d", len(optVal)) if len(optVal) < linux.SizeOfXTEntryTarget { log.Warningf("netfilter: optVal has insufficient size for entry target %d", len(optVal)) return nil, syserr.ErrInvalidArgument @@ -598,7 +608,8 @@ func parseTarget(optVal []byte) (iptables.Target, *syserr.Error) { case iptables.Drop: return iptables.UnconditionalDropTarget{}, nil default: - panic(fmt.Sprintf("Unknown verdict: %v", verdict)) + log.Warningf("Unknown verdict: %v", verdict) + return nil, syserr.ErrInvalidArgument } case errorTargetName: @@ -673,52 +684,3 @@ func hookFromLinux(hook int) iptables.Hook { } panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook)) } - -// printReplace prints information about the struct ipt_replace in optVal. It -// is only for debugging. -func printReplace(optVal []byte) { - // Basic replace info. - var replace linux.IPTReplace - replaceBuf := optVal[:linux.SizeOfIPTReplace] - optVal = optVal[linux.SizeOfIPTReplace:] - binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace) - log.Infof("Replacing table %q: %+v", replace.Name.String(), replace) - - // Read in the list of entries at the end of replace. - var totalOffset uint16 - for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { - var entry linux.IPTEntry - entryBuf := optVal[:linux.SizeOfIPTEntry] - binary.Unmarshal(entryBuf, usermem.ByteOrder, &entry) - log.Infof("Entry %d (total offset %d): %+v", entryIdx, totalOffset, entry) - - totalOffset += entry.NextOffset - if entry.TargetOffset == linux.SizeOfIPTEntry { - log.Infof("Entry has no matches.") - } else { - log.Infof("Entry has matches.") - } - - var target linux.XTEntryTarget - targetBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTEntryTarget] - binary.Unmarshal(targetBuf, usermem.ByteOrder, &target) - log.Infof("Target named %q: %+v", target.Name.String(), target) - - switch target.Name.String() { - case "": - var standardTarget linux.XTStandardTarget - stBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTStandardTarget] - binary.Unmarshal(stBuf, usermem.ByteOrder, &standardTarget) - log.Infof("Standard target with verdict %q (%d).", linux.VerdictStrings[standardTarget.Verdict], standardTarget.Verdict) - case errorTargetName: - var errorTarget linux.XTErrorTarget - etBuf := optVal[entry.TargetOffset : entry.TargetOffset+linux.SizeOfXTErrorTarget] - binary.Unmarshal(etBuf, usermem.ByteOrder, &errorTarget) - log.Infof("Error target with name %q.", errorTarget.Name.String()) - default: - log.Infof("Unknown target type.") - } - - optVal = optVal[entry.NextOffset:] - } -} diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index d47447d40..ba5ed75b4 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -169,8 +169,6 @@ type IPHeaderFilter struct { Protocol tcpip.TransportProtocolNumber } -// TODO: Should these be able to marshal/unmarshal themselves? -// TODO: Something has to map the name to the matcher. // A Matcher is the interface for matching packets. type Matcher interface { // Match returns whether the packet matches and whether the packet @@ -179,19 +177,6 @@ type Matcher interface { // // Precondition: packet.NetworkHeader is set. Match(hook Hook, packet tcpip.PacketBuffer, interfaceName string) (matches bool, hotdrop bool) - - // TODO: Make this typesafe by having each Matcher have their own, typed CheckEntry? - // CheckEntry(params MatchCheckEntryParams) bool -} - -// TODO: Unused? -type MatchCheckEntryParams struct { - Table string // TODO: Tables should be an enum... - Filter IPHeaderFilter - Info interface{} // TODO: Type unsafe. - // HookMask uint8 - // Family uint8 - // NFTCompat bool } // A Target is the interface for taking an action for a packet. diff --git a/pkg/tcpip/iptables/udp_matcher.go b/pkg/tcpip/iptables/udp_matcher.go index 65ae7f9e0..f59ca2027 100644 --- a/pkg/tcpip/iptables/udp_matcher.go +++ b/pkg/tcpip/iptables/udp_matcher.go @@ -16,33 +16,28 @@ package iptables import ( "fmt" - "runtime/debug" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) +// TODO(gvisor.dev/issue/170): The following per-matcher params should be +// supported: +// - Table name +// - Match size +// - User size +// - Hooks +// - Proto +// - Family + +// UDPMatcher matches UDP packets and their headers. It implements Matcher. type UDPMatcher struct { - Data UDPMatcherData - - // tablename string - // unsigned int matchsize; - // unsigned int usersize; - // #ifdef CONFIG_COMPAT - // unsigned int compatsize; - // #endif - // unsigned int hooks; - // unsigned short proto; - // unsigned short family; + Data UDPMatcherParams } -// TODO: Delete? -// MatchCheckEntryParams - -type UDPMatcherData struct { - // Filter IPHeaderFilter - +// UDPMatcherParams are the parameters used to create a UDPMatcher. +type UDPMatcherParams struct { SourcePortStart uint16 SourcePortEnd uint16 DestinationPortStart uint16 @@ -50,12 +45,12 @@ type UDPMatcherData struct { InverseFlags uint8 } -func NewUDPMatcher(filter IPHeaderFilter, data UDPMatcherData) (Matcher, error) { - // TODO: We currently only support source port and destination port. - log.Infof("Adding rule with UDPMatcherData: %+v", data) +// NewUDPMatcher returns a new instance of UDPMatcher. +func NewUDPMatcher(filter IPHeaderFilter, data UDPMatcherParams) (Matcher, error) { + log.Infof("Adding rule with UDPMatcherParams: %+v", data) if data.InverseFlags != 0 { - return nil, fmt.Errorf("unsupported UDP matcher flags set") + return nil, fmt.Errorf("unsupported UDP matcher inverse flags set") } if filter.Protocol != header.UDPProtocolNumber { @@ -65,21 +60,18 @@ func NewUDPMatcher(filter IPHeaderFilter, data UDPMatcherData) (Matcher, error) return &UDPMatcher{Data: data}, nil } -// TODO: Check xt_tcpudp.c. Need to check for same things (e.g. fragments). +// Match implements Matcher.Match. func (um *UDPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) { - log.Infof("UDPMatcher called from: %s", string(debug.Stack())) netHeader := header.IPv4(pkt.NetworkHeader) - // TODO: Do we check proto here or elsewhere? I think elsewhere (check - // codesearch). + // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved + // into the iptables.Check codepath as matchers are added. if netHeader.TransportProtocol() != header.UDPProtocolNumber { - log.Infof("UDPMatcher: wrong protocol number") return false, false } // We dont't match fragments. if frag := netHeader.FragmentOffset(); frag != 0 { - log.Infof("UDPMatcher: it's a fragment") if frag == 1 { return false, true } @@ -89,20 +81,18 @@ func (um *UDPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName str // Now we need the transport header. However, this may not have been set // yet. - // TODO + // TODO(gvisor.dev/issue/170): Parsing the transport header should + // ultimately be moved into the iptables.Check codepath as matchers are + // added. var udpHeader header.UDP if pkt.TransportHeader != nil { - log.Infof("UDPMatcher: transport header is not nil") udpHeader = header.UDP(pkt.TransportHeader) } else { - log.Infof("UDPMatcher: transport header is nil") - log.Infof("UDPMatcher: is network header nil: %t", pkt.NetworkHeader == nil) // The UDP header hasn't been parsed yet. We have to do it here. if len(pkt.Data.First()) < header.UDPMinimumSize { // There's no valid UDP header here, so we hotdrop the // packet. - // TODO: Stats. - log.Warningf("Dropping UDP packet: size to small.") + log.Warningf("Dropping UDP packet: size too small.") return false, true } udpHeader = header.UDP(pkt.Data.First()) @@ -112,10 +102,6 @@ func (um *UDPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName str // matching range. sourcePort := udpHeader.SourcePort() destinationPort := udpHeader.DestinationPort() - log.Infof("UDPMatcher: sport and dport are %d and %d. sports and dport start and end are (%d, %d) and (%d, %d)", - udpHeader.SourcePort(), udpHeader.DestinationPort(), - um.Data.SourcePortStart, um.Data.SourcePortEnd, - um.Data.DestinationPortStart, um.Data.DestinationPortEnd) if sourcePort < um.Data.SourcePortStart || um.Data.SourcePortEnd < sourcePort { return false, false } diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index bc963d40e..e9f0978eb 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -264,9 +264,9 @@ func (FilterInputMultiUDPRules) ContainerAction(ip net.IP) error { if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } - // if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"); err != nil { - // return err - // } + if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"); err != nil { + return err + } return filterTable("-L") } -- cgit v1.2.3 From e889f95671a9b7b1c6f65cb6fbc1b865a896e827 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Mon, 27 Jan 2020 12:30:06 -0800 Subject: More cleanup. --- pkg/tcpip/iptables/udp_matcher.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/iptables/udp_matcher.go b/pkg/tcpip/iptables/udp_matcher.go index f59ca2027..3bb076f9c 100644 --- a/pkg/tcpip/iptables/udp_matcher.go +++ b/pkg/tcpip/iptables/udp_matcher.go @@ -73,9 +73,9 @@ func (um *UDPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName str // We dont't match fragments. if frag := netHeader.FragmentOffset(); frag != 0 { if frag == 1 { + log.Warningf("Dropping UDP packet: malicious fragmented packet.") return false, true } - log.Warningf("Dropping UDP packet: malicious fragmented packet.") return false, false } -- cgit v1.2.3 From 51b783505b1ec164b02b48a0fd234509fba01a73 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 29 Jan 2020 15:41:51 -0800 Subject: Add support for TCP_DEFER_ACCEPT. PiperOrigin-RevId: 292233574 --- pkg/sentry/socket/netstack/netstack.go | 22 ++++ pkg/tcpip/tcpip.go | 6 ++ pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/accept.go | 25 ++--- pkg/tcpip/transport/tcp/connect.go | 53 +++++++++- pkg/tcpip/transport/tcp/endpoint.go | 26 ++++- pkg/tcpip/transport/tcp/forwarder.go | 4 +- pkg/tcpip/transport/tcp/tcp_test.go | 126 ++++++++++++++++++++++ test/syscalls/linux/socket_inet_loopback.cc | 158 ++++++++++++++++++++++++++++ test/syscalls/linux/tcp_socket.cc | 53 ++++++++++ 10 files changed, 451 insertions(+), 23 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 8619cc506..049d04bf2 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1260,6 +1260,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(time.Duration(v) / time.Second), nil + case linux.TCP_DEFER_ACCEPT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPDeferAcceptOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(time.Duration(v) / time.Second), nil + default: emitUnimplementedEventTCP(t, name) } @@ -1713,6 +1725,16 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + case linux.TCP_DEFER_ACCEPT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < 0 { + v = 0 + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v)))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 59c9b3fb0..0fa141d58 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -626,6 +626,12 @@ type TCPLingerTimeoutOption time.Duration // before being marked closed. type TCPTimeWaitTimeoutOption time.Duration +// TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a +// accept to return a completed connection only when there is data to be +// read. This usually means the listening socket will drop the final ACK +// for a handshake till the specified timeout until a segment with data arrives. +type TCPDeferAcceptOption time.Duration + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 4acd9fb9a..7b4a87a2d 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -57,6 +57,7 @@ go_library( imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ + "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d469758eb..6101f2945 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -222,13 +222,13 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { netProto = s.route.NetProto } - n := newEndpoint(l.stack, netProto, nil) + n := newEndpoint(l.stack, netProto, queue) n.v6only = l.v6only n.ID = s.id n.boundNICID = s.route.NICID() @@ -273,16 +273,17 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // createEndpoint creates a new endpoint in connected state and then performs // the TCP 3-way handshake. -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep, err := l.createConnectingEndpoint(s, isn, irs, opts) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) if err != nil { return nil, err } // listenEP is nil when listenContext is used by tcp.Forwarder. + deferAccept := time.Duration(0) if l.listenEP != nil { l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { @@ -290,13 +291,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, tcpip.ErrConnectionAborted } l.addPendingEndpoint(ep) + deferAccept = l.listenEP.deferAccept l.listenEP.mu.Unlock() } // Perform the 3-way handshake. - h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) - - h.resetToSynRcvd(isn, irs, opts) + h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.Close() if l.listenEP != nil { @@ -377,16 +377,14 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header defer e.decSynRcvdCount() defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts) + n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() return } ctx.removePendingEndpoint(n) - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) + n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() e.deliverAccepted(n) @@ -546,7 +544,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) + n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{}) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() @@ -576,8 +574,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // space available in the backlog. // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) + n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 4e3c5419c..9ff7ac261 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -86,6 +86,19 @@ type handshake struct { // rcvWndScale is the receive window scale, as defined in RFC 1323. rcvWndScale int + + // startTime is the time at which the first SYN/SYN-ACK was sent. + startTime time.Time + + // deferAccept if non-zero will drop the final ACK for a passive + // handshake till an ACK segment with data is received or the timeout is + // hit. + deferAccept time.Duration + + // acked is true if the the final ACK for a 3-way handshake has + // been received. This is required to stop retransmitting the + // original SYN-ACK when deferAccept is enabled. + acked bool } func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { @@ -112,6 +125,12 @@ func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { return h } +func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake { + h := newHandshake(ep, rcvWnd) + h.resetToSynRcvd(isn, irs, opts, deferAccept) + return h +} + // FindWndScale determines the window scale to use for the given maximum window // size. func FindWndScale(wnd seqnum.Size) int { @@ -181,7 +200,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 { // resetToSynRcvd resets the state of the handshake object to the SYN-RCVD // state. -func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) { +func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) { h.active = false h.state = handshakeSynRcvd h.flags = header.TCPFlagSyn | header.TCPFlagAck @@ -189,6 +208,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.ackNum = irs + 1 h.mss = opts.MSS h.sndWndScale = opts.WS + h.deferAccept = deferAccept h.ep.mu.Lock() h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() @@ -352,6 +372,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { // We have previously received (and acknowledged) the peer's SYN. If the // peer acknowledges our SYN, the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { + // If deferAccept is not zero and this is a bare ACK and the + // timeout is not hit then drop the ACK. + if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept { + h.acked = true + h.ep.stack.Stats().DroppedPackets.Increment() + return nil + } + // If the timestamp option is negotiated and the segment does // not carry a timestamp option then the segment must be dropped // as per https://tools.ietf.org/html/rfc7323#section-3.2. @@ -365,10 +393,16 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted + h.ep.mu.Lock() h.ep.transitionToStateEstablishedLocked(h) + // If the segment has data then requeue it for the receiver + // to process it again once main loop is started. + if s.data.Size() > 0 { + s.incRef() + h.ep.enqueueSegment(s) + } h.ep.mu.Unlock() - return nil } @@ -471,6 +505,7 @@ func (h *handshake) execute() *tcpip.Error { } } + h.startTime = time.Now() // Initialize the resend timer. resendWaker := sleep.Waker{} timeOut := time.Duration(time.Second) @@ -524,11 +559,21 @@ func (h *handshake) execute() *tcpip.Error { switch index, _ := s.Fetch(true); index { case wakerForResend: timeOut *= 2 - if timeOut > 60*time.Second { + if timeOut > MaxRTO { return tcpip.ErrTimeout } rt.Reset(timeOut) - h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + // Resend the SYN/SYN-ACK only if the following conditions hold. + // - It's an active handshake (deferAccept does not apply) + // - It's a passive handshake and we have not yet got the final-ACK. + // - It's a passive handshake and we got an ACK but deferAccept is + // enabled and we are now past the deferAccept duration. + // The last is required to provide a way for the peer to complete + // the connection with another ACK or data (as ACKs are never + // retransmitted on their own). + if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + } case wakerForNotification: n := h.ep.fetchNotifications() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 13718ff55..8d52414b7 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -498,6 +498,13 @@ type endpoint struct { // without any data being acked. userTimeout time.Duration + // deferAccept if non-zero specifies a user specified time during + // which the final ACK of a handshake will be dropped provided the + // ACK is a bare ACK and carries no data. If the timeout is crossed then + // the bare ACK is accepted and the connection is delivered to the + // listener. + deferAccept time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing @@ -1574,6 +1581,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.TCPDeferAcceptOption: + e.mu.Lock() + if time.Duration(v) > MaxRTO { + v = tcpip.TCPDeferAcceptOption(MaxRTO) + } + e.deferAccept = time.Duration(v) + e.mu.Unlock() + return nil + default: return nil } @@ -1798,6 +1814,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case *tcpip.TCPDeferAcceptOption: + e.mu.Lock() + *o = tcpip.TCPDeferAcceptOption(e.deferAccept) + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -2149,9 +2171,8 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. -func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { +func (e *endpoint) startAcceptedLoop() { e.mu.Lock() - e.waiterQueue = waiterQueue e.workerRunning = true e.mu.Unlock() wakerInitDone := make(chan struct{}) @@ -2177,7 +2198,6 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { default: return nil, nil, tcpip.ErrWouldBlock } - return n, n.waiterQueue, nil } diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 7eb613be5..c9ee5bf06 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -157,13 +157,13 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, TSVal: r.synOptions.TSVal, TSEcr: r.synOptions.TSEcr, SACKPermitted: r.synOptions.SACKPermitted, - }) + }, queue) if err != nil { return nil, err } // Start the protocol goroutine. - ep.startAcceptedLoop(queue) + ep.startAcceptedLoop() return ep, nil } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index df2fb1071..a12336d47 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -6787,3 +6787,129 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { ), ) } + +func TestTCPDeferAccept(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + const tcpDeferAccept = 1 * time.Second + if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock) + } + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) + + // Give a bit of time for the socket to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept() + if err != nil { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err) + } + + aep.Close() + // Closing aep without reading the data should trigger a RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestTCPDeferAcceptTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + const tcpDeferAccept = 1 * time.Second + if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock) + } + + // Sleep for a little of the tcpDeferAccept timeout. + time.Sleep(tcpDeferAccept + 100*time.Millisecond) + + // On timeout expiry we should get a SYN-ACK retransmission. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) + + // Give sometime for the endpoint to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept() + if err != nil { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err) + } + + aep.Close() + // Closing aep without reading the data should trigger a RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 2f9821555..3bf7081b9 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -828,6 +828,164 @@ TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { EXPECT_EQ(get, kUserTimeout); } +// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not +// saved. Enable S/R once issue is fixed. +TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) { + // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not + // saved. Enable S/R issue is fixed. + DisableSave ds; + + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Set the TCP_DEFER_ACCEPT on the listening socket. + constexpr int kTCPDeferAccept = 3; + ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, + &kTCPDeferAccept, sizeof(kTCPDeferAccept)), + SyscallSucceeds()); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Set the listening socket to nonblock so that we can verify that there is no + // connection in queue despite the connect above succeeding since the peer has + // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the + // FD to O_NONBLOCK. + int opts; + ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Set FD back to blocking. + opts &= ~O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + // Now write some data to the socket. + int data = 0; + ASSERT_THAT(RetryEINTR(write)(conn_fd.get(), &data, sizeof(data)), + SyscallSucceedsWithValue(sizeof(data))); + + // This should now cause the connection to complete and be delivered to the + // accept socket. + + // Accept the connection. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + + // Verify that the accepted socket returns the data written. + int get = -1; + ASSERT_THAT(RetryEINTR(recv)(accepted.get(), &get, sizeof(get), 0), + SyscallSucceedsWithValue(sizeof(get))); + + EXPECT_EQ(get, data); +} + +// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not +// saved. Enable S/R once issue is fixed. +TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) { + // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not + // saved. Enable S/R once issue is fixed. + DisableSave ds; + + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast(&listen_addr), &addrlen), + SyscallSucceeds()); + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Set the TCP_DEFER_ACCEPT on the listening socket. + constexpr int kTCPDeferAccept = 3; + ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, + &kTCPDeferAccept, sizeof(kTCPDeferAccept)), + SyscallSucceeds()); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Set the listening socket to nonblock so that we can verify that there is no + // connection in queue despite the connect above succeeding since the peer has + // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the + // FD to O_NONBLOCK. + int opts; + ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds()); + opts |= O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + // Verify that there is no acceptable connection before TCP_DEFER_ACCEPT + // timeout is hit. + absl::SleepFor(absl::Seconds(kTCPDeferAccept - 1)); + ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Set FD back to blocking. + opts &= ~O_NONBLOCK; + ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds()); + + // Now sleep for a little over the TCP_DEFER_ACCEPT duration. When the timeout + // is hit a SYN-ACK should be retransmitted by the listener as a last ditch + // attempt to complete the connection with or without data. + absl::SleepFor(absl::Seconds(2)); + + // Verify that we have a connection that can be accepted even though no + // data was written. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); +} + INSTANTIATE_TEST_SUITE_P( All, SocketInetLoopbackTest, ::testing::Values( diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 33a5ac66c..525ccbd88 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1286,6 +1286,59 @@ TEST_P(SimpleTcpSocketTest, SetTCPUserTimeout) { EXPECT_EQ(get, kTCPUserTimeout); } +TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptNeg) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + // -ve TCP_DEFER_ACCEPT is same as setting it to zero. + constexpr int kNeg = -1; + EXPECT_THAT( + setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &kNeg, sizeof(kNeg)), + SyscallSucceeds()); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); +} + +TEST_P(SimpleTcpSocketTest, GetTCPDeferAcceptDefault) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, 0); +} + +TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptGreaterThanZero) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + // kTCPDeferAccept is in seconds. + // NOTE: linux translates seconds to # of retries and back from + // #of retries to seconds. Which means only certain values + // translate back exactly. That's why we use 3 here, a value of + // 5 will result in us getting back 7 instead of 5 in the + // getsockopt. + constexpr int kTCPDeferAccept = 3; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, + &kTCPDeferAccept, sizeof(kTCPDeferAccept)), + SyscallSucceeds()); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len), + SyscallSucceeds()); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kTCPDeferAccept); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); -- cgit v1.2.3 From 6f841c304d7bd9af6167d7d049bd5c594358a1b9 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Wed, 29 Jan 2020 19:54:11 -0800 Subject: Do not spawn a goroutine when calling stack.NDPDispatcher's methods Do not start a new goroutine when calling stack.NDPDispatcher.OnDuplicateAddressDetectionStatus. PiperOrigin-RevId: 292268574 --- pkg/tcpip/stack/ndp.go | 8 ++++---- pkg/tcpip/stack/ndp_test.go | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 245694118..281ae786d 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -167,8 +167,8 @@ type NDPDispatcher interface { // reason, such as the address being removed). If an error occured // during DAD, err will be set and resolved must be ignored. // - // This function is permitted to block indefinitely without interfering - // with the stack's operation. + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) // OnDefaultRouterDiscovered will be called when a new default router is @@ -607,8 +607,8 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { delete(ndp.dad, addr) // Let the integrator know DAD did not resolve. - if ndp.nic.stack.ndpDisp != nil { - go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 726468e41..8c76e80f2 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -497,7 +497,7 @@ func TestDADFail(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } ndpConfigs := stack.DefaultNDPConfigurations() opts := stack.Options{ @@ -576,7 +576,7 @@ func TestDADFail(t *testing.T) { // removed. func TestDADStop(t *testing.T) { ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } ndpConfigs := stack.NDPConfigurations{ RetransmitTimer: time.Second, -- cgit v1.2.3 From ec0679737e8f9ab31ef6c7c3adb5a0005586b5a7 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 30 Jan 2020 07:12:04 -0800 Subject: Do not include the Source Link Layer option with an unspecified source address When sending NDP messages with an unspecified source address, the Source Link Layer address must not be included. Test: stack_test.TestDADResolve PiperOrigin-RevId: 292341334 --- pkg/tcpip/stack/ndp.go | 22 ++-------------------- pkg/tcpip/stack/ndp_test.go | 12 ++++++++---- 2 files changed, 10 insertions(+), 24 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 281ae786d..31294345d 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -538,29 +538,11 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() - linkAddr := ndp.nic.linkEP.LinkAddress() - isValidLinkAddr := header.IsValidUnicastEthernetAddress(linkAddr) - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize - if isValidLinkAddr { - // Only include a Source Link Layer Address option if the NIC has a valid - // link layer address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by - // LinkEndpoint.LinkAddress) before reaching this point. - ndpNSSize += header.NDPLinkLayerAddressSize - } - - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) pkt.SetType(header.ICMPv6NeighborSolicit) ns := header.NDPNeighborSolicit(pkt.NDPPayload()) ns.SetTargetAddress(addr) - - if isValidLinkAddr { - ns.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr), - }) - } pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) sent := r.Stats().ICMP.V6PacketsSent diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 8c76e80f2..bc7cfbcb4 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -413,14 +413,18 @@ func TestDADResolve(t *testing.T) { t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) } - // Check NDP packet. + // Check NDP NS packet. + // + // As per RFC 4861 section 4.3, a possible option is the Source Link + // Layer option, but this option MUST NOT be included when the source + // address of the packet is the unspecified address. checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.SolicitedNodeAddr(addr1)), checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(addr1), - checker.NDPNSOptions([]header.NDPOption{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }), + checker.NDPNSOptions(nil), )) } }) -- cgit v1.2.3 From 4ee64a248ec16fcc9e526a457a66648546611bfb Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Thu, 30 Jan 2020 11:48:36 -0800 Subject: Fix for panic in endpoint.Close(). When sending a RST on shutdown we need to double check the state after acquiring the work mutex as the endpoint could have transitioned out of a connected state from the time we checked it and we acquired the workMutex. I added two tests but sadly neither reproduce the panic. I am going to leave the tests in as they are good to have anyway. PiperOrigin-RevId: 292393800 --- pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/endpoint.go | 10 ++++- pkg/tcpip/transport/tcp/tcp_test.go | 55 ++++++++++++++++++++++++++++ test/syscalls/linux/BUILD | 1 + test/syscalls/linux/socket_ip_tcp_generic.cc | 33 +++++++++++++++++ 5 files changed, 98 insertions(+), 2 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 7b4a87a2d..272e8f570 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -91,6 +91,7 @@ go_test( tags = ["flaky"], deps = [ ":tcp", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8d52414b7..b5a8e15ee 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2047,8 +2047,14 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // work mutex is available. if e.workMu.TryLock() { e.mu.Lock() - e.resetConnectionLocked(tcpip.ErrConnectionAborted) - e.notifyProtocolGoroutine(notifyTickleWorker) + // We need to double check here to make + // sure worker has not transitioned the + // endpoint out of a connected state + // before trying to send a reset. + if e.EndpointState().connected() { + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.notifyProtocolGoroutine(notifyTickleWorker) + } e.mu.Unlock() e.workMu.Unlock() } else { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index a12336d47..2c1505067 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -6913,3 +6914,57 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { checker.SeqNum(uint32(iss+1)), checker.AckNum(uint32(irs+5)))) } + +func TestResetDuringClose(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + iss := seqnum.Value(789) + c.CreateConnected(iss, 30000, -1 /* epRecvBuf */) + // Send some data to make sure there is some unread + // data to trigger a reset on c.Close. + irs := c.IRS + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(1), + AckNum: irs.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(irs.Add(1))), + checker.AckNum(uint32(iss.Add(5))))) + + // Close in a separate goroutine so that we can trigger + // a race with the RST we send below. This should not + // panic due to the route being released depeding on + // whether Close() sends an active RST or the RST sent + // below is processed by the worker first. + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: iss.Add(5), + AckNum: c.IRS.Add(5), + RcvWnd: 30000, + Flags: header.TCPFlagRst, + }) + }() + + wg.Add(1) + go func() { + defer wg.Done() + c.EP.Close() + }() + + wg.Wait() +} diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 74bf068ec..7958fd0d7 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2173,6 +2173,7 @@ cc_library( ":socket_test_util", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/memory", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 57ce8e169..27779e47c 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -24,6 +24,7 @@ #include #include "gtest/gtest.h" +#include "absl/memory/memory.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/syscalls/linux/socket_test_util.h" @@ -875,5 +876,37 @@ TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) { EXPECT_EQ(get, kAbove); } +TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) { + DisableSave ds; // Too many syscalls. + constexpr int kThreadCount = 1000; + std::unique_ptr instances[kThreadCount]; + for (int i = 0; i < kThreadCount; i++) { + instances[i] = absl::make_unique([&]() { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ScopedThread t([&]() { + // Close one end to trigger sending of a FIN. + struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0}; + // Wait up to 20 seconds for the data. + constexpr int kPollTimeoutMs = 20000; + ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds()); + }); + + // Send some data then close. + constexpr char kStr[] = "abc"; + ASSERT_THAT(write(sockets->first_fd(), kStr, 3), + SyscallSucceedsWithValue(3)); + absl::SleepFor(absl::Milliseconds(10)); + ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); + t.Join(); + }); + } + for (int i = 0; i < kThreadCount; i++) { + instances[i]->Join(); + } +} + } // namespace testing } // namespace gvisor -- cgit v1.2.3 From 528dd1ec72fee1dd63c734fe92d1b972b5735b8f Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 31 Jan 2020 13:24:48 -0800 Subject: Extract multicast IP to Ethernet address mapping Test: header.TestEthernetAddressFromMulticastIPAddress PiperOrigin-RevId: 292604649 --- pkg/tcpip/header/eth.go | 41 +++++++++++++++++++++++++++++++++++++++++ pkg/tcpip/header/eth_test.go | 34 ++++++++++++++++++++++++++++++++++ pkg/tcpip/network/arp/arp.go | 19 ++----------------- pkg/tcpip/network/ipv6/icmp.go | 20 ++------------------ 4 files changed, 79 insertions(+), 35 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index f5d2c127f..b1e92d2d7 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -134,3 +134,44 @@ func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool { // addr is a valid unicast ethernet address. return true } + +// EthernetAddressFromMulticastIPv4Address returns a multicast Ethernet address +// for a multicast IPv4 address. +// +// addr MUST be a multicast IPv4 address. +func EthernetAddressFromMulticastIPv4Address(addr tcpip.Address) tcpip.LinkAddress { + var linkAddrBytes [EthernetAddressSize]byte + // RFC 1112 Host Extensions for IP Multicasting + // + // 6.4. Extensions to an Ethernet Local Network Module: + // + // An IP host group address is mapped to an Ethernet multicast + // address by placing the low-order 23-bits of the IP address + // into the low-order 23 bits of the Ethernet multicast address + // 01-00-5E-00-00-00 (hex). + linkAddrBytes[0] = 0x1 + linkAddrBytes[2] = 0x5e + linkAddrBytes[3] = addr[1] & 0x7F + copy(linkAddrBytes[4:], addr[IPv4AddressSize-2:]) + return tcpip.LinkAddress(linkAddrBytes[:]) +} + +// EthernetAddressFromMulticastIPv6Address returns a multicast Ethernet address +// for a multicast IPv6 address. +// +// addr MUST be a multicast IPv6 address. +func EthernetAddressFromMulticastIPv6Address(addr tcpip.Address) tcpip.LinkAddress { + // RFC 2464 Transmission of IPv6 Packets over Ethernet Networks + // + // 7. Address Mapping -- Multicast + // + // An IPv6 packet with a multicast destination address DST, + // consisting of the sixteen octets DST[1] through DST[16], is + // transmitted to the Ethernet multicast address whose first + // two octets are the value 3333 hexadecimal and whose last + // four octets are the last four octets of DST. + linkAddrBytes := []byte(addr[IPv6AddressSize-EthernetAddressSize:]) + linkAddrBytes[0] = 0x33 + linkAddrBytes[1] = 0x33 + return tcpip.LinkAddress(linkAddrBytes[:]) +} diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go index 6634c90f5..7a0014ad9 100644 --- a/pkg/tcpip/header/eth_test.go +++ b/pkg/tcpip/header/eth_test.go @@ -66,3 +66,37 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) { }) } } + +func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "IPv4 Multicast without 24th bit set", + addr: "\xe0\x7e\xdc\xba", + expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba", + }, + { + name: "IPv4 Multicast with 24th bit set", + addr: "\xe0\xfe\xdc\xba", + expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr { + t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", got, test.expectedLinkAddr) + } + }) + } +} + +func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) { + addr := tcpip.Address("\xff\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x1a") + if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want { + t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want) + } +} diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 1ceaebfbd..4da13c5df 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -178,24 +178,9 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo return broadcastMAC, true } if header.IsV4MulticastAddress(addr) { - // RFC 1112 Host Extensions for IP Multicasting - // - // 6.4. Extensions to an Ethernet Local Network Module: - // - // An IP host group address is mapped to an Ethernet multicast - // address by placing the low-order 23-bits of the IP address - // into the low-order 23 bits of the Ethernet multicast address - // 01-00-5E-00-00-00 (hex). - return tcpip.LinkAddress([]byte{ - 0x01, - 0x00, - 0x5e, - addr[header.IPv4AddressSize-3] & 0x7f, - addr[header.IPv4AddressSize-2], - addr[header.IPv4AddressSize-1], - }), true + return header.EthernetAddressFromMulticastIPv4Address(addr), true } - return "", false + return tcpip.LinkAddress([]byte(nil)), false } // SetOption implements NetworkProtocol. diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index dc20c0fd7..7491cfc41 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -441,23 +441,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. // ResolveStaticAddress implements stack.LinkAddressResolver. func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if header.IsV6MulticastAddress(addr) { - // RFC 2464 Transmission of IPv6 Packets over Ethernet Networks - // - // 7. Address Mapping -- Multicast - // - // An IPv6 packet with a multicast destination address DST, - // consisting of the sixteen octets DST[1] through DST[16], is - // transmitted to the Ethernet multicast address whose first - // two octets are the value 3333 hexadecimal and whose last - // four octets are the last four octets of DST. - return tcpip.LinkAddress([]byte{ - 0x33, - 0x33, - addr[header.IPv6AddressSize-4], - addr[header.IPv6AddressSize-3], - addr[header.IPv6AddressSize-2], - addr[header.IPv6AddressSize-1], - }), true + return header.EthernetAddressFromMulticastIPv6Address(addr), true } - return "", false + return tcpip.LinkAddress([]byte(nil)), false } -- cgit v1.2.3 From eba7bdc24d31388ca81eeab251ed2db108f785dc Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Fri, 31 Jan 2020 13:46:13 -0800 Subject: iptables: enable TCP matching with "-m tcp". A couple other things that changed: - There's a proper extension registration system for matchers. Anyone adding another matcher can use tcp_matcher.go or udp_matcher.go as a template. - All logging and use of syserr.Error in the netfilter package happens at the highest possible level (public functions). Lower-level functions just return normal, descriptive golang errors. --- pkg/abi/linux/netfilter.go | 52 ++++++++ pkg/sentry/socket/netfilter/BUILD | 4 + pkg/sentry/socket/netfilter/extensions.go | 98 +++++++++++++++ pkg/sentry/socket/netfilter/netfilter.go | 187 ++++++++--------------------- pkg/sentry/socket/netfilter/tcp_matcher.go | 143 ++++++++++++++++++++++ pkg/sentry/socket/netfilter/udp_matcher.go | 142 ++++++++++++++++++++++ pkg/tcpip/iptables/BUILD | 1 - pkg/tcpip/iptables/types.go | 3 + pkg/tcpip/iptables/udp_matcher.go | 113 ----------------- 9 files changed, 495 insertions(+), 248 deletions(-) create mode 100644 pkg/sentry/socket/netfilter/extensions.go create mode 100644 pkg/sentry/socket/netfilter/tcp_matcher.go create mode 100644 pkg/sentry/socket/netfilter/udp_matcher.go delete mode 100644 pkg/tcpip/iptables/udp_matcher.go (limited to 'pkg/tcpip') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 8e40bcc62..e4aabb6bb 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -348,6 +348,58 @@ func goString(cstring []byte) string { return string(cstring) } +// XTTCP holds data for matching TCP packets. It corresponds to struct xt_tcp +// in include/uapi/linux/netfilter/xt_tcpudp.h. +type XTTCP struct { + // SourcePortStart specifies the inclusive start of the range of source + // ports to which the matcher applies. + SourcePortStart uint16 + + // SourcePortEnd specifies the inclusive end of the range of source ports + // to which the matcher applies. + SourcePortEnd uint16 + + // DestinationPortStart specifies the start of the destination port + // range to which the matcher applies. + DestinationPortStart uint16 + + // DestinationPortEnd specifies the start of the destination port + // range to which the matcher applies. + DestinationPortEnd uint16 + + // Option specifies that a particular TCP option must be set. + Option uint8 + + // FlagMask masks the FlagCompare byte when comparing to the TCP flag + // fields. + FlagMask uint8 + + // FlagCompare is binary and-ed with the TCP flag fields. + FlagCompare uint8 + + // InverseFlags flips the meaning of certain fields. See the + // TX_TCP_INV_* flags. + InverseFlags uint8 +} + +// SizeOfXTTCP is the size of an XTTCP. +const SizeOfXTTCP = 12 + +// Flags in XTTCP.InverseFlags. Corresponding constants are in +// include/uapi/linux/netfilter/xt_tcpudp.h. +const ( + // Invert the meaning of SourcePortStart/End. + XT_TCP_INV_SRCPT = 0x01 + // Invert the meaning of DestinationPortStart/End. + XT_TCP_INV_DSTPT = 0x02 + // Invert the meaning of FlagCompare. + XT_TCP_INV_FLAGS = 0x04 + // Invert the meaning of Option. + XT_TCP_INV_OPTION = 0x08 + // Enable all flags. + XT_TCP_INV_MASK = 0x0F +) + // XTUDP holds data for matching UDP packets. It corresponds to struct xt_udp // in include/uapi/linux/netfilter/xt_tcpudp.h. type XTUDP struct { diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index fa2a2cb66..c91ec7494 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -5,7 +5,10 @@ package(licenses = ["notice"]) go_library( name = "netfilter", srcs = [ + "extensions.go", "netfilter.go", + "tcp_matcher.go", + "udp_matcher.go", ], # This target depends on netstack and should only be used by epsocket, # which is allowed to depend on netstack. @@ -17,6 +20,7 @@ go_library( "//pkg/sentry/kernel", "//pkg/syserr", "//pkg/tcpip", + "//pkg/tcpip/header", "//pkg/tcpip/iptables", "//pkg/tcpip/stack", "//pkg/usermem", diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go new file mode 100644 index 000000000..5a4cac84c --- /dev/null +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -0,0 +1,98 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/usermem" +) + +// TODO(gvisor.dev/issue/170): The following per-matcher params should be +// supported: +// - Table name +// - Match size +// - User size +// - Hooks +// - Proto +// - Family + +// matchMarshaler knows how to (un)marshal the matcher named name(). +type matchMarshaler interface { + // name is the matcher name as stored in the xt_entry_match struct. + name() string + + // marshal converts from an iptables.Matcher to an ABI struct. + marshal(matcher iptables.Matcher) []byte + + // unmarshal converts from the ABI matcher struct to an + // iptables.Matcher. + unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) +} + +var matchMarshalers = map[string]matchMarshaler{} + +// registerMatchMarshaler should be called by match extensions to register them +// with the netfilter package. +func registerMatchMarshaler(mm matchMarshaler) { + if _, ok := matchMarshalers[mm.name()]; ok { + panic(fmt.Sprintf("Multiple matches registered with name %q.", mm.name())) + } + matchMarshalers[mm.name()] = mm +} + +func marshalMatcher(matcher iptables.Matcher) []byte { + matchMaker, ok := matchMarshalers[matcher.Name()] + if !ok { + panic(fmt.Errorf("Unknown matcher of type %T.", matcher)) + } + return matchMaker.marshal(matcher) +} + +// marshalEntryMatch creates a marshalled XTEntryMatch with the given name and +// data appended at the end. +func marshalEntryMatch(name string, data []byte) []byte { + nflog("marshaling matcher %q", name) + + // We have to pad this struct size to a multiple of 8 bytes. + size := alignUp(linux.SizeOfXTEntryMatch+len(data), 8) + matcher := linux.KernelXTEntryMatch{ + XTEntryMatch: linux.XTEntryMatch{ + MatchSize: uint16(size), + }, + Data: data, + } + copy(matcher.Name[:], name) + + buf := make([]byte, 0, size) + buf = binary.Marshal(buf, usermem.ByteOrder, matcher) + return append(buf, make([]byte, size-len(buf))...) +} + +func unmarshalMatcher(match linux.XTEntryMatch, filter iptables.IPHeaderFilter, buf []byte) (iptables.Matcher, error) { + matchMaker, ok := matchMarshalers[match.Name.String()] + if !ok { + return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String()) + } + return matchMaker.unmarshal(buf, filter) +} + +// alignUp rounds a length up to an alignment. align must be a power of 2. +func alignUp(length int, align uint) int { + return (length + int(align) - 1) & ^(int(align) - 1) +} diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 3dda6c7a1..8f14643b0 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -17,6 +17,7 @@ package netfilter import ( + "errors" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" @@ -34,10 +35,6 @@ import ( // shouldn't be reached - an error has occurred if we fall through to one. const errorTargetName = "ERROR" -const ( - matcherNameUDP = "udp" -) - // Metadata is used to verify that we are correctly serializing and // deserializing iptables into structs consumable by the iptables tool. We save // a metadata struct when the tables are written, and when they are read out we @@ -68,7 +65,8 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT // Find the appropriate table. table, err := findTable(stack, info.Name) if err != nil { - return linux.IPTGetinfo{}, err + nflog("%v", err) + return linux.IPTGetinfo{}, syserr.ErrInvalidArgument } // Get the hooks that apply to this table. @@ -95,39 +93,40 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen // Read in the struct and table name. var userEntries linux.IPTGetEntries if _, err := t.CopyIn(outPtr, &userEntries); err != nil { - log.Warningf("netfilter: couldn't copy in entries %q", userEntries.Name) + nflog("couldn't copy in entries %q", userEntries.Name) return linux.KernelIPTGetEntries{}, syserr.FromError(err) } // Find the appropriate table. table, err := findTable(stack, userEntries.Name) if err != nil { - log.Warningf("netfilter: couldn't find table %q", userEntries.Name) - return linux.KernelIPTGetEntries{}, err + nflog("%v", err) + return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } // Convert netstack's iptables rules to something that the iptables // tool can understand. entries, meta, err := convertNetstackToBinary(userEntries.Name.String(), table) if err != nil { - return linux.KernelIPTGetEntries{}, err + nflog("couldn't read entries: %v", err) + return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } if meta != table.Metadata().(metadata) { panic(fmt.Sprintf("Table %q metadata changed between writing and reading. Was saved as %+v, but is now %+v", userEntries.Name.String(), table.Metadata().(metadata), meta)) } if binary.Size(entries) > uintptr(outLen) { - log.Warningf("Insufficient GetEntries output size: %d", uintptr(outLen)) + nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } return entries, nil } -func findTable(stack *stack.Stack, tablename linux.TableName) (iptables.Table, *syserr.Error) { +func findTable(stack *stack.Stack, tablename linux.TableName) (iptables.Table, error) { ipt := stack.IPTables() table, ok := ipt.Tables[tablename.String()] if !ok { - return iptables.Table{}, syserr.ErrInvalidArgument + return iptables.Table{}, fmt.Errorf("couldn't find table %q", tablename) } return table, nil } @@ -151,19 +150,19 @@ func FillDefaultIPTables(stack *stack.Stack) { stack.SetIPTables(ipt) } +// TODO: Return proto. // convertNetstackToBinary converts the iptables as stored in netstack to the // format expected by the iptables tool. Linux stores each table as a binary // blob that can only be traversed by parsing a bit, reading some offsets, // jumping to those offsets, parsing again, etc. -func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) { +func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, error) { // Return values. var entries linux.KernelIPTGetEntries var meta metadata // The table name has to fit in the struct. if linux.XT_TABLE_MAXNAMELEN < len(tablename) { - log.Warningf("Table name %q too long.", tablename) - return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument + return linux.KernelIPTGetEntries{}, metadata{}, fmt.Errorf("Table name %q too long.", tablename) } copy(entries.Name[:], tablename) @@ -229,46 +228,6 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern return entries, meta, nil } -func marshalMatcher(matcher iptables.Matcher) []byte { - switch m := matcher.(type) { - case *iptables.UDPMatcher: - return marshalUDPMatcher(m) - default: - // TODO(gvisor.dev/issue/170): Support other matchers. - panic(fmt.Errorf("unknown matcher of type %T", matcher)) - } -} - -func marshalUDPMatcher(matcher *iptables.UDPMatcher) []byte { - nflog("convert to binary: marshalling UDP matcher: %+v", matcher) - - // We have to pad this struct size to a multiple of 8 bytes. - size := alignUp(linux.SizeOfXTEntryMatch+linux.SizeOfXTUDP, 8) - - linuxMatcher := linux.KernelXTEntryMatch{ - XTEntryMatch: linux.XTEntryMatch{ - MatchSize: uint16(size), - }, - Data: make([]byte, 0, linux.SizeOfXTUDP), - } - copy(linuxMatcher.Name[:], matcherNameUDP) - - xtudp := linux.XTUDP{ - SourcePortStart: matcher.Data.SourcePortStart, - SourcePortEnd: matcher.Data.SourcePortEnd, - DestinationPortStart: matcher.Data.DestinationPortStart, - DestinationPortEnd: matcher.Data.DestinationPortEnd, - InverseFlags: matcher.Data.InverseFlags, - } - linuxMatcher.Data = binary.Marshal(linuxMatcher.Data, usermem.ByteOrder, xtudp) - - buf := make([]byte, 0, size) - buf = binary.Marshal(buf, usermem.ByteOrder, linuxMatcher) - buf = append(buf, make([]byte, size-len(buf))...) - nflog("convert to binary: marshalled UDP matcher into %v", buf) - return buf[:] -} - func marshalTarget(target iptables.Target) []byte { switch target.(type) { case iptables.UnconditionalAcceptTarget: @@ -332,7 +291,7 @@ func translateFromStandardVerdict(verdict iptables.Verdict) int32 { // translateToStandardVerdict translates from the value in a // linux.XTStandardTarget to an iptables.Verdict. -func translateToStandardVerdict(val int32) (iptables.Verdict, *syserr.Error) { +func translateToStandardVerdict(val int32) (iptables.Verdict, error) { // TODO(gvisor.dev/issue/170): Support other verdicts. switch val { case -linux.NF_ACCEPT - 1: @@ -340,13 +299,12 @@ func translateToStandardVerdict(val int32) (iptables.Verdict, *syserr.Error) { case -linux.NF_DROP - 1: return iptables.Drop, nil case -linux.NF_QUEUE - 1: - log.Warningf("Unsupported iptables verdict QUEUE.") + return iptables.Invalid, errors.New("unsupported iptables verdict QUEUE") case linux.NF_RETURN: - log.Warningf("Unsupported iptables verdict RETURN.") + return iptables.Invalid, errors.New("unsupported iptables verdict RETURN") default: - log.Warningf("Unknown iptables verdict %d.", val) + return iptables.Invalid, fmt.Errorf("unknown iptables verdict %d.", val) } - return iptables.Invalid, syserr.ErrInvalidArgument } // SetEntries sets iptables rules for a single table. See @@ -354,7 +312,7 @@ func translateToStandardVerdict(val int32) (iptables.Verdict, *syserr.Error) { func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // Get the basic rules data (struct ipt_replace). if len(optVal) < linux.SizeOfIPTReplace { - log.Warningf("netfilter.SetEntries: optVal has insufficient size for replace %d", len(optVal)) + nflog("optVal has insufficient size for replace %d", len(optVal)) return syserr.ErrInvalidArgument } var replace linux.IPTReplace @@ -368,7 +326,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { case iptables.TablenameFilter: table = iptables.EmptyFilterTable() default: - log.Warningf("We don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) + nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) return syserr.ErrInvalidArgument } @@ -382,7 +340,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // Get the struct ipt_entry. if len(optVal) < linux.SizeOfIPTEntry { - log.Warningf("netfilter: optVal has insufficient size for entry %d", len(optVal)) + nflog("optVal has insufficient size for entry %d", len(optVal)) return syserr.ErrInvalidArgument } var entry linux.IPTEntry @@ -392,7 +350,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { optVal = optVal[linux.SizeOfIPTEntry:] if entry.TargetOffset < linux.SizeOfIPTEntry { - log.Warningf("netfilter: entry has too-small target offset %d", entry.TargetOffset) + nflog("entry has too-small target offset %d", entry.TargetOffset) return syserr.ErrInvalidArgument } @@ -400,7 +358,8 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // filtering fields. filter, err := filterFromIPTIP(entry.IP) if err != nil { - return err + nflog("bad iptip: %v", err) + return syserr.ErrInvalidArgument } // TODO(gvisor.dev/issue/170): Matchers and targets can specify @@ -408,25 +367,26 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // Get matchers. matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry if len(optVal) < int(matchersSize) { - log.Warningf("netfilter: entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) + nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) return syserr.ErrInvalidArgument } matchers, err := parseMatchers(filter, optVal[:matchersSize]) if err != nil { - log.Warningf("netfilter: failed to parse matchers: %v", err) - return err + nflog("failed to parse matchers: %v", err) + return syserr.ErrInvalidArgument } optVal = optVal[matchersSize:] // Get the target of the rule. targetSize := entry.NextOffset - entry.TargetOffset if len(optVal) < int(targetSize) { - log.Warningf("netfilter: entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) + nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) return syserr.ErrInvalidArgument } target, err := parseTarget(optVal[:targetSize]) if err != nil { - return err + nflog("failed to parse target: %v", err) + return syserr.ErrInvalidArgument } optVal = optVal[targetSize:] @@ -439,7 +399,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { offset += uint32(entry.NextOffset) if initialOptValLen-len(optVal) != int(entry.NextOffset) { - log.Warningf("netfilter: entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal)) + nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal)) } } @@ -457,11 +417,11 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { } } if ruleIdx := table.BuiltinChains[hk]; ruleIdx == iptables.HookUnset { - log.Warningf("Hook %v is unset.", hk) + nflog("hook %v is unset.", hk) return syserr.ErrInvalidArgument } if ruleIdx := table.Underflows[hk]; ruleIdx == iptables.HookUnset { - log.Warningf("Underflow %v is unset.", hk) + nflog("underflow %v is unset.", hk) return syserr.ErrInvalidArgument } } @@ -473,7 +433,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { for hook, ruleIdx := range table.BuiltinChains { if hook != iptables.Input { if _, ok := table.Rules[ruleIdx].Target.(iptables.UnconditionalAcceptTarget); !ok { - log.Warningf("Hook %d is unsupported.", hook) + nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument } } @@ -499,7 +459,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // parseMatchers parses 0 or more matchers from optVal. optVal should contain // only the matchers. -func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Matcher, *syserr.Error) { +func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Matcher, error) { nflog("set entries: parsing matchers of size %d", len(optVal)) var matchers []iptables.Matcher for len(optVal) > 0 { @@ -507,8 +467,7 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma // Get the XTEntryMatch. if len(optVal) < linux.SizeOfXTEntryMatch { - log.Warningf("netfilter: optVal has insufficient size for entry match: %d", len(optVal)) - return nil, syserr.ErrInvalidArgument + return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal)) } var match linux.XTEntryMatch buf := optVal[:linux.SizeOfXTEntryMatch] @@ -517,45 +476,18 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma // Check some invariants. if match.MatchSize < linux.SizeOfXTEntryMatch { - log.Warningf("netfilter: match size is too small, must be at least %d", linux.SizeOfXTEntryMatch) - return nil, syserr.ErrInvalidArgument + + return nil, fmt.Errorf("match size is too small, must be at least %d", linux.SizeOfXTEntryMatch) } if len(optVal) < int(match.MatchSize) { - log.Warningf("netfilter: optVal has insufficient size for match: %d", len(optVal)) - return nil, syserr.ErrInvalidArgument + return nil, fmt.Errorf("optVal has insufficient size for match: %d", len(optVal)) } - buf = optVal[linux.SizeOfXTEntryMatch:match.MatchSize] - var matcher iptables.Matcher - var err error - switch match.Name.String() { - case matcherNameUDP: - if len(buf) < linux.SizeOfXTUDP { - log.Warningf("netfilter: optVal has insufficient size for UDP match: %d", len(optVal)) - return nil, syserr.ErrInvalidArgument - } - // For alignment reasons, the match's total size may - // exceed what's strictly necessary to hold matchData. - var matchData linux.XTUDP - binary.Unmarshal(buf[:linux.SizeOfXTUDP], usermem.ByteOrder, &matchData) - log.Infof("parseMatchers: parsed XTUDP: %+v", matchData) - matcher, err = iptables.NewUDPMatcher(filter, iptables.UDPMatcherParams{ - SourcePortStart: matchData.SourcePortStart, - SourcePortEnd: matchData.SourcePortEnd, - DestinationPortStart: matchData.DestinationPortStart, - DestinationPortEnd: matchData.DestinationPortEnd, - InverseFlags: matchData.InverseFlags, - }) - if err != nil { - log.Warningf("netfilter: failed to create UDP matcher: %v", err) - return nil, syserr.ErrInvalidArgument - } - - default: - log.Warningf("netfilter: unsupported matcher with name %q", match.Name.String()) - return nil, syserr.ErrInvalidArgument + // Parse the specific matcher. + matcher, err := unmarshalMatcher(match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize]) + if err != nil { + return nil, fmt.Errorf("failed to create matcher: %v", err) } - matchers = append(matchers, matcher) // TODO(gvisor.dev/issue/170): Check the revision field. @@ -563,8 +495,7 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma } if len(optVal) != 0 { - log.Warningf("netfilter: optVal should be exhausted after parsing matchers") - return nil, syserr.ErrInvalidArgument + return nil, errors.New("optVal should be exhausted after parsing matchers") } return matchers, nil @@ -572,11 +503,10 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma // parseTarget parses a target from optVal. optVal should contain only the // target. -func parseTarget(optVal []byte) (iptables.Target, *syserr.Error) { +func parseTarget(optVal []byte) (iptables.Target, error) { nflog("set entries: parsing target of size %d", len(optVal)) if len(optVal) < linux.SizeOfXTEntryTarget { - log.Warningf("netfilter: optVal has insufficient size for entry target %d", len(optVal)) - return nil, syserr.ErrInvalidArgument + return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal)) } var target linux.XTEntryTarget buf := optVal[:linux.SizeOfXTEntryTarget] @@ -585,8 +515,7 @@ func parseTarget(optVal []byte) (iptables.Target, *syserr.Error) { case "": // Standard target. if len(optVal) != linux.SizeOfXTStandardTarget { - log.Warningf("netfilter.SetEntries: optVal has wrong size for standard target %d", len(optVal)) - return nil, syserr.ErrInvalidArgument + return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal)) } var standardTarget linux.XTStandardTarget buf = optVal[:linux.SizeOfXTStandardTarget] @@ -602,15 +531,13 @@ func parseTarget(optVal []byte) (iptables.Target, *syserr.Error) { case iptables.Drop: return iptables.UnconditionalDropTarget{}, nil default: - log.Warningf("Unknown verdict: %v", verdict) - return nil, syserr.ErrInvalidArgument + return nil, fmt.Errorf("Unknown verdict: %v", verdict) } case errorTargetName: // Error target. if len(optVal) != linux.SizeOfXTErrorTarget { - log.Infof("netfilter.SetEntries: optVal has insufficient size for error target %d", len(optVal)) - return nil, syserr.ErrInvalidArgument + return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal)) } var errorTarget linux.XTErrorTarget buf = optVal[:linux.SizeOfXTErrorTarget] @@ -627,20 +554,17 @@ func parseTarget(optVal []byte) (iptables.Target, *syserr.Error) { case errorTargetName: return iptables.ErrorTarget{}, nil default: - log.Infof("Unknown error target %q doesn't exist or isn't supported yet.", errorTarget.Name.String()) - return nil, syserr.ErrInvalidArgument + return nil, fmt.Errorf("Unknown error target %q doesn't exist or isn't supported yet.", errorTarget.Name.String()) } } // Unknown target. - log.Infof("Unknown target %q doesn't exist or isn't supported yet.", target.Name.String()) - return nil, syserr.ErrInvalidArgument + return nil, fmt.Errorf("Unknown target %q doesn't exist or isn't supported yet.", target.Name.String()) } -func filterFromIPTIP(iptip linux.IPTIP) (iptables.IPHeaderFilter, *syserr.Error) { +func filterFromIPTIP(iptip linux.IPTIP) (iptables.IPHeaderFilter, error) { if containsUnsupportedFields(iptip) { - log.Warningf("netfilter: unsupported fields in struct iptip: %+v", iptip) - return iptables.IPHeaderFilter{}, syserr.ErrInvalidArgument + return iptables.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) } return iptables.IPHeaderFilter{ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), @@ -678,8 +602,3 @@ func hookFromLinux(hook int) iptables.Hook { } panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook)) } - -// alignUp rounds a length up to an alignment. align must be a power of 2. -func alignUp(length int, align uint) int { - return (length + int(align) - 1) & ^(int(align) - 1) -} diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go new file mode 100644 index 000000000..1646d22f7 --- /dev/null +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -0,0 +1,143 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/usermem" +) + +const matcherNameTCP = "tcp" + +func init() { + registerMatchMarshaler(tcpMarshaler{}) +} + +// tcpMarshaler implements matchMarshaler for TCP matching. +type tcpMarshaler struct{} + +// name implements matchMarshaler.name. +func (tcpMarshaler) name() string { + return matcherNameTCP +} + +// marshal implements matchMarshaler.marshal. +func (tcpMarshaler) marshal(mr iptables.Matcher) []byte { + matcher := mr.(*TCPMatcher) + xttcp := linux.XTTCP{ + SourcePortStart: matcher.sourcePortStart, + SourcePortEnd: matcher.sourcePortEnd, + DestinationPortStart: matcher.destinationPortStart, + DestinationPortEnd: matcher.destinationPortEnd, + } + buf := make([]byte, 0, linux.SizeOfXTUDP) + return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, usermem.ByteOrder, xttcp)) +} + +// unmarshal implements matchMarshaler.unmarshal. +func (tcpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) { + if len(buf) < linux.SizeOfXTTCP { + return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf)) + } + + // For alignment reasons, the match's total size may + // exceed what's strictly necessary to hold matchData. + var matchData linux.XTTCP + binary.Unmarshal(buf[:linux.SizeOfXTTCP], usermem.ByteOrder, &matchData) + nflog("parseMatchers: parsed XTTCP: %+v", matchData) + + if matchData.Option != 0 || + matchData.FlagMask != 0 || + matchData.FlagCompare != 0 || + matchData.InverseFlags != 0 { + return nil, fmt.Errorf("unsupported TCP matcher flags set") + } + + if filter.Protocol != header.TCPProtocolNumber { + return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber) + } + + return &TCPMatcher{ + sourcePortStart: matchData.SourcePortStart, + sourcePortEnd: matchData.SourcePortEnd, + destinationPortStart: matchData.DestinationPortStart, + destinationPortEnd: matchData.DestinationPortEnd, + }, nil +} + +// TCPMatcher matches TCP packets and their headers. It implements Matcher. +type TCPMatcher struct { + sourcePortStart uint16 + sourcePortEnd uint16 + destinationPortStart uint16 + destinationPortEnd uint16 +} + +// Name implements Matcher.Name. +func (*TCPMatcher) Name() string { + return matcherNameTCP +} + +// Match implements Matcher.Match. +func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) { + netHeader := header.IPv4(pkt.NetworkHeader) + + if netHeader.TransportProtocol() != header.TCPProtocolNumber { + return false, false + } + + // We dont't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + return false, true + } + return false, false + } + + // Now we need the transport header. However, this may not have been set + // yet. + // TODO(gvisor.dev/issue/170): Parsing the transport header should + // ultimately be moved into the iptables.Check codepath as matchers are + // added. + var tcpHeader header.TCP + if pkt.TransportHeader != nil { + tcpHeader = header.TCP(pkt.TransportHeader) + } else { + // The TCP header hasn't been parsed yet. We have to do it here. + if len(pkt.Data.First()) < header.TCPMinimumSize { + // There's no valid TCP header here, so we hotdrop the + // packet. + return false, true + } + tcpHeader = header.TCP(pkt.Data.First()) + } + + // Check whether the source and destination ports are within the + // matching range. + if sourcePort := tcpHeader.SourcePort(); sourcePort < tm.sourcePortStart || tm.sourcePortEnd < sourcePort { + return false, false + } + if destinationPort := tcpHeader.DestinationPort(); destinationPort < tm.destinationPortStart || tm.destinationPortEnd < destinationPort { + return false, false + } + + return true, false +} diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go new file mode 100644 index 000000000..b6e95bbc5 --- /dev/null +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -0,0 +1,142 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/usermem" +) + +const matcherNameUDP = "udp" + +func init() { + registerMatchMarshaler(udpMarshaler{}) +} + +// udpMarshaler implements matchMarshaler for UDP matching. +type udpMarshaler struct{} + +// name implements matchMarshaler.name. +func (udpMarshaler) name() string { + return matcherNameUDP +} + +// marshal implements matchMarshaler.marshal. +func (udpMarshaler) marshal(mr iptables.Matcher) []byte { + matcher := mr.(*UDPMatcher) + xtudp := linux.XTUDP{ + SourcePortStart: matcher.sourcePortStart, + SourcePortEnd: matcher.sourcePortEnd, + DestinationPortStart: matcher.destinationPortStart, + DestinationPortEnd: matcher.destinationPortEnd, + } + buf := make([]byte, 0, linux.SizeOfXTUDP) + return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, usermem.ByteOrder, xtudp)) +} + +// unmarshal implements matchMarshaler.unmarshal. +func (udpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) { + if len(buf) < linux.SizeOfXTUDP { + return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf)) + } + + // For alignment reasons, the match's total size may exceed what's + // strictly necessary to hold matchData. + var matchData linux.XTUDP + binary.Unmarshal(buf[:linux.SizeOfXTUDP], usermem.ByteOrder, &matchData) + nflog("parseMatchers: parsed XTUDP: %+v", matchData) + + if matchData.InverseFlags != 0 { + return nil, fmt.Errorf("unsupported UDP matcher inverse flags set") + } + + if filter.Protocol != header.UDPProtocolNumber { + return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber) + } + + return &UDPMatcher{ + sourcePortStart: matchData.SourcePortStart, + sourcePortEnd: matchData.SourcePortEnd, + destinationPortStart: matchData.DestinationPortStart, + destinationPortEnd: matchData.DestinationPortEnd, + }, nil +} + +// UDPMatcher matches UDP packets and their headers. It implements Matcher. +type UDPMatcher struct { + sourcePortStart uint16 + sourcePortEnd uint16 + destinationPortStart uint16 + destinationPortEnd uint16 +} + +// Name implements Matcher.Name. +func (*UDPMatcher) Name() string { + return matcherNameUDP +} + +// Match implements Matcher.Match. +func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) { + netHeader := header.IPv4(pkt.NetworkHeader) + + // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved + // into the iptables.Check codepath as matchers are added. + if netHeader.TransportProtocol() != header.UDPProtocolNumber { + return false, false + } + + // We dont't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + return false, true + } + return false, false + } + + // Now we need the transport header. However, this may not have been set + // yet. + // TODO(gvisor.dev/issue/170): Parsing the transport header should + // ultimately be moved into the iptables.Check codepath as matchers are + // added. + var udpHeader header.UDP + if pkt.TransportHeader != nil { + udpHeader = header.UDP(pkt.TransportHeader) + } else { + // The UDP header hasn't been parsed yet. We have to do it here. + if len(pkt.Data.First()) < header.UDPMinimumSize { + // There's no valid UDP header here, so we hotdrop the + // packet. + return false, true + } + udpHeader = header.UDP(pkt.Data.First()) + } + + // Check whether the source and destination ports are within the + // matching range. + if sourcePort := udpHeader.SourcePort(); sourcePort < um.sourcePortStart || um.sourcePortEnd < sourcePort { + return false, false + } + if destinationPort := udpHeader.DestinationPort(); destinationPort < um.destinationPortStart || um.destinationPortEnd < destinationPort { + return false, false + } + + return true, false +} diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index bab26580b..d1b73cfdf 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -8,7 +8,6 @@ go_library( "iptables.go", "targets.go", "types.go", - "udp_matcher.go", ], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 7f77802a0..d660aab04 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -171,6 +171,9 @@ type IPHeaderFilter struct { // A Matcher is the interface for matching packets. type Matcher interface { + // Name returns the name of the Matcher. + Name() string + // Match returns whether the packet matches and whether the packet // should be "hotdropped", i.e. dropped immediately. This is usually // used for suspicious packets. diff --git a/pkg/tcpip/iptables/udp_matcher.go b/pkg/tcpip/iptables/udp_matcher.go deleted file mode 100644 index 3bb076f9c..000000000 --- a/pkg/tcpip/iptables/udp_matcher.go +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package iptables - -import ( - "fmt" - - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" -) - -// TODO(gvisor.dev/issue/170): The following per-matcher params should be -// supported: -// - Table name -// - Match size -// - User size -// - Hooks -// - Proto -// - Family - -// UDPMatcher matches UDP packets and their headers. It implements Matcher. -type UDPMatcher struct { - Data UDPMatcherParams -} - -// UDPMatcherParams are the parameters used to create a UDPMatcher. -type UDPMatcherParams struct { - SourcePortStart uint16 - SourcePortEnd uint16 - DestinationPortStart uint16 - DestinationPortEnd uint16 - InverseFlags uint8 -} - -// NewUDPMatcher returns a new instance of UDPMatcher. -func NewUDPMatcher(filter IPHeaderFilter, data UDPMatcherParams) (Matcher, error) { - log.Infof("Adding rule with UDPMatcherParams: %+v", data) - - if data.InverseFlags != 0 { - return nil, fmt.Errorf("unsupported UDP matcher inverse flags set") - } - - if filter.Protocol != header.UDPProtocolNumber { - return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber) - } - - return &UDPMatcher{Data: data}, nil -} - -// Match implements Matcher.Match. -func (um *UDPMatcher) Match(hook Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader) - - // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved - // into the iptables.Check codepath as matchers are added. - if netHeader.TransportProtocol() != header.UDPProtocolNumber { - return false, false - } - - // We dont't match fragments. - if frag := netHeader.FragmentOffset(); frag != 0 { - if frag == 1 { - log.Warningf("Dropping UDP packet: malicious fragmented packet.") - return false, true - } - return false, false - } - - // Now we need the transport header. However, this may not have been set - // yet. - // TODO(gvisor.dev/issue/170): Parsing the transport header should - // ultimately be moved into the iptables.Check codepath as matchers are - // added. - var udpHeader header.UDP - if pkt.TransportHeader != nil { - udpHeader = header.UDP(pkt.TransportHeader) - } else { - // The UDP header hasn't been parsed yet. We have to do it here. - if len(pkt.Data.First()) < header.UDPMinimumSize { - // There's no valid UDP header here, so we hotdrop the - // packet. - log.Warningf("Dropping UDP packet: size too small.") - return false, true - } - udpHeader = header.UDP(pkt.Data.First()) - } - - // Check whether the source and destination ports are within the - // matching range. - sourcePort := udpHeader.SourcePort() - destinationPort := udpHeader.DestinationPort() - if sourcePort < um.Data.SourcePortStart || um.Data.SourcePortEnd < sourcePort { - return false, false - } - if destinationPort < um.Data.DestinationPortStart || um.Data.DestinationPortEnd < destinationPort { - return false, false - } - - return true, false -} -- cgit v1.2.3 From 77bf586db75b3dbd9dcb14c349bde8372d26425c Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 31 Jan 2020 13:54:57 -0800 Subject: Use multicast Ethernet address for multicast NDP As per RFC 2464 section 7, an IPv6 packet with a multicast destination address is transmitted to the mapped Ethernet multicast address. Test: - ipv6.TestLinkResolution - stack_test.TestDADResolve - stack_test.TestRouterSolicitation PiperOrigin-RevId: 292610529 --- pkg/tcpip/header/ipv6_test.go | 29 ++++++++++++++++++++++ pkg/tcpip/link/channel/channel.go | 29 ++++++++++++++-------- pkg/tcpip/network/ipv6/icmp.go | 6 ++++- pkg/tcpip/network/ipv6/icmp_test.go | 12 ++++++--- pkg/tcpip/stack/ndp.go | 17 +++++++++++++ pkg/tcpip/stack/ndp_test.go | 16 +++++++++++- pkg/tcpip/stack/route.go | 4 ++- pkg/tcpip/transport/tcp/testing/context/context.go | 6 ++++- 8 files changed, 101 insertions(+), 18 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index 29f54bc57..c3ad503aa 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -17,6 +17,7 @@ package header_test import ( "bytes" "crypto/sha256" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -300,3 +301,31 @@ func TestScopeForIPv6Address(t *testing.T) { }) } } + +func TestSolicitedNodeAddr(t *testing.T) { + tests := []struct { + addr tcpip.Address + want tcpip.Address + }{ + { + addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\xa0", + want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0", + }, + { + addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x0e\x0f\xa0", + want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0", + }, + { + addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x01\x02\x03", + want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x01\x02\x03", + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) { + if got := header.SolicitedNodeAddr(test.addr); got != test.want { + t.Fatalf("got header.SolicitedNodeAddr(%s) = %s, want = %s", test.addr, got, test.want) + } + }) + } +} diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 71b9da797..78d447acd 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -30,15 +30,16 @@ type PacketInfo struct { Pkt tcpip.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO + Route stack.Route } // Endpoint is link layer endpoint that stores outbound packets in a channel // and allows injection of inbound packets. type Endpoint struct { - dispatcher stack.NetworkDispatcher - mtu uint32 - linkAddr tcpip.LinkAddress - GSO bool + dispatcher stack.NetworkDispatcher + mtu uint32 + linkAddr tcpip.LinkAddress + LinkEPCapabilities stack.LinkEndpointCapabilities // c is where outbound packets are queued. c chan PacketInfo @@ -122,11 +123,7 @@ func (e *Endpoint) MTU() uint32 { // Capabilities implements stack.LinkEndpoint.Capabilities. func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { - caps := stack.LinkEndpointCapabilities(0) - if e.GSO { - caps |= stack.CapabilityHardwareGSO - } - return caps + return e.LinkEPCapabilities } // GSOMaxSize returns the maximum GSO packet size. @@ -146,11 +143,16 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { } // WritePacket stores outbound packets into the channel. -func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) *tcpip.Error { + // Clone r then release its resource so we only get the relevant fields from + // stack.Route without holding a reference to a NIC's endpoint. + route := r.Clone() + route.Release() p := PacketInfo{ Pkt: pkt, Proto: protocol, GSO: gso, + Route: route, } select { @@ -162,7 +164,11 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.Ne } // WritePackets stores outbound packets into the channel. -func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.PacketBuffer, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + // Clone r then release its resource so we only get the relevant fields from + // stack.Route without holding a reference to a NIC's endpoint. + route := r.Clone() + route.Release() payloadView := pkts[0].Data.ToView() n := 0 packetLoop: @@ -176,6 +182,7 @@ packetLoop: }, Proto: protocol, GSO: gso, + Route: route, } select { diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 7491cfc41..60817d36d 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -408,10 +408,14 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { // LinkAddressRequest implements stack.LinkAddressResolver. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { snaddr := header.SolicitedNodeAddr(addr) + + // TODO(b/148672031): Use stack.FindRoute instead of manually creating the + // route here. Note, we would need the nicID to do this properly so the right + // NIC (associated to linkEP) is used to send the NDP NS message. r := &stack.Route{ LocalAddress: localAddr, RemoteAddress: snaddr, - RemoteLinkAddress: broadcastMAC, + RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr), } hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 7a6820643..d0e930e20 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -270,8 +270,9 @@ func (c *testContext) cleanup() { } type routeArgs struct { - src, dst *channel.Endpoint - typ header.ICMPv6Type + src, dst *channel.Endpoint + typ header.ICMPv6Type + remoteLinkAddr tcpip.LinkAddress } func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { @@ -292,6 +293,11 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. t.Errorf("unexpected protocol number %d", pi.Proto) return } + + if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress { + t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) + } + ipv6 := header.IPv6(pi.Pkt.Header.View()) transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) if transProto != header.ICMPv6ProtocolNumber { @@ -339,7 +345,7 @@ func TestLinkResolution(t *testing.T) { t.Fatalf("ep.Write(_) = _, , %s, want = _, , tcpip.ErrNoLinkAddress", err) } for _, args := range []routeArgs{ - {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit}, + {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, } { routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 31294345d..6123fda33 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -538,6 +538,14 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() + // Route should resolve immediately since snmc is a multicast address so a + // remote link address can be calculated without a resolution process. + if c, err := r.Resolve(nil); err != nil { + log.Fatalf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err) + } else if c != nil { + log.Fatalf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()) + } + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) pkt.SetType(header.ICMPv6NeighborSolicit) @@ -1197,6 +1205,15 @@ func (ndp *ndpState) startSolicitingRouters() { r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() + // Route should resolve immediately since + // header.IPv6AllRoutersMulticastAddress is a multicast address so a + // remote link address can be calculated without a resolution process. + if c, err := r.Resolve(nil); err != nil { + log.Fatalf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err) + } else if c != nil { + log.Fatalf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()) + } + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize) pkt := header.ICMPv6(hdr.Prepend(payloadSize)) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index bc7cfbcb4..8af8565f7 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -336,6 +336,7 @@ func TestDADResolve(t *testing.T) { opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(opts) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -413,6 +414,12 @@ func TestDADResolve(t *testing.T) { t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) } + // Make sure the right remote link address is used. + snmc := header.SolicitedNodeAddr(addr1) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + // Check NDP NS packet. // // As per RFC 4861 section 4.3, a possible option is the Source Link @@ -420,7 +427,7 @@ func TestDADResolve(t *testing.T) { // address of the packet is the unspecified address. checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.SolicitedNodeAddr(addr1)), + checker.DstAddr(snmc), checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(addr1), @@ -3292,6 +3299,7 @@ func TestRouterSolicitation(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired waitForPkt := func(timeout time.Duration) { t.Helper() ctx, _ := context.WithTimeout(context.Background(), timeout) @@ -3304,6 +3312,12 @@ func TestRouterSolicitation(t *testing.T) { if p.Proto != header.IPv6ProtocolNumber { t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) } + + // Make sure the right remote link address is used. + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + checker.IPv6(t, p.Pkt.Header.View(), checker.SrcAddr(header.IPv6Any), diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 517f4b941..f565aafb2 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -225,7 +225,9 @@ func (r *Route) Release() { // Clone Clone a route such that the original one can be released and the new // one will remain valid. func (r *Route) Clone() Route { - r.ref.incRef() + if r.ref != nil { + r.ref.incRef() + } return *r } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 730ac4292..1e9a0dea3 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -1082,7 +1082,11 @@ func (c *Context) SACKEnabled() bool { // SetGSOEnabled enables or disables generic segmentation offload. func (c *Context) SetGSOEnabled(enable bool) { - c.linkEP.GSO = enable + if enable { + c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO + } else { + c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO + } } // MSSWithoutOptions returns the value for the MSS used by the stack when no -- cgit v1.2.3 From 02997af5abd62d778fca4d01b047a6bdebab2090 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Fri, 31 Jan 2020 15:08:17 -0800 Subject: Fix method comment to match method name. PiperOrigin-RevId: 292624867 --- pkg/tcpip/transport/tcp/accept.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 6101f2945..08afb7c17 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -271,8 +271,8 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i return n, nil } -// createEndpoint creates a new endpoint in connected state and then performs -// the TCP 3-way handshake. +// createEndpointAndPerformHandshake creates a new endpoint in connected state +// and then performs the TCP 3-way handshake. func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber -- cgit v1.2.3 From a26a954946ad2e7910d3ad7578960a93b73a1f9b Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Tue, 4 Feb 2020 15:20:30 -0800 Subject: Add socket connection stress test. Tests 65k connection attempts on common types of sockets to check for port leaks. Also fixes a bug where dual-stack sockets wouldn't properly re-queue segments received while closing. PiperOrigin-RevId: 293241166 --- pkg/tcpip/transport/tcp/connect.go | 4 ++ test/syscalls/BUILD | 9 +++ test/syscalls/linux/BUILD | 17 +++++ test/syscalls/linux/ip_socket_test_util.cc | 27 +++++++ test/syscalls/linux/ip_socket_test_util.h | 15 ++++ test/syscalls/linux/socket_generic_stress.cc | 83 ++++++++++++++++++++++ test/syscalls/linux/socket_test_util.cc | 101 ++++++++++++++++++++++++--- test/syscalls/linux/socket_test_util.h | 6 ++ 8 files changed, 251 insertions(+), 11 deletions(-) create mode 100644 test/syscalls/linux/socket_generic_stress.cc (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 9ff7ac261..5c5397823 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -989,6 +989,10 @@ func (e *endpoint) transitionToStateCloseLocked() { // to any other listening endpoint. We reply with RST if we cannot find one. func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) + if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" { + // Dual-stack socket, try IPv4. + ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route) + } if ep == nil { replyWithReset(s) s.decRef() diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 8f2b75a1c..31d239c0e 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -45,6 +45,15 @@ syscall_test(test = "//test/syscalls/linux:brk_test") syscall_test(test = "//test/syscalls/linux:socket_test") +syscall_test( + size = "large", + shard_count = 50, + # Takes too long for TSAN. Since this is kind of a stress test that doesn't + # involve much concurrency, TSAN's usefulness here is limited anyway. + tags = ["nogotsan"], + test = "//test/syscalls/linux:socket_stress_test", +) + syscall_test( add_overlay = True, test = "//test/syscalls/linux:chdir_test", diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 737e2329f..273b014d6 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -136,6 +136,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", "//test/util:file_descriptor", "//test/util:posix_error", "//test/util:temp_path", @@ -2151,6 +2152,22 @@ cc_library( alwayslink = 1, ) +cc_binary( + name = "socket_stress_test", + testonly = 1, + srcs = [ + "socket_generic_stress.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_test_util", + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + cc_library( name = "socket_unix_dgram_test_cases", testonly = 1, diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc index 6b472eb2f..bba022a41 100644 --- a/test/syscalls/linux/ip_socket_test_util.cc +++ b/test/syscalls/linux/ip_socket_test_util.cc @@ -79,6 +79,33 @@ SocketPairKind DualStackTCPAcceptBindSocketPair(int type) { /* dual_stack = */ true)}; } +SocketPairKind IPv6TCPAcceptBindPersistentListenerSocketPair(int type) { + std::string description = + absl::StrCat(DescribeSocketType(type), "connected IPv6 TCP socket"); + return SocketPairKind{description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP, + TCPAcceptBindPersistentListenerSocketPairCreator( + AF_INET6, type | SOCK_STREAM, 0, + /* dual_stack = */ false)}; +} + +SocketPairKind IPv4TCPAcceptBindPersistentListenerSocketPair(int type) { + std::string description = + absl::StrCat(DescribeSocketType(type), "connected IPv4 TCP socket"); + return SocketPairKind{description, AF_INET, type | SOCK_STREAM, IPPROTO_TCP, + TCPAcceptBindPersistentListenerSocketPairCreator( + AF_INET, type | SOCK_STREAM, 0, + /* dual_stack = */ false)}; +} + +SocketPairKind DualStackTCPAcceptBindPersistentListenerSocketPair(int type) { + std::string description = + absl::StrCat(DescribeSocketType(type), "connected dual stack TCP socket"); + return SocketPairKind{description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP, + TCPAcceptBindPersistentListenerSocketPairCreator( + AF_INET6, type | SOCK_STREAM, 0, + /* dual_stack = */ true)}; +} + SocketPairKind IPv6UDPBidirectionalBindSocketPair(int type) { std::string description = absl::StrCat(DescribeSocketType(type), "connected IPv6 UDP socket"); diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h index 0f58e0f77..083ebbcf0 100644 --- a/test/syscalls/linux/ip_socket_test_util.h +++ b/test/syscalls/linux/ip_socket_test_util.h @@ -50,6 +50,21 @@ SocketPairKind IPv4TCPAcceptBindSocketPair(int type); // given type bound to the IPv4 loopback. SocketPairKind DualStackTCPAcceptBindSocketPair(int type); +// IPv6TCPAcceptBindPersistentListenerSocketPair is like +// IPv6TCPAcceptBindSocketPair except it uses a persistent listening socket to +// create all socket pairs. +SocketPairKind IPv6TCPAcceptBindPersistentListenerSocketPair(int type); + +// IPv4TCPAcceptBindPersistentListenerSocketPair is like +// IPv4TCPAcceptBindSocketPair except it uses a persistent listening socket to +// create all socket pairs. +SocketPairKind IPv4TCPAcceptBindPersistentListenerSocketPair(int type); + +// DualStackTCPAcceptBindPersistentListenerSocketPair is like +// DualStackTCPAcceptBindSocketPair except it uses a persistent listening socket +// to create all socket pairs. +SocketPairKind DualStackTCPAcceptBindPersistentListenerSocketPair(int type); + // IPv6UDPBidirectionalBindSocketPair returns a SocketPairKind that represents // SocketPairs created with bind() and connect() syscalls with AF_INET6 and the // given type bound to the IPv6 loopback. diff --git a/test/syscalls/linux/socket_generic_stress.cc b/test/syscalls/linux/socket_generic_stress.cc new file mode 100644 index 000000000..6a232238d --- /dev/null +++ b/test/syscalls/linux/socket_generic_stress.cc @@ -0,0 +1,83 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to pairs of connected sockets. +using ConnectStressTest = SocketPairTest; + +TEST_P(ConnectStressTest, Reset65kTimes) { + for (int i = 0; i < 1 << 16; ++i) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Send some data to ensure that the connection gets reset and the port gets + // released immediately. This avoids either end entering TIME-WAIT. + char sent_data[100] = {}; + ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)), + SyscallSucceedsWithValue(sizeof(sent_data))); + } +} + +INSTANTIATE_TEST_SUITE_P( + AllConnectedSockets, ConnectStressTest, + ::testing::Values(IPv6UDPBidirectionalBindSocketPair(0), + IPv4UDPBidirectionalBindSocketPair(0), + DualStackUDPBidirectionalBindSocketPair(0), + + // Without REUSEADDR, we get port exhaustion on Linux. + SetSockOpt(SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn)(IPv6TCPAcceptBindSocketPair(0)), + SetSockOpt(SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn)(IPv4TCPAcceptBindSocketPair(0)), + SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)( + DualStackTCPAcceptBindSocketPair(0)))); + +// Test fixture for tests that apply to pairs of connected sockets created with +// a persistent listener (if applicable). +using PersistentListenerConnectStressTest = SocketPairTest; + +TEST_P(PersistentListenerConnectStressTest, 65kTimes) { + for (int i = 0; i < 1 << 16; ++i) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + } +} + +INSTANTIATE_TEST_SUITE_P( + AllConnectedSockets, PersistentListenerConnectStressTest, + ::testing::Values( + IPv6UDPBidirectionalBindSocketPair(0), + IPv4UDPBidirectionalBindSocketPair(0), + DualStackUDPBidirectionalBindSocketPair(0), + + // Without REUSEADDR, we get port exhaustion on Linux. + SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)( + IPv6TCPAcceptBindPersistentListenerSocketPair(0)), + SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)( + IPv4TCPAcceptBindPersistentListenerSocketPair(0)), + SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)( + DualStackTCPAcceptBindPersistentListenerSocketPair(0)))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index eff7d577e..c0c5ab3fe 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -18,10 +18,13 @@ #include #include +#include + #include "gtest/gtest.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/time/clock.h" +#include "absl/types/optional.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" @@ -109,7 +112,10 @@ Creator AcceptBindSocketPairCreator(bool abstract, int domain, MaybeSave(); // Unlinked path. } - return absl::make_unique(connected, accepted, bind_addr, + // accepted is before connected to destruct connected before accepted. + // Destructors for nonstatic member objects are called in the reverse order + // in which they appear in the class declaration. + return absl::make_unique(accepted, connected, bind_addr, extra_addr); }; } @@ -311,11 +317,16 @@ PosixErrorOr BindIP(int fd, bool dual_stack) { } template -PosixErrorOr> CreateTCPAcceptBindSocketPair( - int bound, int connected, int type, bool dual_stack) { - ASSIGN_OR_RETURN_ERRNO(T bind_addr, BindIP(bound, dual_stack)); - RETURN_ERROR_IF_SYSCALL_FAIL(listen(bound, /* backlog = */ 5)); +PosixErrorOr TCPBindAndListen(int fd, bool dual_stack) { + ASSIGN_OR_RETURN_ERRNO(T addr, BindIP(fd, dual_stack)); + RETURN_ERROR_IF_SYSCALL_FAIL(listen(fd, /* backlog = */ 5)); + return addr; +} +template +PosixErrorOr> +CreateTCPConnectAcceptSocketPair(int bound, int connected, int type, + bool dual_stack, T bind_addr) { int connect_result = 0; RETURN_ERROR_IF_SYSCALL_FAIL( (connect_result = RetryEINTR(connect)( @@ -358,16 +369,27 @@ PosixErrorOr> CreateTCPAcceptBindSocketPair( absl::SleepFor(absl::Seconds(1)); } - // Cleanup no longer needed resources. - RETURN_ERROR_IF_SYSCALL_FAIL(close(bound)); - MaybeSave(); // Successful close. - T extra_addr = {}; LocalhostAddr(&extra_addr, dual_stack); return absl::make_unique(connected, accepted, bind_addr, extra_addr); } +template +PosixErrorOr> CreateTCPAcceptBindSocketPair( + int bound, int connected, int type, bool dual_stack) { + ASSIGN_OR_RETURN_ERRNO(T bind_addr, TCPBindAndListen(bound, dual_stack)); + + auto result = CreateTCPConnectAcceptSocketPair(bound, connected, type, + dual_stack, bind_addr); + + // Cleanup no longer needed resources. + RETURN_ERROR_IF_SYSCALL_FAIL(close(bound)); + MaybeSave(); // Successful close. + + return result; +} + Creator TCPAcceptBindSocketPairCreator(int domain, int type, int protocol, bool dual_stack) { @@ -389,6 +411,63 @@ Creator TCPAcceptBindSocketPairCreator(int domain, int type, }; } +Creator TCPAcceptBindPersistentListenerSocketPairCreator( + int domain, int type, int protocol, bool dual_stack) { + // These are lazily initialized below, on the first call to the returned + // lambda. These values are private to each returned lambda, but shared across + // invocations of a specific lambda. + // + // The sharing allows pairs created with the same parameters to share a + // listener. This prevents future connects from failing if the connecting + // socket selects a port which had previously been used by a listening socket + // that still has some connections in TIME-WAIT. + // + // The lazy initialization is to avoid creating sockets during parameter + // enumeration. This is important because parameters are enumerated during the + // build process where networking may not be available. + auto listener = std::make_shared>(absl::optional()); + auto addr4 = std::make_shared>( + absl::optional()); + auto addr6 = std::make_shared>( + absl::optional()); + + return [=]() -> PosixErrorOr> { + int connected; + RETURN_ERROR_IF_SYSCALL_FAIL(connected = socket(domain, type, protocol)); + MaybeSave(); // Successful socket creation. + + // Share the listener across invocations. + if (!listener->has_value()) { + int fd = socket(domain, type, protocol); + if (fd < 0) { + return PosixError(errno, absl::StrCat("socket(", domain, ", ", type, + ", ", protocol, ")")); + } + listener->emplace(fd); + MaybeSave(); // Successful socket creation. + } + + // Bind the listener once, but create a new connect/accept pair each + // time. + if (domain == AF_INET) { + if (!addr4->has_value()) { + addr4->emplace( + TCPBindAndListen(listener->value(), dual_stack) + .ValueOrDie()); + } + return CreateTCPConnectAcceptSocketPair(listener->value(), connected, + type, dual_stack, addr4->value()); + } + if (!addr6->has_value()) { + addr6->emplace( + TCPBindAndListen(listener->value(), dual_stack) + .ValueOrDie()); + } + return CreateTCPConnectAcceptSocketPair(listener->value(), connected, type, + dual_stack, addr6->value()); + }; +} + template PosixErrorOr> CreateUDPBoundSocketPair( int sock1, int sock2, int type, bool dual_stack) { @@ -518,8 +597,8 @@ size_t CalculateUnixSockAddrLen(const char* sun_path) { if (sun_path[0] == 0) { return sizeof(sockaddr_un); } - // Filesystem addresses use the address length plus the 2 byte sun_family and - // null terminator. + // Filesystem addresses use the address length plus the 2 byte sun_family + // and null terminator. return strlen(sun_path) + 3; } diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index 2dbb8bed3..bfaa6e397 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -273,6 +273,12 @@ Creator TCPAcceptBindSocketPairCreator(int domain, int type, int protocol, bool dual_stack); +// TCPAcceptBindPersistentListenerSocketPairCreator is like +// TCPAcceptBindSocketPairCreator, except it uses the same listening socket to +// create all SocketPairs. +Creator TCPAcceptBindPersistentListenerSocketPairCreator( + int domain, int type, int protocol, bool dual_stack); + // UDPBidirectionalBindSocketPairCreator returns a Creator that // obtains file descriptors by invoking the bind() and connect() syscalls on UDP // sockets. -- cgit v1.2.3 From 665b614e4a6e715bac25bea15c5c29184016e549 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Tue, 4 Feb 2020 18:04:26 -0800 Subject: Support RTM_NEWADDR and RTM_GETLINK in (rt)netlink. PiperOrigin-RevId: 293271055 --- pkg/sentry/inet/inet.go | 4 + pkg/sentry/inet/test_stack.go | 6 + pkg/sentry/socket/hostinet/stack.go | 5 + pkg/sentry/socket/netlink/BUILD | 14 +- pkg/sentry/socket/netlink/message.go | 129 +++++++++++ pkg/sentry/socket/netlink/message_test.go | 312 +++++++++++++++++++++++++++ pkg/sentry/socket/netlink/provider.go | 2 +- pkg/sentry/socket/netlink/route/BUILD | 2 - pkg/sentry/socket/netlink/route/protocol.go | 238 ++++++++++++++------ pkg/sentry/socket/netlink/socket.go | 54 ++--- pkg/sentry/socket/netlink/uevent/protocol.go | 2 +- pkg/sentry/socket/netstack/stack.go | 55 +++++ pkg/tcpip/stack/stack.go | 9 + test/syscalls/linux/BUILD | 2 + test/syscalls/linux/socket_netlink_route.cc | 296 ++++++++++++++++++++----- test/syscalls/linux/socket_netlink_util.cc | 45 +++- test/syscalls/linux/socket_netlink_util.h | 9 + 17 files changed, 1022 insertions(+), 162 deletions(-) create mode 100644 pkg/sentry/socket/netlink/message_test.go (limited to 'pkg/tcpip') diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index a7dfb78a7..2916a0644 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -28,6 +28,10 @@ type Stack interface { // interface indexes to a slice of associated interface address properties. InterfaceAddrs() map[int32][]InterfaceAddr + // AddInterfaceAddr adds an address to the network interface identified by + // index. + AddInterfaceAddr(idx int32, addr InterfaceAddr) error + // SupportsIPv6 returns true if the stack supports IPv6 connectivity. SupportsIPv6() bool diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index dcfcbd97e..d8961fc94 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -47,6 +47,12 @@ func (s *TestStack) InterfaceAddrs() map[int32][]InterfaceAddr { return s.InterfaceAddrsMap } +// AddInterfaceAddr implements Stack.AddInterfaceAddr. +func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error { + s.InterfaceAddrsMap[idx] = append(s.InterfaceAddrsMap[idx], addr) + return nil +} + // SupportsIPv6 implements Stack.SupportsIPv6. func (s *TestStack) SupportsIPv6() bool { return s.SupportsIPv6Flag diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 034eca676..a48082631 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -310,6 +310,11 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return addrs } +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + return syserror.EACCES +} + // SupportsIPv6 implements inet.Stack.SupportsIPv6. func (s *Stack) SupportsIPv6() bool { return s.supportsIPv6 diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index f8b8e467d..1911cd9b8 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -33,3 +33,15 @@ go_library( "//pkg/waiter", ], ) + +go_test( + name = "netlink_test", + size = "small", + srcs = [ + "message_test.go", + ], + deps = [ + ":netlink", + "//pkg/abi/linux", + ], +) diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go index b21e0ca4b..4ea252ccb 100644 --- a/pkg/sentry/socket/netlink/message.go +++ b/pkg/sentry/socket/netlink/message.go @@ -30,8 +30,16 @@ func alignUp(length int, align uint) int { return (length + int(align) - 1) &^ (int(align) - 1) } +// alignPad returns the length of padding required for alignment. +// +// Preconditions: align is a power of two. +func alignPad(length int, align uint) int { + return alignUp(length, align) - length +} + // Message contains a complete serialized netlink message. type Message struct { + hdr linux.NetlinkMessageHeader buf []byte } @@ -40,10 +48,86 @@ type Message struct { // The header length will be updated by Finalize. func NewMessage(hdr linux.NetlinkMessageHeader) *Message { return &Message{ + hdr: hdr, buf: binary.Marshal(nil, usermem.ByteOrder, hdr), } } +// ParseMessage parses the first message seen at buf, returning the rest of the +// buffer. If message is malformed, ok of false is returned. For last message, +// padding check is loose, if there isn't enought padding, whole buf is consumed +// and ok is set to true. +func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) { + b := BytesView(buf) + + hdrBytes, ok := b.Extract(linux.NetlinkMessageHeaderSize) + if !ok { + return + } + var hdr linux.NetlinkMessageHeader + binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr) + + // Msg portion. + totalMsgLen := int(hdr.Length) + _, ok = b.Extract(totalMsgLen - linux.NetlinkMessageHeaderSize) + if !ok { + return + } + + // Padding. + numPad := alignPad(totalMsgLen, linux.NLMSG_ALIGNTO) + // Linux permits the last message not being aligned, just consume all of it. + // Ref: net/netlink/af_netlink.c:netlink_rcv_skb + if numPad > len(b) { + numPad = len(b) + } + _, ok = b.Extract(numPad) + if !ok { + return + } + + return &Message{ + hdr: hdr, + buf: buf[:totalMsgLen], + }, []byte(b), true +} + +// Header returns the header of this message. +func (m *Message) Header() linux.NetlinkMessageHeader { + return m.hdr +} + +// GetData unmarshals the payload message header from this netlink message, and +// returns the attributes portion. +func (m *Message) GetData(msg interface{}) (AttrsView, bool) { + b := BytesView(m.buf) + + _, ok := b.Extract(linux.NetlinkMessageHeaderSize) + if !ok { + return nil, false + } + + size := int(binary.Size(msg)) + msgBytes, ok := b.Extract(size) + if !ok { + return nil, false + } + binary.Unmarshal(msgBytes, usermem.ByteOrder, msg) + + numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO) + // Linux permits the last message not being aligned, just consume all of it. + // Ref: net/netlink/af_netlink.c:netlink_rcv_skb + if numPad > len(b) { + numPad = len(b) + } + _, ok = b.Extract(numPad) + if !ok { + return nil, false + } + + return AttrsView(b), true +} + // Finalize returns the []byte containing the entire message, with the total // length set in the message header. The Message must not be modified after // calling Finalize. @@ -157,3 +241,48 @@ func (ms *MessageSet) AddMessage(hdr linux.NetlinkMessageHeader) *Message { ms.Messages = append(ms.Messages, m) return m } + +// AttrsView is a view into the attributes portion of a netlink message. +type AttrsView []byte + +// Empty returns whether there is no attribute left in v. +func (v AttrsView) Empty() bool { + return len(v) == 0 +} + +// ParseFirst parses first netlink attribute at the beginning of v. +func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest AttrsView, ok bool) { + b := BytesView(v) + + hdrBytes, ok := b.Extract(linux.NetlinkAttrHeaderSize) + if !ok { + return + } + binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr) + + value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize) + if !ok { + return + } + + _, ok = b.Extract(alignPad(int(hdr.Length), linux.NLA_ALIGNTO)) + if !ok { + return + } + + return hdr, value, AttrsView(b), ok +} + +// BytesView supports extracting data from a byte slice with bounds checking. +type BytesView []byte + +// Extract removes the first n bytes from v and returns it. If n is out of +// bounds, it returns false. +func (v *BytesView) Extract(n int) ([]byte, bool) { + if n < 0 || n > len(*v) { + return nil, false + } + extracted := (*v)[:n] + *v = (*v)[n:] + return extracted, true +} diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go new file mode 100644 index 000000000..ef13d9386 --- /dev/null +++ b/pkg/sentry/socket/netlink/message_test.go @@ -0,0 +1,312 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message_test + +import ( + "bytes" + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/socket/netlink" +) + +type dummyNetlinkMsg struct { + Foo uint16 +} + +func TestParseMessage(t *testing.T) { + tests := []struct { + desc string + input []byte + + header linux.NetlinkMessageHeader + dataMsg *dummyNetlinkMsg + restLen int + ok bool + }{ + { + desc: "valid", + input: []byte{ + 0x14, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + }, + header: linux.NetlinkMessageHeader{ + Length: 20, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 0, + ok: true, + }, + { + desc: "valid with next message", + input: []byte{ + 0x14, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + 0xFF, // Next message (rest) + }, + header: linux.NetlinkMessageHeader{ + Length: 20, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 1, + ok: true, + }, + { + desc: "valid for last message without padding", + input: []byte{ + 0x12, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, // Data message + }, + header: linux.NetlinkMessageHeader{ + Length: 18, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 0, + ok: true, + }, + { + desc: "valid for last message not to be aligned", + input: []byte{ + 0x13, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, // Data message + 0x00, // Excessive 1 byte permitted at end + }, + header: linux.NetlinkMessageHeader{ + Length: 19, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 0, + ok: true, + }, + { + desc: "header.Length too short", + input: []byte{ + 0x04, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + }, + ok: false, + }, + { + desc: "header.Length too long", + input: []byte{ + 0xFF, 0xFF, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + }, + ok: false, + }, + { + desc: "header incomplete", + input: []byte{ + 0x04, 0x00, 0x00, 0x00, // Length + }, + ok: false, + }, + { + desc: "empty message", + input: []byte{}, + ok: false, + }, + } + for _, test := range tests { + msg, rest, ok := netlink.ParseMessage(test.input) + if ok != test.ok { + t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok) + continue + } + if !test.ok { + continue + } + if !reflect.DeepEqual(msg.Header(), test.header) { + t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header) + } + + dataMsg := &dummyNetlinkMsg{} + _, dataOk := msg.GetData(dataMsg) + if !dataOk { + t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk) + } else if !reflect.DeepEqual(dataMsg, test.dataMsg) { + t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg) + } + + if got, want := rest, test.input[len(test.input)-test.restLen:]; !bytes.Equal(got, want) { + t.Errorf("%v: got rest = %v, want = %v", test.desc, got, want) + } + } +} + +func TestAttrView(t *testing.T) { + tests := []struct { + desc string + input []byte + + // Outputs for ParseFirst. + hdr linux.NetlinkAttrHeader + value []byte + restLen int + ok bool + + // Outputs for Empty. + isEmpty bool + }{ + { + desc: "valid", + input: []byte{ + 0x06, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x00, 0x00, // Data with 2 bytes padding + }, + hdr: linux.NetlinkAttrHeader{ + Length: 6, + Type: 1, + }, + value: []byte{0x30, 0x31}, + restLen: 0, + ok: true, + isEmpty: false, + }, + { + desc: "at alignment", + input: []byte{ + 0x08, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + }, + hdr: linux.NetlinkAttrHeader{ + Length: 8, + Type: 1, + }, + value: []byte{0x30, 0x31, 0x32, 0x33}, + restLen: 0, + ok: true, + isEmpty: false, + }, + { + desc: "at alignment with rest data", + input: []byte{ + 0x08, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + 0xFF, 0xFE, // Rest data + }, + hdr: linux.NetlinkAttrHeader{ + Length: 8, + Type: 1, + }, + value: []byte{0x30, 0x31, 0x32, 0x33}, + restLen: 2, + ok: true, + isEmpty: false, + }, + { + desc: "hdr.Length too long", + input: []byte{ + 0xFF, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + }, + ok: false, + isEmpty: false, + }, + { + desc: "hdr.Length too short", + input: []byte{ + 0x01, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + }, + ok: false, + isEmpty: false, + }, + { + desc: "empty", + input: []byte{}, + ok: false, + isEmpty: true, + }, + } + for _, test := range tests { + attrs := netlink.AttrsView(test.input) + + // Test ParseFirst(). + hdr, value, rest, ok := attrs.ParseFirst() + if ok != test.ok { + t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok) + } else if test.ok { + if !reflect.DeepEqual(hdr, test.hdr) { + t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, hdr, test.hdr) + } + if !bytes.Equal(value, test.value) { + t.Errorf("%v: got value = %v, want = %v", test.desc, value, test.value) + } + if wantRest := test.input[len(test.input)-test.restLen:]; !bytes.Equal(rest, wantRest) { + t.Errorf("%v: got rest = %v, want = %v", test.desc, rest, wantRest) + } + } + + // Test Empty(). + if got, want := attrs.Empty(), test.isEmpty; got != want { + t.Errorf("%v: got empty = %v, want = %v", test.desc, got, want) + } + } +} diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 07f860a49..b0dc70e5c 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -42,7 +42,7 @@ type Protocol interface { // If err == nil, any messages added to ms will be sent back to the // other end of the socket. Setting ms.Multi will cause an NLMSG_DONE // message to be sent even if ms contains no messages. - ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *MessageSet) *syserr.Error + ProcessMessage(ctx context.Context, msg *Message, ms *MessageSet) *syserr.Error } // Provider is a function that creates a new Protocol for a specific netlink diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD index 622a1eafc..93127398d 100644 --- a/pkg/sentry/socket/netlink/route/BUILD +++ b/pkg/sentry/socket/netlink/route/BUILD @@ -10,13 +10,11 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/socket/netlink", "//pkg/syserr", - "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index 2b3c7f5b3..c84d8bd7c 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -17,16 +17,15 @@ package route import ( "bytes" + "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/usermem" ) // commandKind describes the operational class of a message type. @@ -69,13 +68,7 @@ func (p *Protocol) CanSend() bool { } // dumpLinks handles RTM_GETLINK dump requests. -func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { - // TODO(b/68878065): Only the dump variant of the types below are - // supported. - if hdr.Flags&linux.NLM_F_DUMP != linux.NLM_F_DUMP { - return syserr.ErrNotSupported - } - +func (p *Protocol) dumpLinks(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // NLM_F_DUMP + RTM_GETLINK messages are supposed to include an // ifinfomsg. However, Linux <3.9 only checked for rtgenmsg, and some // userspace applications (including glibc) still include rtgenmsg. @@ -99,44 +92,105 @@ func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader return nil } - for id, i := range stack.Interfaces() { - m := ms.AddMessage(linux.NetlinkMessageHeader{ - Type: linux.RTM_NEWLINK, - }) + for idx, i := range stack.Interfaces() { + addNewLinkMessage(ms, idx, i) + } - m.Put(linux.InterfaceInfoMessage{ - Family: linux.AF_UNSPEC, - Type: i.DeviceType, - Index: id, - Flags: i.Flags, - }) + return nil +} - m.PutAttrString(linux.IFLA_IFNAME, i.Name) - m.PutAttr(linux.IFLA_MTU, i.MTU) +// getLinks handles RTM_GETLINK requests. +func (p *Protocol) getLink(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network devices. + return nil + } - mac := make([]byte, 6) - brd := mac - if len(i.Addr) > 0 { - mac = i.Addr - brd = bytes.Repeat([]byte{0xff}, len(i.Addr)) + // Parse message. + var ifi linux.InterfaceInfoMessage + attrs, ok := msg.GetData(&ifi) + if !ok { + return syserr.ErrInvalidArgument + } + + // Parse attributes. + var byName []byte + for !attrs.Empty() { + ahdr, value, rest, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument } - m.PutAttr(linux.IFLA_ADDRESS, mac) - m.PutAttr(linux.IFLA_BROADCAST, brd) + attrs = rest - // TODO(gvisor.dev/issue/578): There are many more attributes. + switch ahdr.Type { + case linux.IFLA_IFNAME: + if len(value) < 1 { + return syserr.ErrInvalidArgument + } + byName = value[:len(value)-1] + + // TODO(gvisor.dev/issue/578): Support IFLA_EXT_MASK. + } } + found := false + for idx, i := range stack.Interfaces() { + switch { + case ifi.Index > 0: + if idx != ifi.Index { + continue + } + case byName != nil: + if string(byName) != i.Name { + continue + } + default: + // Criteria not specified. + return syserr.ErrInvalidArgument + } + + addNewLinkMessage(ms, idx, i) + found = true + break + } + if !found { + return syserr.ErrNoDevice + } return nil } -// dumpAddrs handles RTM_GETADDR dump requests. -func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { - // TODO(b/68878065): Only the dump variant of the types below are - // supported. - if hdr.Flags&linux.NLM_F_DUMP != linux.NLM_F_DUMP { - return syserr.ErrNotSupported +// addNewLinkMessage appends RTM_NEWLINK message for the given interface into +// the message set. +func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.RTM_NEWLINK, + }) + + m.Put(linux.InterfaceInfoMessage{ + Family: linux.AF_UNSPEC, + Type: i.DeviceType, + Index: idx, + Flags: i.Flags, + }) + + m.PutAttrString(linux.IFLA_IFNAME, i.Name) + m.PutAttr(linux.IFLA_MTU, i.MTU) + + mac := make([]byte, 6) + brd := mac + if len(i.Addr) > 0 { + mac = i.Addr + brd = bytes.Repeat([]byte{0xff}, len(i.Addr)) } + m.PutAttr(linux.IFLA_ADDRESS, mac) + m.PutAttr(linux.IFLA_BROADCAST, brd) + + // TODO(gvisor.dev/issue/578): There are many more attributes. +} +// dumpAddrs handles RTM_GETADDR dump requests. +func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // RTM_GETADDR dump requests need not contain anything more than the // netlink header and 1 byte protocol family common to all // NETLINK_ROUTE requests. @@ -168,6 +222,7 @@ func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader Index: uint32(id), }) + m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr)) m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr)) // TODO(gvisor.dev/issue/578): There are many more attributes. @@ -252,12 +307,12 @@ func fillRoute(routes []inet.Route, addr []byte) (inet.Route, *syserr.Error) { } // parseForDestination parses a message as format of RouteMessage-RtAttr-dst. -func parseForDestination(data []byte) ([]byte, *syserr.Error) { +func parseForDestination(msg *netlink.Message) ([]byte, *syserr.Error) { var rtMsg linux.RouteMessage - if len(data) < linux.SizeOfRouteMessage { + attrs, ok := msg.GetData(&rtMsg) + if !ok { return nil, syserr.ErrInvalidArgument } - binary.Unmarshal(data[:linux.SizeOfRouteMessage], usermem.ByteOrder, &rtMsg) // iproute2 added the RTM_F_LOOKUP_TABLE flag in version v4.4.0. See // commit bc234301af12. Note we don't check this flag for backward // compatibility. @@ -265,26 +320,15 @@ func parseForDestination(data []byte) ([]byte, *syserr.Error) { return nil, syserr.ErrNotSupported } - data = data[linux.SizeOfRouteMessage:] - - // TODO(gvisor.dev/issue/1611): Add generic attribute parsing. - var rtAttr linux.RtAttr - if len(data) < linux.SizeOfRtAttr { - return nil, syserr.ErrInvalidArgument + // Expect first attribute is RTA_DST. + if hdr, value, _, ok := attrs.ParseFirst(); ok && hdr.Type == linux.RTA_DST { + return value, nil } - binary.Unmarshal(data[:linux.SizeOfRtAttr], usermem.ByteOrder, &rtAttr) - if rtAttr.Type != linux.RTA_DST { - return nil, syserr.ErrInvalidArgument - } - - if len(data) < int(rtAttr.Len) { - return nil, syserr.ErrInvalidArgument - } - return data[linux.SizeOfRtAttr:rtAttr.Len], nil + return nil, syserr.ErrInvalidArgument } // dumpRoutes handles RTM_GETROUTE requests. -func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { +func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // RTM_GETROUTE dump requests need not contain anything more than the // netlink header and 1 byte protocol family common to all // NETLINK_ROUTE requests. @@ -295,10 +339,11 @@ func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeade return nil } + hdr := msg.Header() routeTables := stack.RouteTable() if hdr.Flags == linux.NLM_F_REQUEST { - dst, err := parseForDestination(data) + dst, err := parseForDestination(msg) if err != nil { return err } @@ -357,10 +402,55 @@ func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeade return nil } +// newAddr handles RTM_NEWADDR requests. +func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network stack. + return syserr.ErrProtocolNotSupported + } + + var ifa linux.InterfaceAddrMessage + attrs, ok := msg.GetData(&ifa) + if !ok { + return syserr.ErrInvalidArgument + } + + for !attrs.Empty() { + ahdr, value, rest, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument + } + attrs = rest + + switch ahdr.Type { + case linux.IFA_LOCAL: + err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ + Family: ifa.Family, + PrefixLen: ifa.PrefixLen, + Flags: ifa.Flags, + Addr: value, + }) + if err == syscall.EEXIST { + flags := msg.Header().Flags + if flags&linux.NLM_F_EXCL != 0 { + return syserr.ErrExists + } + } else if err != nil { + return syserr.ErrInvalidArgument + } + } + } + return nil +} + // ProcessMessage implements netlink.Protocol.ProcessMessage. -func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { +func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + hdr := msg.Header() + // All messages start with a 1 byte protocol family. - if len(data) < 1 { + var family uint8 + if _, ok := msg.GetData(&family); !ok { // Linux ignores messages missing the protocol family. See // net/core/rtnetlink.c:rtnetlink_rcv_msg. return nil @@ -374,16 +464,32 @@ func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageH } } - switch hdr.Type { - case linux.RTM_GETLINK: - return p.dumpLinks(ctx, hdr, data, ms) - case linux.RTM_GETADDR: - return p.dumpAddrs(ctx, hdr, data, ms) - case linux.RTM_GETROUTE: - return p.dumpRoutes(ctx, hdr, data, ms) - default: - return syserr.ErrNotSupported + if hdr.Flags&linux.NLM_F_DUMP == linux.NLM_F_DUMP { + // TODO(b/68878065): Only the dump variant of the types below are + // supported. + switch hdr.Type { + case linux.RTM_GETLINK: + return p.dumpLinks(ctx, msg, ms) + case linux.RTM_GETADDR: + return p.dumpAddrs(ctx, msg, ms) + case linux.RTM_GETROUTE: + return p.dumpRoutes(ctx, msg, ms) + default: + return syserr.ErrNotSupported + } + } else if hdr.Flags&linux.NLM_F_REQUEST == linux.NLM_F_REQUEST { + switch hdr.Type { + case linux.RTM_GETLINK: + return p.getLink(ctx, msg, ms) + case linux.RTM_GETROUTE: + return p.dumpRoutes(ctx, msg, ms) + case linux.RTM_NEWADDR: + return p.newAddr(ctx, msg, ms) + default: + return syserr.ErrNotSupported + } } + return syserr.ErrNotSupported } // init registers the NETLINK_ROUTE provider. diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index c4b95debb..2ca02567d 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -644,47 +644,38 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error return nil } -func (s *Socket) dumpErrorMesage(ctx context.Context, hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) *syserr.Error { +func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) { m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.NLMSG_ERROR, }) - m.Put(linux.NetlinkErrorMessage{ Error: int32(-err.ToLinux().Number()), Header: hdr, }) - return nil +} +func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.NLMSG_ERROR, + }) + m.Put(linux.NetlinkErrorMessage{ + Error: 0, + Header: hdr, + }) } // processMessages handles each message in buf, passing it to the protocol // handler for final handling. func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error { for len(buf) > 0 { - if len(buf) < linux.NetlinkMessageHeaderSize { + msg, rest, ok := ParseMessage(buf) + if !ok { // Linux ignores messages that are too short. See // net/netlink/af_netlink.c:netlink_rcv_skb. break } - - var hdr linux.NetlinkMessageHeader - binary.Unmarshal(buf[:linux.NetlinkMessageHeaderSize], usermem.ByteOrder, &hdr) - - if hdr.Length < linux.NetlinkMessageHeaderSize || uint64(hdr.Length) > uint64(len(buf)) { - // Linux ignores malformed messages. See - // net/netlink/af_netlink.c:netlink_rcv_skb. - break - } - - // Data from this message. - data := buf[linux.NetlinkMessageHeaderSize:hdr.Length] - - // Advance to the next message. - next := alignUp(int(hdr.Length), linux.NLMSG_ALIGNTO) - if next >= len(buf)-1 { - next = len(buf) - 1 - } - buf = buf[next:] + buf = rest + hdr := msg.Header() // Ignore control messages. if hdr.Type < linux.NLMSG_MIN_TYPE { @@ -692,19 +683,10 @@ func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error } ms := NewMessageSet(s.portID, hdr.Seq) - var err *syserr.Error - // TODO(b/68877377): ACKs not supported yet. - if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK { - err = syserr.ErrNotSupported - } else { - - err = s.protocol.ProcessMessage(ctx, hdr, data, ms) - } - if err != nil { - ms = NewMessageSet(s.portID, hdr.Seq) - if err := s.dumpErrorMesage(ctx, hdr, ms, err); err != nil { - return err - } + if err := s.protocol.ProcessMessage(ctx, msg, ms); err != nil { + dumpErrorMesage(hdr, ms, err) + } else if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK { + dumpAckMesage(hdr, ms) } if err := s.sendResponse(ctx, ms); err != nil { diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go index 1ee4296bc..029ba21b5 100644 --- a/pkg/sentry/socket/netlink/uevent/protocol.go +++ b/pkg/sentry/socket/netlink/uevent/protocol.go @@ -49,7 +49,7 @@ func (p *Protocol) CanSend() bool { } // ProcessMessage implements netlink.Protocol.ProcessMessage. -func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { +func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // Silently ignore all messages. return nil } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 31ea66eca..0692482e9 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -20,6 +20,8 @@ import ( "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -88,6 +90,59 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return nicAddrs } +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + var ( + protocol tcpip.NetworkProtocolNumber + address tcpip.Address + ) + switch addr.Family { + case linux.AF_INET: + if len(addr.Addr) < header.IPv4AddressSize { + return syserror.EINVAL + } + if addr.PrefixLen > header.IPv4AddressSize*8 { + return syserror.EINVAL + } + protocol = ipv4.ProtocolNumber + address = tcpip.Address(addr.Addr[:header.IPv4AddressSize]) + + case linux.AF_INET6: + if len(addr.Addr) < header.IPv6AddressSize { + return syserror.EINVAL + } + if addr.PrefixLen > header.IPv6AddressSize*8 { + return syserror.EINVAL + } + protocol = ipv6.ProtocolNumber + address = tcpip.Address(addr.Addr[:header.IPv6AddressSize]) + + default: + return syserror.ENOTSUP + } + + protocolAddress := tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: int(addr.PrefixLen), + }, + } + + // Attach address to interface. + if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + return syserr.TranslateNetstackError(err).ToError() + } + + // Add route for local network. + s.Stack.AddRoute(tcpip.Route{ + Destination: protocolAddress.AddressWithPrefix.Subnet(), + Gateway: "", // No gateway for local network. + NIC: tcpip.NICID(idx), + }) + return nil +} + // TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize. func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { var rs tcp.ReceiveBufferSizeOption diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 7057b110e..b793f1d74 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -795,6 +795,8 @@ func (s *Stack) Forwarding() bool { // SetRouteTable assigns the route table to be used by this stack. It // specifies which NIC to use for given destination address ranges. +// +// This method takes ownership of the table. func (s *Stack) SetRouteTable(table []tcpip.Route) { s.mu.Lock() defer s.mu.Unlock() @@ -809,6 +811,13 @@ func (s *Stack) GetRouteTable() []tcpip.Route { return append([]tcpip.Route(nil), s.routeTable...) } +// AddRoute appends a route to the route table. +func (s *Stack) AddRoute(route tcpip.Route) { + s.mu.Lock() + defer s.mu.Unlock() + s.routeTable = append(s.routeTable, route) +} + // NewEndpoint creates a new transport layer endpoint of the given protocol. func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { t, ok := s.transportProtocols[transport] diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 273b014d6..f2e3c7072 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2769,9 +2769,11 @@ cc_binary( deps = [ ":socket_netlink_util", ":socket_test_util", + "//test/util:capability_util", "//test/util:cleanup", "//test/util:file_descriptor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", gtest, "//test/util:test_main", "//test/util:test_util", diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index 1e28e658d..e5aed1eec 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -25,8 +26,10 @@ #include "gtest/gtest.h" #include "absl/strings/str_format.h" +#include "absl/types/optional.h" #include "test/syscalls/linux/socket_netlink_util.h" #include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" #include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/test_util.h" @@ -38,6 +41,8 @@ namespace testing { namespace { +constexpr uint32_t kSeq = 12345; + using ::testing::AnyOf; using ::testing::Eq; @@ -113,58 +118,224 @@ void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) { // TODO(mpratt): Check ifinfomsg contents and following attrs. } +PosixError DumpLinks( + const FileDescriptor& fd, uint32_t seq, + const std::function& fn) { + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = seq; + req.ifm.ifi_family = AF_UNSPEC; + + return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false); +} + TEST(NetlinkRouteTest, GetLinkDump) { FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); + // Loopback is common among all tests, check that it's found. + bool loopbackFound = false; + ASSERT_NO_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { + CheckGetLinkResponse(hdr, kSeq, port); + if (hdr->nlmsg_type != RTM_NEWLINK) { + return; + } + ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + std::cout << "Found interface idx=" << msg->ifi_index + << ", type=" << std::hex << msg->ifi_type; + if (msg->ifi_type == ARPHRD_LOOPBACK) { + loopbackFound = true; + EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0); + } + })); + EXPECT_TRUE(loopbackFound); +} + +struct Link { + int index; + std::string name; +}; + +PosixErrorOr> FindLoopbackLink() { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + absl::optional link; + RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { + if (hdr->nlmsg_type != RTM_NEWLINK || + hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) { + return; + } + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + if (msg->ifi_type == ARPHRD_LOOPBACK) { + const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); + if (rta == nullptr) { + // Ignore links that do not have a name. + return; + } + + link = Link(); + link->index = msg->ifi_index; + link->name = std::string(reinterpret_cast(RTA_DATA(rta))); + } + })); + return link; +} + +// CheckLinkMsg checks a netlink message against an expected link. +void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) { + ASSERT_THAT(hdr->nlmsg_type, Eq(RTM_NEWLINK)); + ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + EXPECT_EQ(msg->ifi_index, link.index); + + const struct rtattr* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); + EXPECT_NE(nullptr, rta) << "IFLA_IFNAME not found in message."; + if (rta != nullptr) { + std::string name(reinterpret_cast(RTA_DATA(rta))); + EXPECT_EQ(name, link.name); + } +} + +TEST(NetlinkRouteTest, GetLinkByIndex) { + absl::optional loopback_link = + ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); + ASSERT_TRUE(loopback_link.has_value()); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + struct request { struct nlmsghdr hdr; struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETLINK; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_flags = NLM_F_REQUEST; req.hdr.nlmsg_seq = kSeq; req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = loopback_link->index; - // Loopback is common among all tests, check that it's found. - bool loopbackFound = false; + bool found = false; ASSERT_NO_ERRNO(NetlinkRequestResponse( fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { - CheckGetLinkResponse(hdr, kSeq, port); - if (hdr->nlmsg_type != RTM_NEWLINK) { - return; - } - ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg))); - const struct ifinfomsg* msg = - reinterpret_cast(NLMSG_DATA(hdr)); - std::cout << "Found interface idx=" << msg->ifi_index - << ", type=" << std::hex << msg->ifi_type; - if (msg->ifi_type == ARPHRD_LOOPBACK) { - loopbackFound = true; - EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0); - } + CheckLinkMsg(hdr, *loopback_link); + found = true; }, false)); - EXPECT_TRUE(loopbackFound); + EXPECT_TRUE(found) << "Netlink response does not contain any links."; } -TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { +TEST(NetlinkRouteTest, GetLinkByName) { + absl::optional loopback_link = + ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); + ASSERT_TRUE(loopback_link.has_value()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); struct request { struct nlmsghdr hdr; struct ifinfomsg ifm; + struct rtattr rtattr; + char ifname[IFNAMSIZ]; + char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; }; - constexpr uint32_t kSeq = 12345; + struct request req = {}; + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.rtattr.rta_type = IFLA_IFNAME; + req.rtattr.rta_len = RTA_LENGTH(loopback_link->name.size() + 1); + strncpy(req.ifname, loopback_link->name.c_str(), sizeof(req.ifname)); + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); + + bool found = false; + ASSERT_NO_ERRNO(NetlinkRequestResponse( + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { + CheckLinkMsg(hdr, *loopback_link); + found = true; + }, + false)); + EXPECT_TRUE(found) << "Netlink response does not contain any links."; +} + +TEST(NetlinkRouteTest, GetLinkByIndexNotFound) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = 1234590; + + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(ENODEV, ::testing::_)); +} + +TEST(NetlinkRouteTest, GetLinkByNameNotFound) { + const std::string name = "nodevice?!"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + struct rtattr rtattr; + char ifname[IFNAMSIZ]; + char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.rtattr.rta_type = IFLA_IFNAME; + req.rtattr.rta_len = RTA_LENGTH(name.size() + 1); + strncpy(req.ifname, name.c_str(), sizeof(req.ifname)); + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); + + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(ENODEV, ::testing::_)); +} + +TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; struct request req = {}; req.hdr.nlmsg_len = sizeof(req); @@ -175,18 +346,8 @@ TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { req.hdr.nlmsg_seq = kSeq; req.ifm.ifi_family = AF_UNSPEC; - ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), - [&](const struct nlmsghdr* hdr) { - EXPECT_THAT(hdr->nlmsg_type, Eq(NLMSG_ERROR)); - EXPECT_EQ(hdr->nlmsg_seq, kSeq); - EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr)); - - const struct nlmsgerr* msg = - reinterpret_cast(NLMSG_DATA(hdr)); - EXPECT_EQ(msg->error, -EOPNOTSUPP); - }, - true)); + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(EOPNOTSUPP, ::testing::_)); } TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { @@ -198,8 +359,6 @@ TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETLINK; @@ -238,8 +397,6 @@ TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) { struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETLINK; @@ -282,8 +439,6 @@ TEST(NetlinkRouteTest, ControlMessageIgnored) { struct ifinfomsg ifm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; // This control message is ignored. We still receive a response for the @@ -317,8 +472,6 @@ TEST(NetlinkRouteTest, GetAddrDump) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -367,6 +520,57 @@ TEST(NetlinkRouteTest, LookupAll) { ASSERT_GT(count, 0); } +TEST(NetlinkRouteTest, AddAddr) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + absl::optional loopback_link = + ASSERT_NO_ERRNO_AND_VALUE(FindLoopbackLink()); + ASSERT_TRUE(loopback_link.has_value()); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifaddrmsg ifa; + struct rtattr rtattr; + struct in_addr addr; + char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_type = RTM_NEWADDR; + req.hdr.nlmsg_seq = kSeq; + req.ifa.ifa_family = AF_INET; + req.ifa.ifa_prefixlen = 24; + req.ifa.ifa_flags = 0; + req.ifa.ifa_scope = 0; + req.ifa.ifa_index = loopback_link->index; + req.rtattr.rta_type = IFA_LOCAL; + req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr)); + inet_pton(AF_INET, "10.0.0.1", &req.addr); + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(req.ifa)) + NLMSG_ALIGN(req.rtattr.rta_len); + + // Create should succeed, as no such address in kernel. + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_ACK; + EXPECT_NO_ERRNO( + NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len)); + + // Replace an existing address should succeed. + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK; + req.hdr.nlmsg_seq++; + EXPECT_NO_ERRNO( + NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len)); + + // Create exclusive should fail, as we created the address above. + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK; + req.hdr.nlmsg_seq++; + EXPECT_THAT( + NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len), + PosixErrorIs(EEXIST, ::testing::_)); +} + // GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request. TEST(NetlinkRouteTest, GetRouteDump) { FileDescriptor fd = @@ -378,8 +582,6 @@ TEST(NetlinkRouteTest, GetRouteDump) { struct rtmsg rtm; }; - constexpr uint32_t kSeq = 12345; - struct request req = {}; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETROUTE; @@ -538,8 +740,6 @@ TEST(NetlinkRouteTest, RecvmsgTrunc) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -615,8 +815,6 @@ TEST(NetlinkRouteTest, RecvmsgTruncPeek) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -695,8 +893,6 @@ TEST(NetlinkRouteTest, NoPasscredNoCreds) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; @@ -743,8 +939,6 @@ TEST(NetlinkRouteTest, PasscredCreds) { struct rtgenmsg rgm; }; - constexpr uint32_t kSeq = 12345; - struct request req; req.hdr.nlmsg_len = sizeof(req); req.hdr.nlmsg_type = RTM_GETADDR; diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc index cd2212a1a..952eecfe8 100644 --- a/test/syscalls/linux/socket_netlink_util.cc +++ b/test/syscalls/linux/socket_netlink_util.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -71,9 +72,10 @@ PosixError NetlinkRequestResponse( iov.iov_base = buf.data(); iov.iov_len = buf.size(); - // Response is a series of NLM_F_MULTI messages, ending with a NLMSG_DONE - // message. + // If NLM_F_MULTI is set, response is a series of messages that ends with a + // NLMSG_DONE message. int type = -1; + int flags = 0; do { int len; RETURN_ERROR_IF_SYSCALL_FAIL(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0)); @@ -89,6 +91,7 @@ PosixError NetlinkRequestResponse( for (struct nlmsghdr* hdr = reinterpret_cast(buf.data()); NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) { fn(hdr); + flags = hdr->nlmsg_flags; type = hdr->nlmsg_type; // Done should include an integer payload for dump_done_errno. // See net/netlink/af_netlink.c:netlink_dump @@ -98,11 +101,11 @@ PosixError NetlinkRequestResponse( EXPECT_GE(hdr->nlmsg_len, NLMSG_LENGTH(sizeof(int))); } } - } while (type != NLMSG_DONE && type != NLMSG_ERROR); + } while ((flags & NLM_F_MULTI) && type != NLMSG_DONE && type != NLMSG_ERROR); if (expect_nlmsgerr) { EXPECT_EQ(type, NLMSG_ERROR); - } else { + } else if (flags & NLM_F_MULTI) { EXPECT_EQ(type, NLMSG_DONE); } return NoError(); @@ -146,5 +149,39 @@ PosixError NetlinkRequestResponseSingle( return NoError(); } +PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq, + void* request, size_t len) { + // Dummy negative number for no error message received. + // We won't get a negative error number so there will be no confusion. + int err = -42; + RETURN_IF_ERRNO(NetlinkRequestResponse( + fd, request, len, + [&](const struct nlmsghdr* hdr) { + EXPECT_EQ(NLMSG_ERROR, hdr->nlmsg_type); + EXPECT_EQ(hdr->nlmsg_seq, seq); + EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr)); + + const struct nlmsgerr* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + err = -msg->error; + }, + true)); + return PosixError(err); +} + +const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr, + const struct ifinfomsg* msg, int16_t attr) { + const int ifi_space = NLMSG_SPACE(sizeof(*msg)); + int attrlen = hdr->nlmsg_len - ifi_space; + const struct rtattr* rta = reinterpret_cast( + reinterpret_cast(hdr) + NLMSG_ALIGN(ifi_space)); + for (; RTA_OK(rta, attrlen); rta = RTA_NEXT(rta, attrlen)) { + if (rta->rta_type == attr) { + return rta; + } + } + return nullptr; +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h index 3678c0599..e13ead406 100644 --- a/test/syscalls/linux/socket_netlink_util.h +++ b/test/syscalls/linux/socket_netlink_util.h @@ -19,6 +19,7 @@ // socket.h has to be included before if_arp.h. #include #include +#include #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -47,6 +48,14 @@ PosixError NetlinkRequestResponseSingle( const FileDescriptor& fd, void* request, size_t len, const std::function& fn); +// Send the passed request then expect and return an ack or error. +PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq, + void* request, size_t len); + +// Find rtnetlink attribute in message. +const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr, + const struct ifinfomsg* msg, int16_t attr); + } // namespace testing } // namespace gvisor -- cgit v1.2.3 From f3d95607036b8a502c65aa7b3e8145227274dbbc Mon Sep 17 00:00:00 2001 From: Eyal Soha Date: Wed, 5 Feb 2020 17:56:00 -0800 Subject: recv() on a closed TCP socket returns ENOTCONN From RFC 793 s3.9 p58 Event Processing: If RECEIVE Call arrives in CLOSED state and the user has access to such a connection, the return should be "error: connection does not exist" Fixes #1598 PiperOrigin-RevId: 293494287 --- pkg/sentry/socket/netstack/netstack.go | 7 ++++++- pkg/tcpip/tcpip.go | 4 ++++ pkg/tcpip/transport/tcp/endpoint.go | 4 ++-- pkg/tcpip/transport/tcp/tcp_test.go | 9 ++++----- test/syscalls/linux/tcp_socket.cc | 9 +++++++++ 5 files changed, 25 insertions(+), 8 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 049d04bf2..ed2fbcceb 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -2229,11 +2229,16 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq var copied int // Copy as many views as possible into the user-provided buffer. - for dst.NumBytes() != 0 { + for { + // Always do at least one fetchReadView, even if the number of bytes to + // read is 0. err = s.fetchReadView() if err != nil { break } + if dst.NumBytes() == 0 { + break + } var n int var e error diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 0fa141d58..d29d9a704 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1124,6 +1124,10 @@ type ReadErrors struct { // InvalidEndpointState is the number of times we found the endpoint state // to be unexpected. InvalidEndpointState StatCounter + + // NotConnected is the number of times we tried to read but found that the + // endpoint was not connected. + NotConnected StatCounter } // WriteErrors collects packet write errors from an endpoint write call. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index b5a8e15ee..e4a6b1b8b 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1003,8 +1003,8 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } - e.stats.ReadErrors.InvalidEndpointState.Increment() - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + e.stats.ReadErrors.NotConnected.Increment() + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected } v, err := e.readLocked() diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 2c1505067..cc118c993 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -5405,12 +5405,11 @@ func TestEndpointBindListenAcceptState(t *testing.T) { t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) } - // Expect InvalidEndpointState errors on a read at this point. - if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState) + if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected { + t.Errorf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrNotConnected) } - if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 { - t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1) + if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { + t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %v want %v", got, 1) } if err := ep.Listen(10); err != nil { diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 525ccbd88..8a8b68e75 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1339,6 +1339,15 @@ TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptGreaterThanZero) { EXPECT_EQ(get, kTCPDeferAccept); } +TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) { + auto s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + char buf[1]; + EXPECT_THAT(recv(s.get(), buf, 0, 0), SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT(recv(s.get(), buf, sizeof(buf), 0), + SyscallFailsWithErrno(ENOTCONN)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); -- cgit v1.2.3 From 1b6a12a768216a99a5e0428c42ea4faf79cf3b50 Mon Sep 17 00:00:00 2001 From: Adin Scannell Date: Wed, 5 Feb 2020 22:45:44 -0800 Subject: Add notes to relevant tests. These were out-of-band notes that can help provide additional context and simplify automated imports. PiperOrigin-RevId: 293525915 --- pkg/metric/metric.go | 1 - pkg/sentry/arch/arch_x86.go | 4 ++ pkg/sentry/arch/signal_amd64.go | 2 +- pkg/sentry/fs/file_overlay_test.go | 1 + pkg/sentry/fs/proc/README.md | 4 ++ pkg/sentry/kernel/BUILD | 1 + pkg/sentry/kernel/kernel.go | 3 ++ pkg/sentry/kernel/kernel_opts.go | 20 +++++++ pkg/sentry/socket/hostinet/BUILD | 1 + pkg/sentry/socket/hostinet/socket.go | 5 +- pkg/sentry/socket/hostinet/sockopt_impl.go | 27 ++++++++++ pkg/tcpip/transport/tcp/endpoint.go | 3 ++ runsc/boot/filter/BUILD | 1 + runsc/boot/filter/config.go | 13 ----- runsc/boot/filter/config_profile.go | 34 ++++++++++++ runsc/container/console_test.go | 5 +- runsc/dockerutil/dockerutil.go | 11 ++-- runsc/testutil/BUILD | 5 +- runsc/testutil/testutil.go | 54 ------------------- runsc/testutil/testutil_runfiles.go | 75 +++++++++++++++++++++++++++ test/image/image_test.go | 8 +-- test/syscalls/build_defs.bzl | 35 +++++++++++-- test/syscalls/linux/chroot.cc | 2 +- test/syscalls/linux/concurrency.cc | 3 +- test/syscalls/linux/exec_proc_exe_workload.cc | 6 +++ test/syscalls/linux/fork.cc | 5 +- test/syscalls/linux/mmap.cc | 8 +-- test/syscalls/linux/open_create.cc | 1 + test/syscalls/linux/preadv.cc | 1 + test/syscalls/linux/proc.cc | 46 +++++++++++++--- test/syscalls/linux/readv.cc | 4 +- test/syscalls/linux/rseq.cc | 2 +- test/syscalls/linux/select.cc | 2 +- test/syscalls/linux/shm.cc | 2 +- test/syscalls/linux/sigprocmask.cc | 2 +- test/syscalls/linux/socket_unix_non_stream.cc | 4 +- test/syscalls/linux/symlink.cc | 2 +- test/syscalls/linux/tcp_socket.cc | 3 +- test/syscalls/linux/time.cc | 1 + test/syscalls/linux/tkill.cc | 2 +- test/util/temp_path.cc | 1 + tools/build/tags.bzl | 4 ++ tools/defs.bzl | 17 +++++- 43 files changed, 318 insertions(+), 113 deletions(-) create mode 100644 pkg/sentry/kernel/kernel_opts.go create mode 100644 pkg/sentry/socket/hostinet/sockopt_impl.go create mode 100644 runsc/boot/filter/config_profile.go create mode 100644 runsc/testutil/testutil_runfiles.go (limited to 'pkg/tcpip') diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index 93d4f2b8c..006fcd9ab 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -46,7 +46,6 @@ var ( // // TODO(b/67298402): Support non-cumulative metrics. // TODO(b/67298427): Support metric fields. -// type Uint64Metric struct { // value is the actual value of the metric. It must be accessed // atomically. diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index a18093155..3db8bd34b 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -114,6 +114,10 @@ func newX86FPStateSlice() []byte { size, align := cpuid.HostFeatureSet().ExtendedStateSize() capacity := size // Always use at least 4096 bytes. + // + // For the KVM platform, this state is a fixed 4096 bytes, so make sure + // that the underlying array is at _least_ that size otherwise we will + // corrupt random memory. This is not a pleasant thing to debug. if capacity < 4096 { capacity = 4096 } diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go index 81b92bb43..6fb756f0e 100644 --- a/pkg/sentry/arch/signal_amd64.go +++ b/pkg/sentry/arch/signal_amd64.go @@ -55,7 +55,7 @@ type SignalContext64 struct { Trapno uint64 Oldmask linux.SignalSet Cr2 uint64 - // Pointer to a struct _fpstate. + // Pointer to a struct _fpstate. See b/33003106#comment8. Fpstate uint64 Reserved [8]uint64 } diff --git a/pkg/sentry/fs/file_overlay_test.go b/pkg/sentry/fs/file_overlay_test.go index 02538bb4f..a76d87e3a 100644 --- a/pkg/sentry/fs/file_overlay_test.go +++ b/pkg/sentry/fs/file_overlay_test.go @@ -177,6 +177,7 @@ func TestReaddirRevalidation(t *testing.T) { // TestReaddirOverlayFrozen tests that calling Readdir on an overlay file with // a frozen dirent tree does not make Readdir calls to the underlying files. +// This is a regression test for b/114808269. func TestReaddirOverlayFrozen(t *testing.T) { ctx := contexttest.Context(t) diff --git a/pkg/sentry/fs/proc/README.md b/pkg/sentry/fs/proc/README.md index 5d4ec6c7b..6667a0916 100644 --- a/pkg/sentry/fs/proc/README.md +++ b/pkg/sentry/fs/proc/README.md @@ -11,6 +11,8 @@ inconsistency, please file a bug. The following files are implemented: + + | File /proc/ | Content | | :------------------------ | :---------------------------------------------------- | | [cpuinfo](#cpuinfo) | Info about the CPU | @@ -22,6 +24,8 @@ The following files are implemented: | [uptime](#uptime) | Wall clock since boot, combined idle time of all cpus | | [version](#version) | Kernel version | + + ### cpuinfo ```bash diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index a27628c0a..2231d6973 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -91,6 +91,7 @@ go_library( "fs_context.go", "ipc_namespace.go", "kernel.go", + "kernel_opts.go", "kernel_state.go", "pending_signals.go", "pending_signals_list.go", diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index dcd6e91c4..3ee760ba2 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -235,6 +235,9 @@ type Kernel struct { // events. This is initialized lazily on the first unimplemented // syscall. unimplementedSyscallEmitter eventchannel.Emitter `state:"nosave"` + + // SpecialOpts contains special kernel options. + SpecialOpts } // InitKernelArgs holds arguments to Init. diff --git a/pkg/sentry/kernel/kernel_opts.go b/pkg/sentry/kernel/kernel_opts.go new file mode 100644 index 000000000..2e66ec587 --- /dev/null +++ b/pkg/sentry/kernel/kernel_opts.go @@ -0,0 +1,20 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +// SpecialOpts contains non-standard options for the kernel. +// +// +stateify savable +type SpecialOpts struct{} diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 5a07d5d0e..023bad156 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -10,6 +10,7 @@ go_library( "save_restore.go", "socket.go", "socket_unsafe.go", + "sockopt_impl.go", "stack.go", ], visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 34f63986f..de76388ac 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -285,7 +285,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt } // Whitelist options and constrain option length. - var optlen int + optlen := getSockOptLen(t, level, name) switch level { case linux.SOL_IP: switch name { @@ -330,7 +330,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt // SetSockOpt implements socket.Socket.SetSockOpt. func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { // Whitelist options and constrain option length. - var optlen int + optlen := setSockOptLen(t, level, name) switch level { case linux.SOL_IP: switch name { @@ -353,6 +353,7 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ optlen = sizeofInt32 } } + if optlen == 0 { // Pretend to accept socket options we don't understand. This seems // dangerous, but it's what netstack does... diff --git a/pkg/sentry/socket/hostinet/sockopt_impl.go b/pkg/sentry/socket/hostinet/sockopt_impl.go new file mode 100644 index 000000000..8a783712e --- /dev/null +++ b/pkg/sentry/socket/hostinet/sockopt_impl.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hostinet + +import ( + "gvisor.dev/gvisor/pkg/sentry/kernel" +) + +func getSockOptLen(t *kernel.Task, level, name int) int { + return 0 // No custom options. +} + +func setSockOptLen(t *kernel.Task, level, name int) int { + return 0 // No custom options. +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index e4a6b1b8b..f2be0e651 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2166,6 +2166,9 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { e.isRegistered = true e.setEndpointState(StateListen) + // The channel may be non-nil when we're restoring the endpoint, and it + // may be pre-populated with some previously accepted (but not Accepted) + // endpoints. if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } diff --git a/runsc/boot/filter/BUILD b/runsc/boot/filter/BUILD index ce30f6c53..ed18f0047 100644 --- a/runsc/boot/filter/BUILD +++ b/runsc/boot/filter/BUILD @@ -8,6 +8,7 @@ go_library( "config.go", "config_amd64.go", "config_arm64.go", + "config_profile.go", "extra_filters.go", "extra_filters_msan.go", "extra_filters_race.go", diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index f8d351c7b..c69f4c602 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -536,16 +536,3 @@ func controlServerFilters(fd int) seccomp.SyscallRules { }, } } - -// profileFilters returns extra syscalls made by runtime/pprof package. -func profileFilters() seccomp.SyscallRules { - return seccomp.SyscallRules{ - syscall.SYS_OPENAT: []seccomp.Rule{ - { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC), - }, - }, - } -} diff --git a/runsc/boot/filter/config_profile.go b/runsc/boot/filter/config_profile.go new file mode 100644 index 000000000..194952a7b --- /dev/null +++ b/runsc/boot/filter/config_profile.go @@ -0,0 +1,34 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package filter + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/seccomp" +) + +// profileFilters returns extra syscalls made by runtime/pprof package. +func profileFilters() seccomp.SyscallRules { + return seccomp.SyscallRules{ + syscall.SYS_OPENAT: []seccomp.Rule{ + { + seccomp.AllowAny{}, + seccomp.AllowAny{}, + seccomp.AllowValue(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC), + }, + }, + } +} diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index 060b63bf3..c2518d52b 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -196,7 +196,10 @@ func TestJobControlSignalExec(t *testing.T) { defer ptyMaster.Close() defer ptySlave.Close() - // Exec bash and attach a terminal. + // Exec bash and attach a terminal. Note that occasionally /bin/sh + // may be a different shell or have a different configuration (such + // as disabling interactive mode and job control). Since we want to + // explicitly test interactive mode, use /bin/bash. See b/116981926. execArgs := &control.ExecArgs{ Filename: "/bin/bash", // Don't let bash execute from profile or rc files, otherwise diff --git a/runsc/dockerutil/dockerutil.go b/runsc/dockerutil/dockerutil.go index 9b6346ca2..1ff5e8cc3 100644 --- a/runsc/dockerutil/dockerutil.go +++ b/runsc/dockerutil/dockerutil.go @@ -143,8 +143,11 @@ func PrepareFiles(names ...string) (string, error) { return "", fmt.Errorf("os.Chmod(%q, 0777) failed: %v", dir, err) } for _, name := range names { - src := getLocalPath(name) - dst := path.Join(dir, name) + src, err := testutil.FindFile(name) + if err != nil { + return "", fmt.Errorf("testutil.Preparefiles(%q) failed: %v", name, err) + } + dst := path.Join(dir, path.Base(name)) if err := testutil.Copy(src, dst); err != nil { return "", fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err) } @@ -152,10 +155,6 @@ func PrepareFiles(names ...string) (string, error) { return dir, nil } -func getLocalPath(file string) string { - return path.Join(".", file) -} - // do executes docker command. func do(args ...string) (string, error) { log.Printf("Running: docker %s\n", args) diff --git a/runsc/testutil/BUILD b/runsc/testutil/BUILD index f845120b0..945405303 100644 --- a/runsc/testutil/BUILD +++ b/runsc/testutil/BUILD @@ -5,7 +5,10 @@ package(licenses = ["notice"]) go_library( name = "testutil", testonly = 1, - srcs = ["testutil.go"], + srcs = [ + "testutil.go", + "testutil_runfiles.go", + ], visibility = ["//:sandbox"], deps = [ "//pkg/log", diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go index edf2e809a..80c2c9680 100644 --- a/runsc/testutil/testutil.go +++ b/runsc/testutil/testutil.go @@ -79,60 +79,6 @@ func ConfigureExePath() error { return nil } -// FindFile searchs for a file inside the test run environment. It returns the -// full path to the file. It fails if none or more than one file is found. -func FindFile(path string) (string, error) { - wd, err := os.Getwd() - if err != nil { - return "", err - } - - // The test root is demarcated by a path element called "__main__". Search for - // it backwards from the working directory. - root := wd - for { - dir, name := filepath.Split(root) - if name == "__main__" { - break - } - if len(dir) == 0 { - return "", fmt.Errorf("directory __main__ not found in %q", wd) - } - // Remove ending slash to loop around. - root = dir[:len(dir)-1] - } - - // Annoyingly, bazel adds the build type to the directory path for go - // binaries, but not for c++ binaries. We use two different patterns to - // to find our file. - patterns := []string{ - // Try the obvious path first. - filepath.Join(root, path), - // If it was a go binary, use a wildcard to match the build - // type. The pattern is: /test-path/__main__/directories/*/file. - filepath.Join(root, filepath.Dir(path), "*", filepath.Base(path)), - } - - for _, p := range patterns { - matches, err := filepath.Glob(p) - if err != nil { - // "The only possible returned error is ErrBadPattern, - // when pattern is malformed." -godoc - return "", fmt.Errorf("error globbing %q: %v", p, err) - } - switch len(matches) { - case 0: - // Try the next pattern. - case 1: - // We found it. - return matches[0], nil - default: - return "", fmt.Errorf("more than one match found for %q: %s", path, matches) - } - } - return "", fmt.Errorf("file %q not found", path) -} - // TestConfig returns the default configuration to use in tests. Note that // 'RootDir' must be set by caller if required. func TestConfig() *boot.Config { diff --git a/runsc/testutil/testutil_runfiles.go b/runsc/testutil/testutil_runfiles.go new file mode 100644 index 000000000..ece9ea9a1 --- /dev/null +++ b/runsc/testutil/testutil_runfiles.go @@ -0,0 +1,75 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "fmt" + "os" + "path/filepath" +) + +// FindFile searchs for a file inside the test run environment. It returns the +// full path to the file. It fails if none or more than one file is found. +func FindFile(path string) (string, error) { + wd, err := os.Getwd() + if err != nil { + return "", err + } + + // The test root is demarcated by a path element called "__main__". Search for + // it backwards from the working directory. + root := wd + for { + dir, name := filepath.Split(root) + if name == "__main__" { + break + } + if len(dir) == 0 { + return "", fmt.Errorf("directory __main__ not found in %q", wd) + } + // Remove ending slash to loop around. + root = dir[:len(dir)-1] + } + + // Annoyingly, bazel adds the build type to the directory path for go + // binaries, but not for c++ binaries. We use two different patterns to + // to find our file. + patterns := []string{ + // Try the obvious path first. + filepath.Join(root, path), + // If it was a go binary, use a wildcard to match the build + // type. The pattern is: /test-path/__main__/directories/*/file. + filepath.Join(root, filepath.Dir(path), "*", filepath.Base(path)), + } + + for _, p := range patterns { + matches, err := filepath.Glob(p) + if err != nil { + // "The only possible returned error is ErrBadPattern, + // when pattern is malformed." -godoc + return "", fmt.Errorf("error globbing %q: %v", p, err) + } + switch len(matches) { + case 0: + // Try the next pattern. + case 1: + // We found it. + return matches[0], nil + default: + return "", fmt.Errorf("more than one match found for %q: %s", path, matches) + } + } + return "", fmt.Errorf("file %q not found", path) +} diff --git a/test/image/image_test.go b/test/image/image_test.go index d0dcb1861..0a1e19d6f 100644 --- a/test/image/image_test.go +++ b/test/image/image_test.go @@ -107,7 +107,7 @@ func TestHttpd(t *testing.T) { } d := dockerutil.MakeDocker("http-test") - dir, err := dockerutil.PrepareFiles("latin10k.txt") + dir, err := dockerutil.PrepareFiles("test/image/latin10k.txt") if err != nil { t.Fatalf("PrepareFiles() failed: %v", err) } @@ -139,7 +139,7 @@ func TestNginx(t *testing.T) { } d := dockerutil.MakeDocker("net-test") - dir, err := dockerutil.PrepareFiles("latin10k.txt") + dir, err := dockerutil.PrepareFiles("test/image/latin10k.txt") if err != nil { t.Fatalf("PrepareFiles() failed: %v", err) } @@ -183,7 +183,7 @@ func TestMysql(t *testing.T) { } client := dockerutil.MakeDocker("mysql-client-test") - dir, err := dockerutil.PrepareFiles("mysql.sql") + dir, err := dockerutil.PrepareFiles("test/image/mysql.sql") if err != nil { t.Fatalf("PrepareFiles() failed: %v", err) } @@ -283,7 +283,7 @@ func TestRuby(t *testing.T) { } d := dockerutil.MakeDocker("ruby-test") - dir, err := dockerutil.PrepareFiles("ruby.rb", "ruby.sh") + dir, err := dockerutil.PrepareFiles("test/image/ruby.rb", "test/image/ruby.sh") if err != nil { t.Fatalf("PrepareFiles() failed: %v", err) } diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl index 1df761dd0..cbab85ef7 100644 --- a/test/syscalls/build_defs.bzl +++ b/test/syscalls/build_defs.bzl @@ -2,8 +2,6 @@ load("//tools:defs.bzl", "loopback") -# syscall_test is a macro that will create targets to run the given test target -# on the host (native) and runsc. def syscall_test( test, shard_count = 5, @@ -13,6 +11,19 @@ def syscall_test( add_uds_tree = False, add_hostinet = False, tags = None): + """syscall_test is a macro that will create targets for all platforms. + + Args: + test: the test target. + shard_count: shards for defined tests. + size: the defined test size. + use_tmpfs: use tmpfs in the defined tests. + add_overlay: add an overlay test. + add_uds_tree: add a UDS test. + add_hostinet: add a hostinet test. + tags: starting test tags. + """ + _syscall_test( test = test, shard_count = shard_count, @@ -111,6 +122,19 @@ def _syscall_test( # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared. tags += [full_platform, "file_" + file_access] + # Hash this target into one of 15 buckets. This can be used to + # randomly split targets between different workflows. + hash15 = hash(native.package_name() + name) % 15 + tags.append("hash15:" + str(hash15)) + + # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until + # we figure out how to request ipv4 sockets on Guitar machines. + if network == "host": + tags.append("noguitar") + + # Disable off-host networking. + tags.append("requires-net:loopback") + # Add tag to prevent the tests from running in a Bazel sandbox. # TODO(b/120560048): Make the tests run without this tag. tags.append("no-sandbox") @@ -118,8 +142,11 @@ def _syscall_test( # TODO(b/112165693): KVM tests are tagged "manual" to until the platform is # more stable. if platform == "kvm": - tags += ["manual"] - tags += ["requires-kvm"] + tags.append("manual") + tags.append("requires-kvm") + + # TODO(b/112165693): Remove when tests pass reliably. + tags.append("notap") args = [ # Arguments are passed directly to syscall_test_runner binary. diff --git a/test/syscalls/linux/chroot.cc b/test/syscalls/linux/chroot.cc index 0a2d44a2c..85ec013d5 100644 --- a/test/syscalls/linux/chroot.cc +++ b/test/syscalls/linux/chroot.cc @@ -167,7 +167,7 @@ TEST(ChrootTest, DotDotFromOpenFD) { } // Test that link resolution in a chroot can escape the root by following an -// open proc fd. +// open proc fd. Regression test for b/32316719. TEST(ChrootTest, ProcFdLinkResolutionInChroot) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT))); diff --git a/test/syscalls/linux/concurrency.cc b/test/syscalls/linux/concurrency.cc index f41f99900..7cd6a75bd 100644 --- a/test/syscalls/linux/concurrency.cc +++ b/test/syscalls/linux/concurrency.cc @@ -46,7 +46,8 @@ TEST(ConcurrencyTest, SingleProcessMultithreaded) { } // Test that multiple threads in this process continue to execute in parallel, -// even if an unrelated second process is spawned. +// even if an unrelated second process is spawned. Regression test for +// b/32119508. TEST(ConcurrencyTest, MultiProcessMultithreaded) { // In PID 1, start TIDs 1 and 2, and put both to sleep. // diff --git a/test/syscalls/linux/exec_proc_exe_workload.cc b/test/syscalls/linux/exec_proc_exe_workload.cc index b790fe5be..2989379b7 100644 --- a/test/syscalls/linux/exec_proc_exe_workload.cc +++ b/test/syscalls/linux/exec_proc_exe_workload.cc @@ -21,6 +21,12 @@ #include "test/util/posix_error.h" int main(int argc, char** argv, char** envp) { + // This is annoying. Because remote build systems may put these binaries + // in a content-addressable-store, you may wind up with /proc/self/exe + // pointing to some random path (but with a sensible argv[0]). + // + // Therefore, this test simply checks that the /proc/self/exe + // is absolute and *doesn't* match argv[1]. std::string exe = gvisor::testing::ProcessExePath(getpid()).ValueOrDie(); if (exe[0] != '/') { diff --git a/test/syscalls/linux/fork.cc b/test/syscalls/linux/fork.cc index 906f3358d..ff8bdfeb0 100644 --- a/test/syscalls/linux/fork.cc +++ b/test/syscalls/linux/fork.cc @@ -271,7 +271,7 @@ TEST_F(ForkTest, Alarm) { EXPECT_EQ(0, alarmed); } -// Child cannot affect parent private memory. +// Child cannot affect parent private memory. Regression test for b/24137240. TEST_F(ForkTest, PrivateMemory) { std::atomic local(0); @@ -298,6 +298,9 @@ TEST_F(ForkTest, PrivateMemory) { } // Kernel-accessed buffers should remain coherent across COW. +// +// The buffer must be >= usermem.ZeroCopyMinBytes, as UnsafeAccess operates +// differently. Regression test for b/33811887. TEST_F(ForkTest, COWSegment) { constexpr int kBufSize = 1024; char* read_buf = private_; diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc index 1c4d9f1c7..11fb1b457 100644 --- a/test/syscalls/linux/mmap.cc +++ b/test/syscalls/linux/mmap.cc @@ -1418,7 +1418,7 @@ TEST_P(MMapFileParamTest, NoSigBusOnPageContainingEOF) { // // On most platforms this is trivial, but when the file is mapped via the sentry // page cache (which does not yet support writing to shared mappings), a bug -// caused reads to fail unnecessarily on such mappings. +// caused reads to fail unnecessarily on such mappings. See b/28913513. TEST_F(MMapFileTest, ReadingWritableSharedFilePageSucceeds) { uintptr_t addr; size_t len = strlen(kFileContents); @@ -1435,7 +1435,7 @@ TEST_F(MMapFileTest, ReadingWritableSharedFilePageSucceeds) { // Tests that EFAULT is returned when invoking a syscall that requires the OS to // read past end of file (resulting in a fault in sentry context in the gVisor -// case). +// case). See b/28913513. TEST_F(MMapFileTest, InternalSigBus) { uintptr_t addr; ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE, @@ -1578,7 +1578,7 @@ TEST_F(MMapFileTest, Bug38498194) { } // Tests that reading from a file to a memory mapping of the same file does not -// deadlock. +// deadlock. See b/34813270. TEST_F(MMapFileTest, SelfRead) { uintptr_t addr; ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, @@ -1590,7 +1590,7 @@ TEST_F(MMapFileTest, SelfRead) { } // Tests that writing to a file from a memory mapping of the same file does not -// deadlock. +// deadlock. Regression test for b/34813270. TEST_F(MMapFileTest, SelfWrite) { uintptr_t addr; ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0), diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index 431733dbe..902d0a0dc 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -132,6 +132,7 @@ TEST(CreateTest, CreateFailsOnDirWithoutWritePerms) { } // A file originally created RW, but opened RO can later be opened RW. +// Regression test for b/65385065. TEST(CreateTest, OpenCreateROThenRW) { TempPath file(NewTempAbsPath()); diff --git a/test/syscalls/linux/preadv.cc b/test/syscalls/linux/preadv.cc index f7ea44054..5b0743fe9 100644 --- a/test/syscalls/linux/preadv.cc +++ b/test/syscalls/linux/preadv.cc @@ -37,6 +37,7 @@ namespace testing { namespace { +// Stress copy-on-write. Attempts to reproduce b/38430174. TEST(PreadvTest, MMConcurrencyStress) { // Fill a one-page file with zeroes (the contents don't really matter). const auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index 169b723eb..a23fdb58d 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -1352,13 +1352,19 @@ TEST(ProcPidSymlink, SubprocessZombied) { // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux // on proc files. - // 4.17 & gVisor: Syscall succeeds and returns 1 + // + // ~4.3: Syscall fails with EACCES. + // 4.17 & gVisor: Syscall succeeds and returns 1. + // // EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)), // SyscallFailsWithErrno(EACCES)); // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux // on proc files. - // 4.17 & gVisor: Syscall succeeds and returns 1. + // + // ~4.3: Syscall fails with EACCES. + // 4.17 & gVisor: Syscall succeeds and returns 1. + // // EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)), // SyscallFailsWithErrno(EACCES)); } @@ -1431,8 +1437,12 @@ TEST(ProcPidFile, SubprocessRunning) { TEST(ProcPidFile, SubprocessZombie) { char buf[1]; - // 4.17: Succeeds and returns 1 - // gVisor: Succeeds and returns 0 + // FIXME(gvisor.dev/issue/164): Loosen requirement due to inconsistent + // behavior on different kernels. + // + // ~4.3: Succeds and returns 0. + // 4.17: Succeeds and returns 1. + // gVisor: Succeeds and returns 0. EXPECT_THAT(ReadWhileZombied("auxv", buf, sizeof(buf)), SyscallSucceeds()); EXPECT_THAT(ReadWhileZombied("cmdline", buf, sizeof(buf)), @@ -1458,7 +1468,10 @@ TEST(ProcPidFile, SubprocessZombie) { // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux // on proc files. + // + // ~4.3: Fails and returns EACCES. // gVisor & 4.17: Succeeds and returns 1. + // // EXPECT_THAT(ReadWhileZombied("io", buf, sizeof(buf)), // SyscallFailsWithErrno(EACCES)); } @@ -1467,9 +1480,12 @@ TEST(ProcPidFile, SubprocessZombie) { TEST(ProcPidFile, SubprocessExited) { char buf[1]; - // FIXME(gvisor.dev/issue/164): Inconsistent behavior between kernels + // FIXME(gvisor.dev/issue/164): Inconsistent behavior between kernels. + // + // ~4.3: Fails and returns ESRCH. // gVisor: Fails with ESRCH. // 4.17: Succeeds and returns 1. + // // EXPECT_THAT(ReadWhileExited("auxv", buf, sizeof(buf)), // SyscallFailsWithErrno(ESRCH)); @@ -1641,7 +1657,7 @@ TEST(ProcTask, KilledThreadsDisappear) { EXPECT_NO_ERRNO(DirContainsExactly("/proc/self/task", TaskFiles(initial, {child1.Tid()}))); - // Stat child1's task file. + // Stat child1's task file. Regression test for b/32097707. struct stat statbuf; const std::string child1_task_file = absl::StrCat("/proc/self/task/", child1.Tid()); @@ -1669,7 +1685,7 @@ TEST(ProcTask, KilledThreadsDisappear) { EXPECT_NO_ERRNO(EventuallyDirContainsExactly( "/proc/self/task", TaskFiles(initial, {child3.Tid(), child5.Tid()}))); - // Stat child1's task file again. This time it should fail. + // Stat child1's task file again. This time it should fail. See b/32097707. EXPECT_THAT(stat(child1_task_file.c_str(), &statbuf), SyscallFailsWithErrno(ENOENT)); @@ -1824,7 +1840,7 @@ TEST(ProcSysVmOvercommitMemory, HasNumericValue) { } // Check that link for proc fd entries point the target node, not the -// symlink itself. +// symlink itself. Regression test for b/31155070. TEST(ProcTaskFd, FstatatFollowsSymlink) { const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor fd = @@ -1883,6 +1899,20 @@ TEST(ProcMounts, IsSymlink) { EXPECT_EQ(link, "self/mounts"); } +TEST(ProcSelfMountinfo, RequiredFieldsArePresent) { + auto mountinfo = + ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mountinfo")); + EXPECT_THAT( + mountinfo, + AllOf( + // Root mount. + ContainsRegex( + R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ / / (rw|ro).*- \S+ \S+ (rw|ro)\S*)"), + // Proc mount - always rw. + ContainsRegex( + R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ / /proc rw.*- \S+ \S+ rw\S*)"))); +} + // Check that /proc/self/mounts looks something like a real mounts file. TEST(ProcSelfMounts, RequiredFieldsArePresent) { auto mounts = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mounts")); diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc index 4069cbc7e..baaf9f757 100644 --- a/test/syscalls/linux/readv.cc +++ b/test/syscalls/linux/readv.cc @@ -254,7 +254,9 @@ TEST_F(ReadvTest, IovecOutsideTaskAddressRangeInNonemptyArray) { // This test depends on the maximum extent of a single readv() syscall, so // we can't tolerate interruption from saving. TEST(ReadvTestNoFixture, TruncatedAtMax_NoRandomSave) { - // Ensure that we won't be interrupted by ITIMER_PROF. + // Ensure that we won't be interrupted by ITIMER_PROF. This is particularly + // important in environments where automated profiling tools may start + // ITIMER_PROF automatically. struct itimerval itv = {}; auto const cleanup_itimer = ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_PROF, itv)); diff --git a/test/syscalls/linux/rseq.cc b/test/syscalls/linux/rseq.cc index 106c045e3..4bfb1ff56 100644 --- a/test/syscalls/linux/rseq.cc +++ b/test/syscalls/linux/rseq.cc @@ -36,7 +36,7 @@ namespace { // We must be very careful about how these tests are written. Each thread may // only have one struct rseq registration, which may be done automatically at // thread start (as of 2019-11-13, glibc does *not* support rseq and thus does -// not do so). +// not do so, but other libraries do). // // Testing of rseq is thus done primarily in a child process with no // registration. This means exec'ing a nostdlib binary, as rseq registration can diff --git a/test/syscalls/linux/select.cc b/test/syscalls/linux/select.cc index 424e2a67f..be2364fb8 100644 --- a/test/syscalls/linux/select.cc +++ b/test/syscalls/linux/select.cc @@ -146,7 +146,7 @@ TEST_F(SelectTest, IgnoreBitsAboveNfds) { // This test illustrates Linux's behavior of 'select' calls passing after // setrlimit RLIMIT_NOFILE is called. In particular, versions of sshd rely on -// this behavior. +// this behavior. See b/122318458. TEST_F(SelectTest, SetrlimitCallNOFILE) { fd_set read_set; FD_ZERO(&read_set); diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc index 7ba752599..c7fdbb924 100644 --- a/test/syscalls/linux/shm.cc +++ b/test/syscalls/linux/shm.cc @@ -473,7 +473,7 @@ TEST(ShmTest, PartialUnmap) { } // Check that sentry does not panic when asked for a zero-length private shm -// segment. +// segment. Regression test for b/110694797. TEST(ShmTest, GracefullyFailOnZeroLenSegmentCreation) { EXPECT_THAT(Shmget(IPC_PRIVATE, 0, 0), PosixErrorIs(EINVAL, _)); } diff --git a/test/syscalls/linux/sigprocmask.cc b/test/syscalls/linux/sigprocmask.cc index 654c6a47f..a603fc1d1 100644 --- a/test/syscalls/linux/sigprocmask.cc +++ b/test/syscalls/linux/sigprocmask.cc @@ -237,7 +237,7 @@ TEST_F(SigProcMaskTest, SignalHandler) { } // Check that sigprocmask correctly handles aliasing of the set and oldset -// pointers. +// pointers. Regression test for b/30502311. TEST_F(SigProcMaskTest, AliasedSets) { sigset_t mask; diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc index 276a94eb8..884319e1d 100644 --- a/test/syscalls/linux/socket_unix_non_stream.cc +++ b/test/syscalls/linux/socket_unix_non_stream.cc @@ -109,7 +109,7 @@ PosixErrorOr> CreateFragmentedRegion(const int size, } // A contiguous iov that is heavily fragmented in FileMem can still be sent -// successfully. +// successfully. See b/115833655. TEST_P(UnixNonStreamSocketPairTest, FragmentedSendMsg) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -165,7 +165,7 @@ TEST_P(UnixNonStreamSocketPairTest, FragmentedSendMsg) { } // A contiguous iov that is heavily fragmented in FileMem can still be received -// into successfully. +// into successfully. Regression test for b/115833655. TEST_P(UnixNonStreamSocketPairTest, FragmentedRecvMsg) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc index b249ff91f..03ee1250d 100644 --- a/test/syscalls/linux/symlink.cc +++ b/test/syscalls/linux/symlink.cc @@ -38,7 +38,7 @@ mode_t FilePermission(const std::string& path) { } // Test that name collisions are checked on the new link path, not the source -// path. +// path. Regression test for b/31782115. TEST(SymlinkTest, CanCreateSymlinkWithCachedSourceDirent) { const std::string srcname = NewTempAbsPath(); const std::string newname = NewTempAbsPath(); diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 8a8b68e75..c4591a3b9 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -244,7 +244,8 @@ TEST_P(TcpSocketTest, ZeroWriteAllowed) { } // Test that a non-blocking write with a buffer that is larger than the send -// buffer size will not actually write the whole thing at once. +// buffer size will not actually write the whole thing at once. Regression test +// for b/64438887. TEST_P(TcpSocketTest, NonblockingLargeWrite) { // Set the FD to O_NONBLOCK. int opts; diff --git a/test/syscalls/linux/time.cc b/test/syscalls/linux/time.cc index c7eead17e..1ccb95733 100644 --- a/test/syscalls/linux/time.cc +++ b/test/syscalls/linux/time.cc @@ -62,6 +62,7 @@ TEST(TimeTest, VsyscallTime_InvalidAddressSIGSEGV) { ::testing::KilledBySignal(SIGSEGV), ""); } +// Mimics the gettimeofday(2) wrapper from the Go runtime <= 1.2. int vsyscall_gettimeofday(struct timeval* tv, struct timezone* tz) { constexpr uint64_t kVsyscallGettimeofdayEntry = 0xffffffffff600000; return reinterpret_cast( diff --git a/test/syscalls/linux/tkill.cc b/test/syscalls/linux/tkill.cc index bae377c69..8d8ebbb24 100644 --- a/test/syscalls/linux/tkill.cc +++ b/test/syscalls/linux/tkill.cc @@ -54,7 +54,7 @@ void SigHandler(int sig, siginfo_t* info, void* context) { TEST_CHECK(info->si_code == SI_TKILL); } -// Test with a real signal. +// Test with a real signal. Regression test for b/24790092. TEST(TkillTest, ValidTIDAndRealSignal) { struct sigaction sa; sa.sa_sigaction = SigHandler; diff --git a/test/util/temp_path.cc b/test/util/temp_path.cc index 35aacb172..9c10b6674 100644 --- a/test/util/temp_path.cc +++ b/test/util/temp_path.cc @@ -77,6 +77,7 @@ std::string NewTempAbsPath() { std::string NewTempRelPath() { return NextTempBasename(); } std::string GetAbsoluteTestTmpdir() { + // Note that TEST_TMPDIR is guaranteed to be set. char* env_tmpdir = getenv("TEST_TMPDIR"); std::string tmp_dir = env_tmpdir != nullptr ? std::string(env_tmpdir) : "/tmp"; diff --git a/tools/build/tags.bzl b/tools/build/tags.bzl index e99c87f81..a6db44e47 100644 --- a/tools/build/tags.bzl +++ b/tools/build/tags.bzl @@ -33,4 +33,8 @@ go_suffixes = [ "_wasm_unsafe", "_linux", "_linux_unsafe", + "_opts", + "_opts_unsafe", + "_impl", + "_impl_unsafe", ] diff --git a/tools/defs.bzl b/tools/defs.bzl index 5d5fa134a..c03b557ae 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -73,6 +73,16 @@ def calculate_sets(srcs): result[target].append(file) return result +def go_imports(name, src, out): + """Simplify a single Go source file by eliminating unused imports.""" + native.genrule( + name = name, + srcs = [src], + outs = [out], + tools = ["@org_golang_x_tools//cmd/goimports:goimports"], + cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"), + ) + def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, **kwargs): """Wraps the standard go_library and does stateification and marshalling. @@ -107,10 +117,15 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F state_sets = calculate_sets(srcs) for (suffix, srcs) in state_sets.items(): go_stateify( - name = name + suffix + "_state_autogen", + name = name + suffix + "_state_autogen_with_imports", srcs = srcs, imports = imports, package = name, + out = name + suffix + "_state_autogen_with_imports.go", + ) + go_imports( + name = name + suffix + "_state_autogen", + src = name + suffix + "_state_autogen_with_imports.go", out = name + suffix + "_state_autogen.go", ) all_srcs = all_srcs + [ -- cgit v1.2.3 From 6bd59b4e08893281468e8af5aebb5fab0f7a8c0d Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 6 Feb 2020 11:12:41 -0800 Subject: Update link address for targets of Neighbor Adverts Get the link address for the target of an NDP Neighbor Advertisement from the NDP Target Link Layer Address option. Tests: - ipv6.TestNeighorAdvertisementWithTargetLinkLayerOption - ipv6.TestNeighorAdvertisementWithInvalidTargetLinkLayerOption PiperOrigin-RevId: 293632609 --- pkg/tcpip/network/ipv6/icmp.go | 44 ++++-- pkg/tcpip/network/ipv6/icmp_test.go | 186 +++++++++++++++--------- pkg/tcpip/network/ipv6/ndp_test.go | 278 +++++++++++++++++++++++++----------- pkg/tcpip/stack/ndp_test.go | 8 +- 4 files changed, 352 insertions(+), 164 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 60817d36d..45dc757c7 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -15,6 +15,8 @@ package ipv6 import ( + "log" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -194,7 +196,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P // TODO(b/148429853): Properly process the NS message and do Neighbor // Unreachability Detection. for { - opt, done, _ := it.Next() + opt, done, err := it.Next() + if err != nil { + // This should never happen as Iter(true) above did not return an error. + log.Fatalf("unexpected error when iterating over NDP options: %s", err) + } if done { break } @@ -253,21 +259,25 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P } na := header.NDPNeighborAdvert(h.NDPPayload()) + it, err := na.Options().Iter(true) + if err != nil { + // If we have a malformed NDP NA option, drop the packet. + received.Invalid.Increment() + return + } + targetAddr := na.TargetAddress() stack := r.Stack() rxNICID := r.NICID() - isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr) - if err != nil { + if isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr); err != nil { // We will only get an error if rxNICID is unrecognized, // which should not happen. For now short-circuit this // packet. // // TODO(b/141002840): Handle this better? return - } - - if isTentative { + } else if isTentative { // We just got an NA from a node that owns an address we // are performing DAD on, implying the address is not // unique. In this case we let the stack know so it can @@ -283,13 +293,29 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P // scenario is beyond the scope of RFC 4862. As such, we simply // ignore such a scenario for now and proceed as normal. // + // If the NA message has the target link layer option, update the link + // address cache with the link address for the target of the message. + // // TODO(b/143147598): Handle the scenario described above. Also // inform the netstack integration that a duplicate address was // detected outside of DAD. + // + // TODO(b/148429853): Properly process the NA message and do Neighbor + // Unreachability Detection. + for { + opt, done, err := it.Next() + if err != nil { + // This should never happen as Iter(true) above did not return an error. + log.Fatalf("unexpected error when iterating over NDP options: %s", err) + } + if done { + break + } - e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, r.RemoteLinkAddress) - if targetAddr != r.RemoteAddress { - e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) + switch opt := opt.(type) { + case header.NDPTargetLinkLayerAddressOption: + e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, opt.EthernetAddress()) + } } case header.ICMPv6EchoRequest: diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index d0e930e20..50c4b6474 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -121,21 +121,60 @@ func TestICMPCounts(t *testing.T) { } defer r.Release() + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + types := []struct { - typ header.ICMPv6Type - size int + typ header.ICMPv6Type + size int + extraData []byte }{ - {header.ICMPv6DstUnreachable, header.ICMPv6DstUnreachableMinimumSize}, - {header.ICMPv6PacketTooBig, header.ICMPv6PacketTooBigMinimumSize}, - {header.ICMPv6TimeExceeded, header.ICMPv6MinimumSize}, - {header.ICMPv6ParamProblem, header.ICMPv6MinimumSize}, - {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize}, - {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize}, - {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize}, - {header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize}, - {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize}, - {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize}, - {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize}, + { + typ: header.ICMPv6DstUnreachable, + size: header.ICMPv6DstUnreachableMinimumSize, + }, + { + typ: header.ICMPv6PacketTooBig, + size: header.ICMPv6PacketTooBigMinimumSize, + }, + { + typ: header.ICMPv6TimeExceeded, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6ParamProblem, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6EchoRequest, + size: header.ICMPv6EchoMinimumSize, + }, + { + typ: header.ICMPv6EchoReply, + size: header.ICMPv6EchoMinimumSize, + }, + { + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + }, + { + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize}, + { + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + }, + { + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + }, } handleIPv6Payload := func(hdr buffer.Prependable) { @@ -154,10 +193,13 @@ func TestICMPCounts(t *testing.T) { } for _, typ := range types { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) + extraDataLen := len(typ.extraData) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) + extraData := buffer.View(hdr.Prepend(extraDataLen)) + copy(extraData, typ.extraData) pkt := header.ICMPv6(hdr.Prepend(typ.size)) pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) handleIPv6Payload(hdr) } @@ -372,97 +414,104 @@ func TestLinkResolution(t *testing.T) { } func TestICMPChecksumValidationSimple(t *testing.T) { + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + types := []struct { name string typ header.ICMPv6Type size int + extraData []byte statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter }{ { - "DstUnreachable", - header.ICMPv6DstUnreachable, - header.ICMPv6DstUnreachableMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "DstUnreachable", + typ: header.ICMPv6DstUnreachable, + size: header.ICMPv6DstUnreachableMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.DstUnreachable }, }, { - "PacketTooBig", - header.ICMPv6PacketTooBig, - header.ICMPv6PacketTooBigMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "PacketTooBig", + typ: header.ICMPv6PacketTooBig, + size: header.ICMPv6PacketTooBigMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.PacketTooBig }, }, { - "TimeExceeded", - header.ICMPv6TimeExceeded, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "TimeExceeded", + typ: header.ICMPv6TimeExceeded, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.TimeExceeded }, }, { - "ParamProblem", - header.ICMPv6ParamProblem, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "ParamProblem", + typ: header.ICMPv6ParamProblem, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.ParamProblem }, }, { - "EchoRequest", - header.ICMPv6EchoRequest, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "EchoRequest", + typ: header.ICMPv6EchoRequest, + size: header.ICMPv6EchoMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.EchoRequest }, }, { - "EchoReply", - header.ICMPv6EchoReply, - header.ICMPv6EchoMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "EchoReply", + typ: header.ICMPv6EchoReply, + size: header.ICMPv6EchoMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.EchoReply }, }, { - "RouterSolicit", - header.ICMPv6RouterSolicit, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "RouterSolicit", + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterSolicit }, }, { - "RouterAdvert", - header.ICMPv6RouterAdvert, - header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "RouterAdvert", + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterAdvert }, }, { - "NeighborSolicit", - header.ICMPv6NeighborSolicit, - header.ICMPv6NeighborSolicitMinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "NeighborSolicit", + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.NeighborSolicit }, }, { - "NeighborAdvert", - header.ICMPv6NeighborAdvert, - header.ICMPv6NeighborAdvertSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "NeighborAdvert", + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.NeighborAdvert }, }, { - "RedirectMsg", - header.ICMPv6RedirectMsg, - header.ICMPv6MinimumSize, - func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + name: "RedirectMsg", + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RedirectMsg }, }, @@ -494,16 +543,19 @@ func TestICMPChecksumValidationSimple(t *testing.T) { ) } - handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + size) - pkt := header.ICMPv6(hdr.Prepend(size)) - pkt.SetType(typ) + handleIPv6Payload := func(checksum bool) { + extraDataLen := len(typ.extraData) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) + extraData := buffer.View(hdr.Prepend(extraDataLen)) + copy(extraData, typ.extraData) + pkt := header.ICMPv6(hdr.Prepend(typ.size)) + pkt.SetType(typ.typ) if checksum { - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, extraData.ToVectorisedView())) } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(size), + PayloadLength: uint16(typ.size + extraDataLen), NextHeader: uint8(header.ICMPv6ProtocolNumber), HopLimit: header.NDPHopLimit, SrcAddr: lladdr1, @@ -528,7 +580,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { // Without setting checksum, the incoming packet should // be invalid. - handleIPv6Payload(typ.typ, typ.size, false) + handleIPv6Payload(false) if got := invalid.Value(); got != 1 { t.Fatalf("got invalid = %d, want = 1", got) } @@ -538,7 +590,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { } // When checksum is set, it should be received. - handleIPv6Payload(typ.typ, typ.size, true) + handleIPv6Payload(true) if got := typStat.Value(); got != 1 { t.Fatalf("got %s = %d, want = 1", typ.name, got) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index bd732f93f..c9395de52 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -70,76 +70,29 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack return s, ep } -// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving an -// NDP NS message with the Source Link Layer Address option results in a +// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a +// valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - e := channel.New(0, 1280, linkAddr0) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) - } - - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) - ns.SetTargetAddress(lladdr0) - ns.Options().Serialize(header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkAddr1), - }) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ - Data: hdr.View().ToVectorisedView(), - }) - - linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) - if err != nil { - t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - if c != nil { - t.Errorf("got unexpected channel") - } - if linkAddr != linkAddr1 { - t.Errorf("got link address = %s, want = %s", linkAddr, linkAddr1) - } -} - -// TestNeighorSolicitationWithInvalidSourceLinkLayerOption tests that receiving -// an NDP NS message with an invalid Source Link Layer Address option does not -// result in a new entry in the link address cache for the sender of the -// message. -func TestNeighorSolicitationWithInvalidSourceLinkLayerOption(t *testing.T) { - const nicID = 1 - tests := []struct { - name string - optsBuf []byte + name string + optsBuf []byte + expectedLinkAddr tcpip.LinkAddress }{ + { + name: "Valid", + optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7}, + expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", + }, { name: "Too Small", - optsBuf: []byte{1, 1, 1, 2, 3, 4, 5}, + optsBuf: []byte{1, 1, 2, 3, 4, 5, 6}, }, { name: "Invalid Length", - optsBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6}, + optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7}, }, } @@ -186,20 +139,138 @@ func TestNeighorSolicitationWithInvalidSourceLinkLayerOption(t *testing.T) { Data: hdr.View().ToVectorisedView(), }) - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) + linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) + if linkAddr != test.expectedLinkAddr { + t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) } - linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) + if test.expectedLinkAddr != "" { + if err != nil { + t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) + } + if c != nil { + t.Errorf("got unexpected channel") + } + + // Invalid count should not have increased. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) + } + if c == nil { + t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + } + + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Errorf("got invalid = %d, want = 1", got) + } + } + }) + } +} + +// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a +// valid NDP NA message with the Target Link Layer Address option results in a +// new entry in the link address cache for the target of the message. +func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + optsBuf []byte + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "Valid", + optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7}, + expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", + }, + { + name: "Too Small", + optsBuf: []byte{2, 1, 2, 3, 4, 5, 6}, + }, + { + name: "Invalid Length", + optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr0) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr1) + opts := ns.Options() + copy(opts, test.optsBuf) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) } - if c == nil { - t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) + if linkAddr != test.expectedLinkAddr { + t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) } - if linkAddr != "" { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (%s, _, ), want = ('', _, _)", nicID, lladdr1, lladdr0, ProtocolNumber, linkAddr) + + if test.expectedLinkAddr != "" { + if err != nil { + t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) + } + if c != nil { + t.Errorf("got unexpected channel") + } + + // Invalid count should not have increased. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) + } + if c == nil { + t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + } + + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Errorf("got invalid = %d, want = 1", got) + } } }) } @@ -238,27 +309,59 @@ func TestHopLimitValidation(t *testing.T) { }) } + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + types := []struct { name string typ header.ICMPv6Type size int + extraData []byte statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter }{ - {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }}, - {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6HeaderSize + header.NDPRAMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }}, - {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }}, - {"NeighborAdvert", header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }}, - {"RedirectMsg", header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg - }}, + { + name: "RouterSolicit", + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + }, + { + name: "RouterAdvert", + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + name: "NeighborSolicit", + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + name: "NeighborAdvert", + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + name: "RedirectMsg", + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, + }, } for _, typ := range types { @@ -270,10 +373,13 @@ func TestHopLimitValidation(t *testing.T) { invalid := stats.Invalid typStat := typ.statCounter(stats) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size) + extraDataLen := len(typ.extraData) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size + extraDataLen) + extraData := buffer.View(hdr.Prepend(extraDataLen)) + copy(extraData, typ.extraData) pkt := header.ICMPv6(hdr.Prepend(typ.size)) pkt.SetType(typ.typ) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, extraData.ToVectorisedView())) // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 8af8565f7..9a4607dcb 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -478,13 +478,17 @@ func TestDADFail(t *testing.T) { { "RxAdvert", func(tgt tcpip.Address) buffer.Prependable { - hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) - pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) + pkt := header.ICMPv6(hdr.Prepend(naSize)) pkt.SetType(header.ICMPv6NeighborAdvert) na := header.NDPNeighborAdvert(pkt.NDPPayload()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(tgt) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) -- cgit v1.2.3 From 736775e0ac592c266acac0a7415f21d54715a54c Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Thu, 6 Feb 2020 14:03:49 -0800 Subject: Make gonet consistent both internally and with the net package. The types gonet.Conn and gonet.PacketConn were confusingly named as both implemented net.Conn. Further, gonet.Conn was perhaps unexpectedly TCP-specific (net.Conn is not). This change renames them to gonet.TCPConn and gonet.UDPConn. Renames gonet.NewListener to gonet.ListenTCP and adds a new gonet.NewTCPListner function to be consistent with both the gonet.DialXxx and gonet.NewXxxConn functions as well as net.ListenTCP. Updates #1632 PiperOrigin-RevId: 293671303 --- benchmarks/tcp/tcp_proxy.go | 2 +- pkg/tcpip/adapters/gonet/gonet.go | 111 +++++++++++++++++---------------- pkg/tcpip/adapters/gonet/gonet_test.go | 22 +++---- 3 files changed, 70 insertions(+), 65 deletions(-) (limited to 'pkg/tcpip') diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go index 72ada5700..73b7c4f5b 100644 --- a/benchmarks/tcp/tcp_proxy.go +++ b/benchmarks/tcp/tcp_proxy.go @@ -274,7 +274,7 @@ func (n netstackImpl) listen(port int) (net.Listener, error) { NIC: nicID, Port: uint16(port), } - listener, err := gonet.NewListener(n.s, addr, ipv4.ProtocolNumber) + listener, err := gonet.ListenTCP(n.s, addr, ipv4.ProtocolNumber) if err != nil { return nil, err } diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 711969b9b..6e0db2741 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -43,18 +43,28 @@ func (e *timeoutError) Error() string { return "i/o timeout" } func (e *timeoutError) Timeout() bool { return true } func (e *timeoutError) Temporary() bool { return true } -// A Listener is a wrapper around a tcpip endpoint that implements +// A TCPListener is a wrapper around a TCP tcpip.Endpoint that implements // net.Listener. -type Listener struct { +type TCPListener struct { stack *stack.Stack ep tcpip.Endpoint wq *waiter.Queue cancel chan struct{} } -// NewListener creates a new Listener. -func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) { - // Create TCP endpoint, bind it, then start listening. +// NewTCPListener creates a new TCPListener from a listening tcpip.Endpoint. +func NewTCPListener(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *TCPListener { + return &TCPListener{ + stack: s, + ep: ep, + wq: wq, + cancel: make(chan struct{}), + } +} + +// ListenTCP creates a new TCPListener. +func ListenTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPListener, error) { + // Create a TCP endpoint, bind it, then start listening. var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) if err != nil { @@ -81,28 +91,23 @@ func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkPr } } - return &Listener{ - stack: s, - ep: ep, - wq: &wq, - cancel: make(chan struct{}), - }, nil + return NewTCPListener(s, &wq, ep), nil } // Close implements net.Listener.Close. -func (l *Listener) Close() error { +func (l *TCPListener) Close() error { l.ep.Close() return nil } // Shutdown stops the HTTP server. -func (l *Listener) Shutdown() { +func (l *TCPListener) Shutdown() { l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) close(l.cancel) // broadcast cancellation } // Addr implements net.Listener.Addr. -func (l *Listener) Addr() net.Addr { +func (l *TCPListener) Addr() net.Addr { a, err := l.ep.GetLocalAddress() if err != nil { return nil @@ -208,9 +213,9 @@ func (d *deadlineTimer) SetDeadline(t time.Time) error { return nil } -// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn +// A TCPConn is a wrapper around a TCP tcpip.Endpoint that implements the net.Conn // interface. -type Conn struct { +type TCPConn struct { deadlineTimer wq *waiter.Queue @@ -228,9 +233,9 @@ type Conn struct { read buffer.View } -// NewConn creates a new Conn. -func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn { - c := &Conn{ +// NewTCPConn creates a new TCPConn. +func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { + c := &TCPConn{ wq: wq, ep: ep, } @@ -239,7 +244,7 @@ func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn { } // Accept implements net.Conn.Accept. -func (l *Listener) Accept() (net.Conn, error) { +func (l *TCPListener) Accept() (net.Conn, error) { n, wq, err := l.ep.Accept() if err == tcpip.ErrWouldBlock { @@ -272,7 +277,7 @@ func (l *Listener) Accept() (net.Conn, error) { } } - return NewConn(wq, n), nil + return NewTCPConn(wq, n), nil } type opErrorer interface { @@ -323,7 +328,7 @@ func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, a } // Read implements net.Conn.Read. -func (c *Conn) Read(b []byte) (int, error) { +func (c *TCPConn) Read(b []byte) (int, error) { c.readMu.Lock() defer c.readMu.Unlock() @@ -352,7 +357,7 @@ func (c *Conn) Read(b []byte) (int, error) { } // Write implements net.Conn.Write. -func (c *Conn) Write(b []byte) (int, error) { +func (c *TCPConn) Write(b []byte) (int, error) { deadline := c.writeCancel() // Check if deadlineTimer has already expired. @@ -431,7 +436,7 @@ func (c *Conn) Write(b []byte) (int, error) { } // Close implements net.Conn.Close. -func (c *Conn) Close() error { +func (c *TCPConn) Close() error { c.ep.Close() return nil } @@ -440,7 +445,7 @@ func (c *Conn) Close() error { // should just use Close. // // A TCP Half-Close is performed the same as CloseRead for *net.TCPConn. -func (c *Conn) CloseRead() error { +func (c *TCPConn) CloseRead() error { if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil { return c.newOpError("close", errors.New(terr.String())) } @@ -451,7 +456,7 @@ func (c *Conn) CloseRead() error { // should just use Close. // // A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn. -func (c *Conn) CloseWrite() error { +func (c *TCPConn) CloseWrite() error { if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil { return c.newOpError("close", errors.New(terr.String())) } @@ -459,7 +464,7 @@ func (c *Conn) CloseWrite() error { } // LocalAddr implements net.Conn.LocalAddr. -func (c *Conn) LocalAddr() net.Addr { +func (c *TCPConn) LocalAddr() net.Addr { a, err := c.ep.GetLocalAddress() if err != nil { return nil @@ -468,7 +473,7 @@ func (c *Conn) LocalAddr() net.Addr { } // RemoteAddr implements net.Conn.RemoteAddr. -func (c *Conn) RemoteAddr() net.Addr { +func (c *TCPConn) RemoteAddr() net.Addr { a, err := c.ep.GetRemoteAddress() if err != nil { return nil @@ -476,7 +481,7 @@ func (c *Conn) RemoteAddr() net.Addr { return fullToTCPAddr(a) } -func (c *Conn) newOpError(op string, err error) *net.OpError { +func (c *TCPConn) newOpError(op string, err error) *net.OpError { return &net.OpError{ Op: op, Net: "tcp", @@ -494,14 +499,14 @@ func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr { return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)} } -// DialTCP creates a new TCP Conn connected to the specified address. -func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { +// DialTCP creates a new TCPConn connected to the specified address. +func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { return DialContextTCP(context.Background(), s, addr, network) } -// DialContextTCP creates a new TCP Conn connected to the specified address +// DialContextTCP creates a new TCPConn connected to the specified address // with the option of adding cancellation and timeouts. -func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) { +func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) { // Create TCP endpoint, then connect. var wq waiter.Queue ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq) @@ -543,12 +548,12 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, } } - return NewConn(&wq, ep), nil + return NewTCPConn(&wq, ep), nil } -// A PacketConn is a wrapper around a tcpip endpoint that implements -// net.PacketConn. -type PacketConn struct { +// A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements +// net.Conn and net.PacketConn. +type UDPConn struct { deadlineTimer stack *stack.Stack @@ -556,9 +561,9 @@ type PacketConn struct { wq *waiter.Queue } -// NewPacketConn creates a new PacketConn. -func NewPacketConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *PacketConn { - c := &PacketConn{ +// NewUDPConn creates a new UDPConn. +func NewUDPConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *UDPConn { + c := &UDPConn{ stack: s, ep: ep, wq: wq, @@ -567,12 +572,12 @@ func NewPacketConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *PacketC return c } -// DialUDP creates a new PacketConn. +// DialUDP creates a new UDPConn. // // If laddr is nil, a local address is automatically chosen. // -// If raddr is nil, the PacketConn is left unconnected. -func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) { +// If raddr is nil, the UDPConn is left unconnected. +func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*UDPConn, error) { var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq) if err != nil { @@ -591,7 +596,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw } } - c := NewPacketConn(s, &wq, ep) + c := NewUDPConn(s, &wq, ep) if raddr != nil { if err := c.ep.Connect(*raddr); err != nil { @@ -608,11 +613,11 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw return c, nil } -func (c *PacketConn) newOpError(op string, err error) *net.OpError { +func (c *UDPConn) newOpError(op string, err error) *net.OpError { return c.newRemoteOpError(op, nil, err) } -func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError { +func (c *UDPConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError { return &net.OpError{ Op: op, Net: "udp", @@ -623,7 +628,7 @@ func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *ne } // RemoteAddr implements net.Conn.RemoteAddr. -func (c *PacketConn) RemoteAddr() net.Addr { +func (c *UDPConn) RemoteAddr() net.Addr { a, err := c.ep.GetRemoteAddress() if err != nil { return nil @@ -632,13 +637,13 @@ func (c *PacketConn) RemoteAddr() net.Addr { } // Read implements net.Conn.Read -func (c *PacketConn) Read(b []byte) (int, error) { +func (c *UDPConn) Read(b []byte) (int, error) { bytesRead, _, err := c.ReadFrom(b) return bytesRead, err } // ReadFrom implements net.PacketConn.ReadFrom. -func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { +func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) { deadline := c.readCancel() var addr tcpip.FullAddress @@ -650,12 +655,12 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { return copy(b, read), fullToUDPAddr(addr), nil } -func (c *PacketConn) Write(b []byte) (int, error) { +func (c *UDPConn) Write(b []byte) (int, error) { return c.WriteTo(b, nil) } // WriteTo implements net.PacketConn.WriteTo. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { deadline := c.writeCancel() // Check if deadline has already expired. @@ -713,13 +718,13 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { } // Close implements net.PacketConn.Close. -func (c *PacketConn) Close() error { +func (c *UDPConn) Close() error { c.ep.Close() return nil } // LocalAddr implements net.PacketConn.LocalAddr. -func (c *PacketConn) LocalAddr() net.Addr { +func (c *UDPConn) LocalAddr() net.Addr { a, err := c.ep.GetLocalAddress() if err != nil { return nil diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index ee077ae83..ea0a0409a 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -41,7 +41,7 @@ const ( ) func TestTimeouts(t *testing.T) { - nc := NewConn(nil, nil) + nc := NewTCPConn(nil, nil) dlfs := []struct { name string f func(time.Time) error @@ -132,7 +132,7 @@ func TestCloseReader(t *testing.T) { s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - l, e := NewListener(s, addr, ipv4.ProtocolNumber) + l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) if e != nil { t.Fatalf("NewListener() = %v", e) } @@ -168,7 +168,7 @@ func TestCloseReader(t *testing.T) { sender.close() } -// TestCloseReaderWithForwarder tests that Conn.Close() wakes Conn.Read() when +// TestCloseReaderWithForwarder tests that TCPConn.Close wakes TCPConn.Read when // using tcp.Forwarder. func TestCloseReaderWithForwarder(t *testing.T) { s, err := newLoopbackStack() @@ -192,7 +192,7 @@ func TestCloseReaderWithForwarder(t *testing.T) { defer ep.Close() r.Complete(false) - c := NewConn(&wq, ep) + c := NewTCPConn(&wq, ep) // Give c.Read() a chance to block before closing the connection. time.AfterFunc(time.Millisecond*50, func() { @@ -238,7 +238,7 @@ func TestCloseRead(t *testing.T) { defer ep.Close() r.Complete(false) - c := NewConn(&wq, ep) + c := NewTCPConn(&wq, ep) buf := make([]byte, 256) n, e := c.Read(buf) @@ -257,7 +257,7 @@ func TestCloseRead(t *testing.T) { if terr != nil { t.Fatalf("connect() = %v", terr) } - c := NewConn(tc.wq, tc.ep) + c := NewTCPConn(tc.wq, tc.ep) if err := c.CloseRead(); err != nil { t.Errorf("c.CloseRead() = %v", err) @@ -291,7 +291,7 @@ func TestCloseWrite(t *testing.T) { defer ep.Close() r.Complete(false) - c := NewConn(&wq, ep) + c := NewTCPConn(&wq, ep) n, e := c.Read(make([]byte, 256)) if n != 0 || e != io.EOF { @@ -309,7 +309,7 @@ func TestCloseWrite(t *testing.T) { if terr != nil { t.Fatalf("connect() = %v", terr) } - c := NewConn(tc.wq, tc.ep) + c := NewTCPConn(tc.wq, tc.ep) if err := c.CloseWrite(); err != nil { t.Errorf("c.CloseWrite() = %v", err) @@ -353,7 +353,7 @@ func TestUDPForwarder(t *testing.T) { } defer ep.Close() - c := NewConn(&wq, ep) + c := NewTCPConn(&wq, ep) buf := make([]byte, 256) n, e := c.Read(buf) @@ -396,7 +396,7 @@ func TestDeadlineChange(t *testing.T) { s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) - l, e := NewListener(s, addr, ipv4.ProtocolNumber) + l, e := ListenTCP(s, addr, ipv4.ProtocolNumber) if e != nil { t.Fatalf("NewListener() = %v", e) } @@ -541,7 +541,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { addr := tcpip.FullAddress{NICID, ip, 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, ip) - l, err := NewListener(s, addr, ipv4.ProtocolNumber) + l, err := ListenTCP(s, addr, ipv4.ProtocolNumber) if err != nil { return nil, nil, nil, fmt.Errorf("NewListener: %v", err) } -- cgit v1.2.3 From 940d255971c38af9f91ceed1345fd973f8fdb41d Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 6 Feb 2020 15:57:34 -0800 Subject: Perform DAD on IPv6 addresses when enabling a NIC Addresses may be added before a NIC is enabled. Make sure DAD is performed on the permanent IPv6 addresses when they get enabled. Test: - stack_test.TestDoDADWhenNICEnabled - stack.TestDisabledRxStatsWhenNICDisabled PiperOrigin-RevId: 293697429 --- pkg/tcpip/stack/BUILD | 6 ++- pkg/tcpip/stack/ndp_test.go | 74 +++++++++++++-------------- pkg/tcpip/stack/nic.go | 84 ++++++++++++++++++++++-------- pkg/tcpip/stack/nic_test.go | 62 +++++++++++++++++++++++ pkg/tcpip/stack/stack_test.go | 115 ++++++++++++++++++++++++++++++++++++++++++ pkg/tcpip/tcpip.go | 8 +-- 6 files changed, 287 insertions(+), 62 deletions(-) create mode 100644 pkg/tcpip/stack/nic_test.go (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index f5b750046..705cf01ee 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -78,11 +78,15 @@ go_test( go_test( name = "stack_test", size = "small", - srcs = ["linkaddrcache_test.go"], + srcs = [ + "linkaddrcache_test.go", + "nic_test.go", + ], library = ":stack", deps = [ "//pkg/sleep", "//pkg/sync", "//pkg/tcpip", + "//pkg/tcpip/buffer", ], ) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 9a4607dcb..1e575bdaf 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1539,7 +1539,7 @@ func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { } // Checks to see if list contains an IPv6 address, item. -func contains(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { +func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool { protocolAddress := tcpip.ProtocolAddress{ Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: item, @@ -1665,7 +1665,7 @@ func TestAutoGenAddr(t *testing.T) { // with non-zero lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) expectAutoGenAddrEvent(addr1, newAddr) - if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should have %s in the list of addresses", addr1) } @@ -1681,10 +1681,10 @@ func TestAutoGenAddr(t *testing.T) { // Receive an RA with prefix2 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) expectAutoGenAddrEvent(addr2, newAddr) - if !contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { t.Fatalf("Should have %s in the list of addresses", addr2) } @@ -1705,10 +1705,10 @@ func TestAutoGenAddr(t *testing.T) { case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - if contains(s.NICInfo()[1].ProtocolAddresses, addr1) { + if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should not have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) { t.Fatalf("Should have %s in the list of addresses", addr2) } } @@ -1853,7 +1853,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // Receive PI for prefix1. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) expectAutoGenAddrEvent(addr1, newAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should have %s in the list of addresses", addr1) } expectPrimaryAddr(addr1) @@ -1861,7 +1861,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // Deprecate addr for prefix1 immedaitely. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) expectAutoGenAddrEvent(addr1, deprecatedAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should have %s in the list of addresses", addr1) } // addr should still be the primary endpoint as there are no other addresses. @@ -1879,7 +1879,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // Receive PI for prefix2. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) expectAutoGenAddrEvent(addr2, newAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } expectPrimaryAddr(addr2) @@ -1887,7 +1887,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // Deprecate addr for prefix2 immedaitely. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) expectAutoGenAddrEvent(addr2, deprecatedAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } // addr1 should be the primary endpoint now since addr2 is deprecated but @@ -1982,7 +1982,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // Receive PI for prefix2. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) expectAutoGenAddrEvent(addr2, newAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } expectPrimaryAddr(addr2) @@ -1990,10 +1990,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // Receive a PI for prefix1. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) expectAutoGenAddrEvent(addr1, newAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } expectPrimaryAddr(addr1) @@ -2009,10 +2009,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // Wait for addr of prefix1 to be deprecated. expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } // addr2 should be the primary endpoint now since addr1 is deprecated but @@ -2049,10 +2049,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // Wait for addr of prefix1 to be deprecated. expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } // addr2 should be the primary endpoint now since it is not deprecated. @@ -2063,10 +2063,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { // Wait for addr of prefix1 to be invalidated. expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout) - if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } expectPrimaryAddr(addr2) @@ -2112,10 +2112,10 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } - if contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should not have %s in the list of addresses", addr2) } // Should not have any primary endpoints. @@ -2600,7 +2600,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil { t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err) } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) } @@ -2613,7 +2613,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically") default: } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) } @@ -2624,7 +2624,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { t.Fatal("unexpectedly received an auto gen addr event") case <-time.After(lifetimeSeconds*time.Second + defaultTimeout): } - if !contains(s.NICInfo()[1].ProtocolAddresses, addr) { + if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) } } @@ -2702,17 +2702,17 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { const validLifetimeSecondPrefix1 = 1 e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0)) expectAutoGenAddrEvent(addr1, newAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should have %s in the list of addresses", addr1) } // Receive an RA with prefix2 in a PI with a large valid lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) expectAutoGenAddrEvent(addr2, newAddr) - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } @@ -2725,10 +2725,10 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } - if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) } } @@ -3014,16 +3014,16 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { nicinfo := s.NICInfo() nic1Addrs := nicinfo[nicID1].ProtocolAddresses nic2Addrs := nicinfo[nicID2].ProtocolAddresses - if !contains(nic1Addrs, e1Addr1) { + if !containsV6Addr(nic1Addrs, e1Addr1) { t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) } - if !contains(nic1Addrs, e1Addr2) { + if !containsV6Addr(nic1Addrs, e1Addr2) { t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) } - if !contains(nic2Addrs, e2Addr1) { + if !containsV6Addr(nic2Addrs, e2Addr1) { t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) } - if !contains(nic2Addrs, e2Addr2) { + if !containsV6Addr(nic2Addrs, e2Addr2) { t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) } @@ -3102,16 +3102,16 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { nicinfo = s.NICInfo() nic1Addrs = nicinfo[nicID1].ProtocolAddresses nic2Addrs = nicinfo[nicID2].ProtocolAddresses - if contains(nic1Addrs, e1Addr1) { + if containsV6Addr(nic1Addrs, e1Addr1) { t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) } - if contains(nic1Addrs, e1Addr2) { + if containsV6Addr(nic1Addrs, e1Addr2) { t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) } - if contains(nic2Addrs, e2Addr1) { + if containsV6Addr(nic2Addrs, e2Addr1) { t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) } - if contains(nic2Addrs, e2Addr2) { + if containsV6Addr(nic2Addrs, e2Addr2) { t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 7dad9a8cb..682e9c416 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,6 +16,7 @@ package stack import ( "log" + "reflect" "sort" "strings" "sync/atomic" @@ -39,6 +40,7 @@ type NIC struct { mu struct { sync.RWMutex + enabled bool spoofing bool promiscuous bool primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint @@ -56,6 +58,14 @@ type NIC struct { type NICStats struct { Tx DirectionStats Rx DirectionStats + + DisabledRx DirectionStats +} + +func makeNICStats() NICStats { + var s NICStats + tcpip.InitStatCounters(reflect.ValueOf(&s).Elem()) + return s } // DirectionStats includes packet and byte counts. @@ -99,16 +109,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC name: name, linkEP: ep, context: ctx, - stats: NICStats{ - Tx: DirectionStats{ - Packets: &tcpip.StatCounter{}, - Bytes: &tcpip.StatCounter{}, - }, - Rx: DirectionStats{ - Packets: &tcpip.StatCounter{}, - Bytes: &tcpip.StatCounter{}, - }, - }, + stats: makeNICStats(), } nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) @@ -137,14 +138,30 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // enable enables the NIC. enable will attach the link to its LinkEndpoint and // join the IPv6 All-Nodes Multicast address (ff02::1). func (n *NIC) enable() *tcpip.Error { + n.mu.RLock() + enabled := n.mu.enabled + n.mu.RUnlock() + if enabled { + return nil + } + + n.mu.Lock() + defer n.mu.Unlock() + + if n.mu.enabled { + return nil + } + + n.mu.enabled = true + n.attachLinkEndpoint() // Create an endpoint to receive broadcast packets on this interface. if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - if err := n.AddAddress(tcpip.ProtocolAddress{ + if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, - }, NeverPrimaryEndpoint); err != nil { + }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { return err } } @@ -166,8 +183,22 @@ func (n *NIC) enable() *tcpip.Error { return nil } - n.mu.Lock() - defer n.mu.Unlock() + // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent + // state. + // + // Addresses may have aleady completed DAD but in the time since the NIC was + // last enabled, other devices may have acquired the same addresses. + for _, r := range n.mu.endpoints { + addr := r.ep.ID().LocalAddress + if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) { + continue + } + + r.setKind(permanentTentative) + if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil { + return err + } + } if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { return err @@ -633,7 +664,9 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) // If the address is an IPv6 address and it is a permanent address, - // mark it as tentative so it goes through the DAD process. + // mark it as tentative so it goes through the DAD process if the NIC is + // enabled. If the NIC is not enabled, DAD will be started when the NIC is + // enabled. if isIPv6Unicast && kind == permanent { kind = permanentTentative } @@ -668,8 +701,8 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar n.insertPrimaryEndpointLocked(ref, peb) - // If we are adding a tentative IPv6 address, start DAD. - if isIPv6Unicast && kind == permanentTentative { + // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled. + if isIPv6Unicast && kind == permanentTentative && n.mu.enabled { if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { return nil, err } @@ -700,9 +733,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { // Don't include tentative, expired or temporary endpoints to // avoid confusion and prevent the caller from using those. switch ref.getKind() { - case permanentTentative, permanentExpired, temporary: - // TODO(b/140898488): Should tentative addresses be - // returned? + case permanentExpired, temporary: continue } addrs = append(addrs, tcpip.ProtocolAddress{ @@ -1016,11 +1047,23 @@ func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, // This rule applies only to the slice itself, not to the items of the slice; // the ownership of the items is not retained by the caller. func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { + n.mu.RLock() + enabled := n.mu.enabled + // If the NIC is not yet enabled, don't receive any packets. + if !enabled { + n.mu.RUnlock() + + n.stats.DisabledRx.Packets.Increment() + n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size())) + return + } + n.stats.Rx.Packets.Increment() n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) netProto, ok := n.stack.networkProtocols[protocol] if !ok { + n.mu.RUnlock() n.stack.stats.UnknownProtocolRcvdPackets.Increment() return } @@ -1032,7 +1075,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } // Are any packet sockets listening for this network protocol? - n.mu.RLock() packetEPs := n.mu.packetEPs[protocol] // Check whether there are packet sockets listening for every protocol. // If we received a packet with protocol EthernetProtocolAll, then the diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go new file mode 100644 index 000000000..edaee3b86 --- /dev/null +++ b/pkg/tcpip/stack/nic_test.go @@ -0,0 +1,62 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { + // When the NIC is disabled, the only field that matters is the stats field. + // This test is limited to stats counter checks. + nic := NIC{ + stats: makeNICStats(), + } + + if got := nic.stats.DisabledRx.Packets.Value(); got != 0 { + t.Errorf("got DisabledRx.Packets = %d, want = 0", got) + } + if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 { + t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) + } + if got := nic.stats.Rx.Packets.Value(); got != 0 { + t.Errorf("got Rx.Packets = %d, want = 0", got) + } + if got := nic.stats.Rx.Bytes.Value(); got != 0 { + t.Errorf("got Rx.Bytes = %d, want = 0", got) + } + + if t.Failed() { + t.FailNow() + } + + nic.DeliverNetworkPacket(nil, "", "", 0, tcpip.PacketBuffer{Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView()}) + + if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { + t.Errorf("got DisabledRx.Packets = %d, want = 1", got) + } + if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 { + t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) + } + if got := nic.stats.Rx.Packets.Value(); got != 0 { + t.Errorf("got Rx.Packets = %d, want = 0", got) + } + if got := nic.stats.Rx.Bytes.Value(); got != 0 { + t.Errorf("got Rx.Bytes = %d, want = 0", got) + } +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 834fe9487..243868f3a 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2561,3 +2561,118 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { }) } } + +// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC +// was disabled have DAD performed on them when the NIC is enabled. +func TestDoDADWhenNICEnabled(t *testing.T) { + t.Parallel() + + const dadTransmits = 1 + const retransmitTimer = time.Second + const nicID = 1 + + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent), + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + DupAddrDetectTransmits: dadTransmits, + RetransmitTimer: retransmitTimer, + }, + NDPDisp: &ndpDisp, + } + + e := channel.New(dadTransmits, 1280, linkAddr1) + s := stack.New(opts) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + addr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: llAddr1, + PrefixLen: 128, + }, + } + if err := s.AddProtocolAddress(nicID, addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err) + } + + // Address should be in the list of all addresses. + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + + // Address should be tentative so it should not be a main address. + got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); got != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + } + + // Enabling the NIC should start DAD for the address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + + // Address should not be considered bound to the NIC yet (DAD ongoing). + got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); got != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want) + } + + // Wait for DAD to resolve. + select { + case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): + t.Fatal("timed out waiting for DAD resolution") + case e := <-ndpDisp.dadC: + if e.err != nil { + t.Fatal("got DAD error: ", e.err) + } + if e.nicID != nicID { + t.Fatalf("got DAD event w/ nicID = %d, want = %d", e.nicID, nicID) + } + if e.addr != addr.AddressWithPrefix.Address { + t.Fatalf("got DAD event w/ addr = %s, want = %s", e.addr, addr.AddressWithPrefix.Address) + } + if !e.resolved { + t.Fatal("got DAD event w/ resolved = false, want = true") + } + } + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if got != addr.AddressWithPrefix { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + } + + // Enabling the NIC again should be a no-op. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { + t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) + } + got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if got != addr.AddressWithPrefix { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) + } +} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index d29d9a704..0e944712f 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1170,7 +1170,9 @@ type TransportEndpointStats struct { // marker interface. func (*TransportEndpointStats) IsEndpointStats() {} -func fillIn(v reflect.Value) { +// InitStatCounters initializes v's fields with nil StatCounter fields to new +// StatCounters. +func InitStatCounters(v reflect.Value) { for i := 0; i < v.NumField(); i++ { v := v.Field(i) if s, ok := v.Addr().Interface().(**StatCounter); ok { @@ -1178,14 +1180,14 @@ func fillIn(v reflect.Value) { *s = new(StatCounter) } } else { - fillIn(v) + InitStatCounters(v) } } } // FillIn returns a copy of s with nil fields initialized to new StatCounters. func (s Stats) FillIn() Stats { - fillIn(reflect.ValueOf(&s).Elem()) + InitStatCounters(reflect.ValueOf(&s).Elem()) return s } -- cgit v1.2.3 From 3700221b1f3ff0779a0f4479fd2bafa3312d5a23 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 6 Feb 2020 16:42:37 -0800 Subject: Auto-generate link-local address as a SLAAC address Auto-generated link-local addresses should have the same lifecycle hooks as global SLAAC addresses. The Stack's NDP dispatcher should be notified when link-local addresses are auto-generated and invalidated. They should also be removed when a NIC is disabled (which will be supported in a later change). Tests: - stack_test.TestNICAutoGenAddrWithOpaque - stack_test.TestNICAutoGenAddr PiperOrigin-RevId: 293706760 --- pkg/tcpip/stack/ndp.go | 36 +++-- pkg/tcpip/stack/ndp_test.go | 85 +++++++----- pkg/tcpip/stack/nic.go | 30 +---- pkg/tcpip/stack/stack_test.go | 307 +++++++++++++++++++----------------------- 4 files changed, 218 insertions(+), 240 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 6123fda33..fae5f5014 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -906,22 +906,21 @@ func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInform return } - // We do not already have an address within the prefix, prefix. Do the + // We do not already have an address with the prefix prefix. Do the // work as outlined by RFC 4862 section 5.5.3.d if n is configured - // to auto-generated global addresses by SLAAC. - ndp.newAutoGenAddress(prefix, pl, vl) + // to auto-generate global addresses by SLAAC. + if !ndp.configs.AutoGenGlobalAddresses { + return + } + + ndp.doSLAAC(prefix, pl, vl) } -// newAutoGenAddress generates a new SLAAC address with the provided lifetimes +// doSLAAC generates a new SLAAC address with the provided lifetimes // for prefix. // // pl is the new preferred lifetime. vl is the new valid lifetime. -func (ndp *ndpState) newAutoGenAddress(prefix tcpip.Subnet, pl, vl time.Duration) { - // Are we configured to auto-generate new global addresses? - if !ndp.configs.AutoGenGlobalAddresses { - return - } - +func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // If we do not already have an address for this prefix and the valid // lifetime is 0, no need to do anything further, as per RFC 4862 // section 5.5.3.d. @@ -1152,12 +1151,21 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupHostOnlyState() { + linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() + linkLocalAddrs := 0 for addr := range ndp.autoGenAddresses { + // RFC 4862 section 5 states that routers are also expected to generate a + // link-local address so we do not invalidate them. + if linkLocalSubnet.Contains(addr) { + linkLocalAddrs++ + continue + } + ndp.invalidateAutoGenAddress(addr) } - if got := len(ndp.autoGenAddresses); got != 0 { - log.Fatalf("ndp: still have auto-generated addresses after cleaning up, found = %d", got) + if got := len(ndp.autoGenAddresses); got != linkLocalAddrs { + log.Fatalf("ndp: still have non-linklocal auto-generated addresses after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalAddrs) } for prefix := range ndp.onLinkPrefixes { @@ -1165,7 +1173,7 @@ func (ndp *ndpState) cleanupHostOnlyState() { } if got := len(ndp.onLinkPrefixes); got != 0 { - log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up, found = %d", got) + log.Fatalf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got) } for router := range ndp.defaultRouters { @@ -1173,7 +1181,7 @@ func (ndp *ndpState) cleanupHostOnlyState() { } if got := len(ndp.defaultRouters); got != 0 { - log.Fatalf("ndp: still have discovered default routers after cleaning up, found = %d", got) + log.Fatalf("ndp: still have discovered default routers after cleaning up; found = %d", got) } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 1e575bdaf..e13509fbd 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -42,6 +42,7 @@ const ( linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") + linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") defaultTimeout = 100 * time.Millisecond defaultAsyncEventTimeout = time.Second ) @@ -50,6 +51,7 @@ var ( llAddr1 = header.LinkLocalAddr(linkAddr1) llAddr2 = header.LinkLocalAddr(linkAddr2) llAddr3 = header.LinkLocalAddr(linkAddr3) + llAddr4 = header.LinkLocalAddr(linkAddr4) dstAddr = tcpip.FullAddress{ Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", Port: 25, @@ -2882,8 +2884,8 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } // TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers -// and prefixes, and auto-generated addresses get invalidated when a NIC -// becomes a router. +// and prefixes, and non-linklocal auto-generated addresses are invalidated when +// a NIC becomes a router. func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { t.Parallel() @@ -2898,6 +2900,14 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1) e2Addr1 := addrForSubnet(subnet1, linkAddr2) e2Addr2 := addrForSubnet(subnet2, linkAddr2) + llAddrWithPrefix1 := tcpip.AddressWithPrefix{ + Address: llAddr1, + PrefixLen: 64, + } + llAddrWithPrefix2 := tcpip.AddressWithPrefix{ + Address: llAddr2, + PrefixLen: 64, + } ndpDisp := ndpDispatcher{ routerC: make(chan ndpRouterEvent, maxEvents), @@ -2907,7 +2917,8 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents), } s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: true, NDPConfigs: stack.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: true, @@ -2917,16 +2928,6 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { NDPDisp: &ndpDisp, }) - e1 := channel.New(0, 1280, linkAddr1) - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) - } - - e2 := channel.New(0, 1280, linkAddr2) - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) - } - expectRouterEvent := func() (bool, ndpRouterEvent) { select { case e := <-ndpDisp.routerC: @@ -2957,18 +2958,30 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { return false, ndpAutoGenAddrEvent{} } - // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr1 and - // llAddr2) w/ PI (for prefix1 in RA from llAddr1 and prefix2 in RA from - // llAddr2) to discover multiple routers and prefixes, and auto-gen - // multiple addresses. - - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr1, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + e1 := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) + } // We have other tests that make sure we receive the *correct* events // on normal discovery of routers/prefixes, and auto-generated // addresses. Here we just make sure we get an event and let other tests // handle the correctness check. + expectAutoGenAddrEvent() + + e2 := channel.New(0, 1280, linkAddr2) + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) + } + expectAutoGenAddrEvent() + + // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and + // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from + // llAddr4) to discover multiple routers and prefixes, and auto-gen + // multiple addresses. + + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr1, nicID1) + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) @@ -2977,9 +2990,9 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) } - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr2, nicID1) + t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) @@ -2988,9 +3001,9 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) } - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr1, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr1, nicID2) + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) @@ -2999,9 +3012,9 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) } - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr2, nicID2) + t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) @@ -3014,12 +3027,18 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { nicinfo := s.NICInfo() nic1Addrs := nicinfo[nicID1].ProtocolAddresses nic2Addrs := nicinfo[nicID2].ProtocolAddresses + if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } if !containsV6Addr(nic1Addrs, e1Addr1) { t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) } if !containsV6Addr(nic1Addrs, e1Addr2) { t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) } + if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } if !containsV6Addr(nic2Addrs, e2Addr1) { t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) } @@ -3071,10 +3090,10 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { } expectedRouterEvents := map[ndpRouterEvent]int{ - {nicID: nicID1, addr: llAddr1, discovered: false}: 1, - {nicID: nicID1, addr: llAddr2, discovered: false}: 1, - {nicID: nicID2, addr: llAddr1, discovered: false}: 1, - {nicID: nicID2, addr: llAddr2, discovered: false}: 1, + {nicID: nicID1, addr: llAddr3, discovered: false}: 1, + {nicID: nicID1, addr: llAddr4, discovered: false}: 1, + {nicID: nicID2, addr: llAddr3, discovered: false}: 1, + {nicID: nicID2, addr: llAddr4, discovered: false}: 1, } if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { t.Errorf("router events mismatch (-want +got):\n%s", diff) @@ -3102,12 +3121,18 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { nicinfo = s.NICInfo() nic1Addrs = nicinfo[nicID1].ProtocolAddresses nic2Addrs = nicinfo[nicID2].ProtocolAddresses + if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } if containsV6Addr(nic1Addrs, e1Addr1) { t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) } if containsV6Addr(nic1Addrs, e1Addr2) { t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) } + if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } if containsV6Addr(nic2Addrs, e2Addr1) { t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 682e9c416..78d451cca 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -206,33 +206,9 @@ func (n *NIC) enable() *tcpip.Error { // Do not auto-generate an IPv6 link-local address for loopback devices. if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() { - var addr tcpip.Address - if oIID := n.stack.opaqueIIDOpts; oIID.NICNameFromID != nil { - addr = header.LinkLocalAddrWithOpaqueIID(oIID.NICNameFromID(n.ID(), n.name), 0, oIID.SecretKey) - } else { - l2addr := n.linkEP.LinkAddress() - - // Only attempt to generate the link-local address if we have a valid MAC - // address. - // - // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by - // LinkEndpoint.LinkAddress) before reaching this point. - if !header.IsValidUnicastEthernetAddress(l2addr) { - return nil - } - - addr = header.LinkLocalAddr(l2addr) - } - - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: addr, - PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, - }, - }, CanBePrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { - return err - } + // The valid and preferred lifetime is infinite for the auto-generated + // link-local address. + n.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) } // If we are operating as a router, then do not solicit routers since we diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 243868f3a..b2c1763bf 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1894,112 +1894,6 @@ func TestNICForwarding(t *testing.T) { } } -// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses -// using the modified EUI-64 of the NIC's MAC address (or lack there-of if -// disabled (default)). Note, DAD will be disabled in these tests. -func TestNICAutoGenAddr(t *testing.T) { - tests := []struct { - name string - autoGen bool - linkAddr tcpip.LinkAddress - iidOpts stack.OpaqueInterfaceIdentifierOptions - shouldGen bool - }{ - { - "Disabled", - false, - linkAddr1, - stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(nicID tcpip.NICID, _ string) string { - return fmt.Sprintf("nic%d", nicID) - }, - }, - false, - }, - { - "Enabled", - true, - linkAddr1, - stack.OpaqueInterfaceIdentifierOptions{}, - true, - }, - { - "Nil MAC", - true, - tcpip.LinkAddress([]byte(nil)), - stack.OpaqueInterfaceIdentifierOptions{}, - false, - }, - { - "Empty MAC", - true, - tcpip.LinkAddress(""), - stack.OpaqueInterfaceIdentifierOptions{}, - false, - }, - { - "Invalid MAC", - true, - tcpip.LinkAddress("\x01\x02\x03"), - stack.OpaqueInterfaceIdentifierOptions{}, - false, - }, - { - "Multicast MAC", - true, - tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - stack.OpaqueInterfaceIdentifierOptions{}, - false, - }, - { - "Unspecified MAC", - true, - tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), - stack.OpaqueInterfaceIdentifierOptions{}, - false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - OpaqueIIDOpts: test.iidOpts, - } - - if test.autoGen { - // Only set opts.AutoGenIPv6LinkLocal when test.autoGen is true because - // opts.AutoGenIPv6LinkLocal should be false by default. - opts.AutoGenIPv6LinkLocal = true - } - - e := channel.New(10, 1280, test.linkAddr) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - - if test.shouldGen { - // Should have auto-generated an address and resolved immediately (DAD - // is disabled). - if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) - } - } else { - // Should not have auto-generated an address. - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) - } - } - }) - } -} - // TestNICContextPreservation tests that you can read out via stack.NICInfo the // Context data you pass via NICContext.Context in stack.CreateNICWithOptions. func TestNICContextPreservation(t *testing.T) { @@ -2040,11 +1934,9 @@ func TestNICContextPreservation(t *testing.T) { } } -// TestNICAutoGenAddrWithOpaque tests the auto-generation of IPv6 link-local -// addresses with opaque interface identifiers. Link Local addresses should -// always be generated with opaque IIDs if configured to use them, even if the -// NIC has an invalid MAC address. -func TestNICAutoGenAddrWithOpaque(t *testing.T) { +// TestNICAutoGenLinkLocalAddr tests the auto-generation of IPv6 link-local +// addresses. +func TestNICAutoGenLinkLocalAddr(t *testing.T) { const nicID = 1 var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte @@ -2056,108 +1948,185 @@ func TestNICAutoGenAddrWithOpaque(t *testing.T) { t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n) } + nicNameFunc := func(_ tcpip.NICID, name string) string { + return name + } + tests := []struct { - name string - nicName string - autoGen bool - linkAddr tcpip.LinkAddress - secretKey []byte + name string + nicName string + autoGen bool + linkAddr tcpip.LinkAddress + iidOpts stack.OpaqueInterfaceIdentifierOptions + shouldGen bool + expectedAddr tcpip.Address }{ { name: "Disabled", nicName: "nic1", autoGen: false, linkAddr: linkAddr1, - secretKey: secretKey[:], + shouldGen: false, }, { - name: "Enabled", - nicName: "nic1", - autoGen: true, - linkAddr: linkAddr1, - secretKey: secretKey[:], + name: "Disabled without OIID options", + nicName: "nic1", + autoGen: false, + linkAddr: linkAddr1, + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:], + }, + shouldGen: false, }, - // These are all cases where we would not have generated a - // link-local address if opaque IIDs were disabled. + + // Tests for EUI64 based addresses. { - name: "Nil MAC and empty nicName", - nicName: "", + name: "EUI64 Enabled", + autoGen: true, + linkAddr: linkAddr1, + shouldGen: true, + expectedAddr: header.LinkLocalAddr(linkAddr1), + }, + { + name: "EUI64 Empty MAC", autoGen: true, - linkAddr: tcpip.LinkAddress([]byte(nil)), - secretKey: secretKey[:1], + shouldGen: false, }, { - name: "Empty MAC and empty nicName", + name: "EUI64 Invalid MAC", autoGen: true, - linkAddr: tcpip.LinkAddress(""), - secretKey: secretKey[:2], + linkAddr: "\x01\x02\x03", + shouldGen: false, }, { - name: "Invalid MAC", - nicName: "test", + name: "EUI64 Multicast MAC", autoGen: true, - linkAddr: tcpip.LinkAddress("\x01\x02\x03"), - secretKey: secretKey[:3], + linkAddr: "\x01\x02\x03\x04\x05\x06", + shouldGen: false, }, { - name: "Multicast MAC", - nicName: "test2", + name: "EUI64 Unspecified MAC", autoGen: true, - linkAddr: tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), - secretKey: secretKey[:4], + linkAddr: "\x00\x00\x00\x00\x00\x00", + shouldGen: false, }, + + // Tests for Opaque IID based addresses. { - name: "Unspecified MAC and nil SecretKey", + name: "OIID Enabled", + nicName: "nic1", + autoGen: true, + linkAddr: linkAddr1, + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]), + }, + // These are all cases where we would not have generated a + // link-local address if opaque IIDs were disabled. + { + name: "OIID Empty MAC and empty nicName", + autoGen: true, + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:1], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("", 0, secretKey[:1]), + }, + { + name: "OIID Invalid MAC", + nicName: "test", + autoGen: true, + linkAddr: "\x01\x02\x03", + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:2], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("test", 0, secretKey[:2]), + }, + { + name: "OIID Multicast MAC", + nicName: "test2", + autoGen: true, + linkAddr: "\x01\x02\x03\x04\x05\x06", + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + SecretKey: secretKey[:3], + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("test2", 0, secretKey[:3]), + }, + { + name: "OIID Unspecified MAC and nil SecretKey", nicName: "test3", autoGen: true, - linkAddr: tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"), + linkAddr: "\x00\x00\x00\x00\x00\x00", + iidOpts: stack.OpaqueInterfaceIdentifierOptions{ + NICNameFromID: nicNameFunc, + }, + shouldGen: true, + expectedAddr: header.LinkLocalAddrWithOpaqueIID("test3", 0, nil), }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{ - NICNameFromID: func(_ tcpip.NICID, nicName string) string { - return nicName - }, - SecretKey: test.secretKey, - }, + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } - - if test.autoGen { - // Only set opts.AutoGenIPv6LinkLocal when - // test.autoGen is true because - // opts.AutoGenIPv6LinkLocal should be false by - // default. - opts.AutoGenIPv6LinkLocal = true + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: test.autoGen, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: test.iidOpts, } - e := channel.New(10, 1280, test.linkAddr) + e := channel.New(0, 1280, test.linkAddr) s := stack.New(opts) nicOpts := stack.NICOptions{Name: test.nicName} if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) } - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err) - } + var expectedMainAddr tcpip.AddressWithPrefix + + if test.shouldGen { + expectedMainAddr = tcpip.AddressWithPrefix{ + Address: test.expectedAddr, + PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen, + } - if test.autoGen { - // Should have auto-generated an address and - // resolved immediately (DAD is disabled). - if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddrWithOpaqueIID(test.nicName, 0, test.secretKey), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + // Should have auto-generated an address and resolved immediately (DAD + // is disabled). + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, expectedMainAddr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") } } else { // Should not have auto-generated an address. - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly auto-generated an address") + default: } } + + gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + } + if gotMainAddr != expectedMainAddr { + t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr) + } }) } } @@ -2226,7 +2195,7 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { NDPDisp: &ndpDisp, } - e := channel.New(10, 1280, linkAddr1) + e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1) s := stack.New(opts) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) -- cgit v1.2.3 From ca30dfa065f5458228b06dcf5379ed4edf29c165 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 6 Feb 2020 19:49:30 -0800 Subject: Send DAD event when DAD resolves immediately Previously, a DAD event would not be sent if DAD was disabled. This allows integrators to do some work when an IPv6 address is bound to a NIC without special logic that checks if DAD is enabled. Without this change, integrators would need to check if a NIC has DAD enabled when an address is auto-generated. If DAD is enabled, it would need to delay the work until the DAD completion event; otherwise, it would need to do the work in the address auto-generated event handler. Test: stack_test.TestDADDisabled PiperOrigin-RevId: 293732914 --- pkg/tcpip/stack/ndp.go | 7 ++ pkg/tcpip/stack/ndp_test.go | 263 +++++++++++++++++++++--------------------- pkg/tcpip/stack/stack_test.go | 44 +++---- 3 files changed, 154 insertions(+), 160 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index fae5f5014..045409bda 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -448,6 +448,13 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref remaining := ndp.configs.DupAddrDetectTransmits if remaining == 0 { ref.setKind(permanent) + + // Consider DAD to have resolved even if no DAD messages were actually + // transmitted. + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, true, nil) + } + return nil } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index e13509fbd..1f6f77439 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -86,39 +86,6 @@ func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWi return prefix, subnet, addrForSubnet(subnet, linkAddr) } -// TestDADDisabled tests that an address successfully resolves immediately -// when DAD is not enabled (the default for an empty stack.Options). -func TestDADDisabled(t *testing.T) { - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - } - - e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) - } - - // Should get the address immediately since we should not have performed - // DAD on it. - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) - } - if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1) - } - - // We should not have sent any NDP NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { - t.Fatalf("got NeighborSolicit = %d, want = 0", got) - } -} - // ndpDADEvent is a set of parameters that was passed to // ndpDispatcher.OnDuplicateAddressDetectionStatus. type ndpDADEvent struct { @@ -300,6 +267,58 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s } } +// Check e to make sure that the event is for addr on nic with ID 1, and the +// resolved flag set to resolved with the specified err. +func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string { + return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e)) +} + +// TestDADDisabled tests that an address successfully resolves immediately +// when DAD is not enabled (the default for an empty stack.Options). +func TestDADDisabled(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + } + + e := channel.New(0, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) + } + + // Should get the address immediately since we should not have performed + // DAD on it. + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected DAD event") + } + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err) + } + if addr.Address != addr1 { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) + } + + // We should not have sent any NDP NS messages. + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { + t.Fatalf("got NeighborSolicit = %d, want = 0", got) + } +} + // TestDADResolve tests that an address successfully resolves after performing // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid @@ -381,17 +400,8 @@ func TestDADResolve(t *testing.T) { // means something is wrong. t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != nicID { - t.Fatalf("got DAD event w/ nicID = %d, want = %d", e.nicID, nicID) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") + if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) @@ -445,6 +455,8 @@ func TestDADResolve(t *testing.T) { // a node doing DAD for the same address), or if another node is detected to own // the address already (receive an NA message for the tentative address). func TestDADFail(t *testing.T) { + const nicID = 1 + tests := []struct { name string makeBuf func(tgt tcpip.Address) buffer.Prependable @@ -526,22 +538,22 @@ func TestDADFail(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) } // Address should not be considered bound to the NIC yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Receive a packet to simulate multiple nodes owning or @@ -565,25 +577,16 @@ func TestDADFail(t *testing.T) { // something is wrong. t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) - } - if e.resolved { - t.Fatal("got DAD event w/ resolved = true, want = false") + if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } }) } @@ -592,6 +595,8 @@ func TestDADFail(t *testing.T) { // TestDADStop tests to make sure that the DAD process stops when an address is // removed. func TestDADStop(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), } @@ -607,26 +612,26 @@ func TestDADStop(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) } // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Remove the address. This should stop DAD. - if err := s.RemoveAddress(1, addr1); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err) + if err := s.RemoveAddress(nicID, addr1); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) } // Wait for DAD to fail (since the address was removed during DAD). @@ -636,26 +641,16 @@ func TestDADStop(t *testing.T) { // time + extra 1s buffer, something is wrong. t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) + if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } - if e.resolved { - t.Fatal("got DAD event w/ resolved = true, want = false") - } - } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Should not have sent more than 1 NS message. @@ -681,6 +676,10 @@ func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) { // configurations without affecting the default NDP configurations or other // interfaces' configurations. func TestSetNDPConfigurations(t *testing.T) { + const nicID1 = 1 + const nicID2 = 2 + const nicID3 = 3 + tests := []struct { name string dupAddrDetectTransmits uint8 @@ -704,7 +703,7 @@ func TestSetNDPConfigurations(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -712,17 +711,28 @@ func TestSetNDPConfigurations(t *testing.T) { NDPDisp: &ndpDisp, }) + expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatalf("expected DAD event for %s", addr) + } + } + // This NIC(1)'s NDP configurations will be updated to // be different from the default. - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) } // Created before updating NIC(1)'s NDP configurations // but updating NIC(1)'s NDP configurations should not // affect other existing NICs. - if err := s.CreateNIC(2, e); err != nil { - t.Fatalf("CreateNIC(2) = %s", err) + if err := s.CreateNIC(nicID2, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) } // Update the NDP configurations on NIC(1) to use DAD. @@ -730,36 +740,38 @@ func TestSetNDPConfigurations(t *testing.T) { DupAddrDetectTransmits: test.dupAddrDetectTransmits, RetransmitTimer: test.retransmitTimer, } - if err := s.SetNDPConfigurations(1, configs); err != nil { - t.Fatalf("got SetNDPConfigurations(1, _) = %s", err) + if err := s.SetNDPConfigurations(nicID1, configs); err != nil { + t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err) } // Created after updating NIC(1)'s NDP configurations // but the stack's default NDP configurations should not // have been updated. - if err := s.CreateNIC(3, e); err != nil { - t.Fatalf("CreateNIC(3) = %s", err) + if err := s.CreateNIC(nicID3, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID3, err) } // Add addresses for each NIC. - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(1, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) + if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err) } - if err := s.AddAddress(2, header.IPv6ProtocolNumber, addr2); err != nil { - t.Fatalf("AddAddress(2, %d, %s) = %s", header.IPv6ProtocolNumber, addr2, err) + if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err) } - if err := s.AddAddress(3, header.IPv6ProtocolNumber, addr3); err != nil { - t.Fatalf("AddAddress(3, %d, %s) = %s", header.IPv6ProtocolNumber, addr3, err) + expectDADEvent(nicID2, addr2) + if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err) } + expectDADEvent(nicID3, addr3) // Address should not be considered bound to NIC(1) yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) } // Should get the address on NIC(2) and NIC(3) @@ -767,31 +779,31 @@ func TestSetNDPConfigurations(t *testing.T) { // it as the stack was configured to not do DAD by // default and we only updated the NDP configurations on // NIC(1). - addr, err = s.GetMainNICAddress(2, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(2, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err) } if addr.Address != addr2 { - t.Fatalf("got stack.GetMainNICAddress(2, _) = %s, want = %s", addr, addr2) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2) } - addr, err = s.GetMainNICAddress(3, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(3, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err) } if addr.Address != addr3 { - t.Fatalf("got stack.GetMainNICAddress(3, _) = %s, want = %s", addr, addr3) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3) } // Sleep until right (500ms before) before resolution to // make sure the address didn't resolve on NIC(1) yet. const delta = 500 * time.Millisecond time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want) } // Wait for DAD to resolve. @@ -805,25 +817,16 @@ func TestSetNDPConfigurations(t *testing.T) { // means something is wrong. t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != addr1 { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") + if diff := checkDADEvent(e, nicID1, addr1, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(1, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err) } if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(1, _) = %s, want = %s", addr, addr1) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1) } }) } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b2c1763bf..24133e6f2 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2184,6 +2184,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { // TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6 // link-local addresses will only be assigned after the DAD process resolves. func TestNICAutoGenAddrDoesDAD(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent), } @@ -2197,18 +2199,18 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1) s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } // Address should not be considered bound to the // NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } linkLocalAddr := header.LinkLocalAddr(linkAddr1) @@ -2222,25 +2224,16 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { // means something is wrong. t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) - } - if e.addr != linkLocalAddr { - t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") + if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } } @@ -2606,17 +2599,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) { case <-time.After(dadTransmits*retransmitTimer + defaultAsyncEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if e.err != nil { - t.Fatal("got DAD error: ", e.err) - } - if e.nicID != nicID { - t.Fatalf("got DAD event w/ nicID = %d, want = %d", e.nicID, nicID) - } - if e.addr != addr.AddressWithPrefix.Address { - t.Fatalf("got DAD event w/ addr = %s, want = %s", e.addr, addr.AddressWithPrefix.Address) - } - if !e.resolved { - t.Fatal("got DAD event w/ resolved = false, want = true") + if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { -- cgit v1.2.3 From b8e22e241cab625d2809034c74d0ff808b948b4c Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Tue, 11 Feb 2020 12:58:13 -0800 Subject: Disallow duplicate NIC names. PiperOrigin-RevId: 294500858 --- pkg/tcpip/stack/nic.go | 5 +++ pkg/tcpip/stack/stack.go | 9 +++++ pkg/tcpip/stack/stack_test.go | 85 +++++++++++++++++++++++++++++++++++++++++++ runsc/sandbox/network.go | 30 +++++++-------- 4 files changed, 114 insertions(+), 15 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 78d451cca..ca3a7a07e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1215,6 +1215,11 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +// Name returns the name of n. +func (n *NIC) Name() string { + return n.name +} + // Stack returns the instance of the Stack that owns this NIC. func (n *NIC) Stack() *Stack { return n.stack diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index b793f1d74..6eac16e16 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -890,6 +890,15 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp return tcpip.ErrDuplicateNICID } + // Make sure name is unique, unless unnamed. + if opts.Name != "" { + for _, n := range s.nics { + if n.Name() == opts.Name { + return tcpip.ErrDuplicateNICID + } + } + } + n := newNIC(s, id, opts.Name, ep, opts.Context) s.nics[id] = n diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 24133e6f2..7ba604442 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1792,6 +1792,91 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { verifyAddresses(t, expectedAddresses, gotAddresses) } +func TestCreateNICWithOptions(t *testing.T) { + type callArgsAndExpect struct { + nicID tcpip.NICID + opts stack.NICOptions + err *tcpip.Error + } + + tests := []struct { + desc string + calls []callArgsAndExpect + }{ + { + desc: "DuplicateNICID", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "eth1"}, + err: nil, + }, + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "eth2"}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + { + desc: "DuplicateName", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{Name: "lo"}, + err: nil, + }, + { + nicID: tcpip.NICID(2), + opts: stack.NICOptions{Name: "lo"}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + { + desc: "Unnamed", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: nil, + }, + { + nicID: tcpip.NICID(2), + opts: stack.NICOptions{}, + err: nil, + }, + }, + }, + { + desc: "UnnamedDuplicateNICID", + calls: []callArgsAndExpect{ + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: nil, + }, + { + nicID: tcpip.NICID(1), + opts: stack.NICOptions{}, + err: tcpip.ErrDuplicateNICID, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + s := stack.New(stack.Options{}) + ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + for _, call := range test.calls { + if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { + t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) + } + } + }) + } +} + func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index ff48f5646..99e143696 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -174,13 +174,13 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG return fmt.Errorf("fetching interface addresses for %q: %v", iface.Name, err) } - // We build our own loopback devices. + // We build our own loopback device. if iface.Flags&net.FlagLoopback != 0 { - links, err := loopbackLinks(iface, allAddrs) + link, err := loopbackLink(iface, allAddrs) if err != nil { - return fmt.Errorf("getting loopback routes and links for iface %q: %v", iface.Name, err) + return fmt.Errorf("getting loopback link for iface %q: %v", iface.Name, err) } - args.LoopbackLinks = append(args.LoopbackLinks, links...) + args.LoopbackLinks = append(args.LoopbackLinks, link) continue } @@ -339,25 +339,25 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) ( return &socketEntry{deviceFile, gsoMaxSize}, nil } -// loopbackLinks collects the links for a loopback interface. -func loopbackLinks(iface net.Interface, addrs []net.Addr) ([]boot.LoopbackLink, error) { - var links []boot.LoopbackLink +// loopbackLink returns the link with addresses and routes for a loopback +// interface. +func loopbackLink(iface net.Interface, addrs []net.Addr) (boot.LoopbackLink, error) { + link := boot.LoopbackLink{ + Name: iface.Name, + } for _, addr := range addrs { ipNet, ok := addr.(*net.IPNet) if !ok { - return nil, fmt.Errorf("address is not IPNet: %+v", addr) + return boot.LoopbackLink{}, fmt.Errorf("address is not IPNet: %+v", addr) } dst := *ipNet dst.IP = dst.IP.Mask(dst.Mask) - links = append(links, boot.LoopbackLink{ - Name: iface.Name, - Addresses: []net.IP{ipNet.IP}, - Routes: []boot.Route{{ - Destination: dst, - }}, + link.Addresses = append(link.Addresses, ipNet.IP) + link.Routes = append(link.Routes, boot.Route{ + Destination: dst, }) } - return links, nil + return link, nil } // routesForIface iterates over all routes for the given interface and converts -- cgit v1.2.3 From 6fdf2c53a1d084b70602170b660242036fd8fe4f Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Fri, 7 Feb 2020 11:21:07 -0800 Subject: iptables: User chains - Adds creation of user chains via `-N ` - Adds `-j RETURN` support for built-in chains, which triggers the chain's underflow rule (usually the default policy). - Adds tests for chain creation, default policies, and `-j RETURN' from built-in chains. --- pkg/sentry/socket/netfilter/netfilter.go | 115 +++++++++++++++++++----------- pkg/tcpip/iptables/iptables.go | 74 ++++++++++++------- pkg/tcpip/iptables/targets.go | 41 ++++++++--- pkg/tcpip/iptables/types.go | 50 +++++-------- test/iptables/filter_input.go | 117 ++++++++++++++++++++++++++++++- test/iptables/iptables_test.go | 24 +++++++ 6 files changed, 310 insertions(+), 111 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index ea02627de..3fc80e0de 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -50,7 +50,9 @@ type metadata struct { // nflog logs messages related to the writing and reading of iptables. func nflog(format string, args ...interface{}) { - log.Infof("netfilter: "+format, args...) + if log.IsLogging(log.Debug) { + log.Debugf("netfilter: "+format, args...) + } } // GetInfo returns information about iptables. @@ -227,19 +229,23 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern } func marshalTarget(target iptables.Target) []byte { - switch target.(type) { - case iptables.UnconditionalAcceptTarget: - return marshalStandardTarget(iptables.Accept) - case iptables.UnconditionalDropTarget: - return marshalStandardTarget(iptables.Drop) + switch tg := target.(type) { + case iptables.AcceptTarget: + return marshalStandardTarget(iptables.RuleAccept) + case iptables.DropTarget: + return marshalStandardTarget(iptables.RuleDrop) case iptables.ErrorTarget: - return marshalErrorTarget() + return marshalErrorTarget(errorTargetName) + case iptables.UserChainTarget: + return marshalErrorTarget(tg.Name) + case iptables.ReturnTarget: + return marshalStandardTarget(iptables.RuleReturn) default: panic(fmt.Errorf("unknown target of type %T", target)) } } -func marshalStandardTarget(verdict iptables.Verdict) []byte { +func marshalStandardTarget(verdict iptables.RuleVerdict) []byte { nflog("convert to binary: marshalling standard target with size %d", linux.SizeOfXTStandardTarget) // The target's name will be the empty string. @@ -254,14 +260,14 @@ func marshalStandardTarget(verdict iptables.Verdict) []byte { return binary.Marshal(ret, usermem.ByteOrder, target) } -func marshalErrorTarget() []byte { +func marshalErrorTarget(errorName string) []byte { // This is an error target named error target := linux.XTErrorTarget{ Target: linux.XTEntryTarget{ TargetSize: linux.SizeOfXTErrorTarget, }, } - copy(target.Name[:], errorTargetName) + copy(target.Name[:], errorName) copy(target.Target.Name[:], errorTargetName) ret := make([]byte, 0, linux.SizeOfXTErrorTarget) @@ -270,38 +276,35 @@ func marshalErrorTarget() []byte { // translateFromStandardVerdict translates verdicts the same way as the iptables // tool. -func translateFromStandardVerdict(verdict iptables.Verdict) int32 { +func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 { switch verdict { - case iptables.Accept: + case iptables.RuleAccept: return -linux.NF_ACCEPT - 1 - case iptables.Drop: + case iptables.RuleDrop: return -linux.NF_DROP - 1 - case iptables.Queue: - return -linux.NF_QUEUE - 1 - case iptables.Return: + case iptables.RuleReturn: return linux.NF_RETURN - case iptables.Jump: + default: // TODO(gvisor.dev/issue/170): Support Jump. - panic("Jump isn't supported yet") + panic(fmt.Sprintf("unknown standard verdict: %d", verdict)) } - panic(fmt.Sprintf("unknown standard verdict: %d", verdict)) } -// translateToStandardVerdict translates from the value in a +// translateToStandardTarget translates from the value in a // linux.XTStandardTarget to an iptables.Verdict. -func translateToStandardVerdict(val int32) (iptables.Verdict, error) { +func translateToStandardTarget(val int32) (iptables.Target, error) { // TODO(gvisor.dev/issue/170): Support other verdicts. switch val { case -linux.NF_ACCEPT - 1: - return iptables.Accept, nil + return iptables.AcceptTarget{}, nil case -linux.NF_DROP - 1: - return iptables.Drop, nil + return iptables.DropTarget{}, nil case -linux.NF_QUEUE - 1: - return iptables.Invalid, errors.New("unsupported iptables verdict QUEUE") + return nil, errors.New("unsupported iptables verdict QUEUE") case linux.NF_RETURN: - return iptables.Invalid, errors.New("unsupported iptables verdict RETURN") + return iptables.ReturnTarget{}, nil default: - return iptables.Invalid, fmt.Errorf("unknown iptables verdict %d", val) + return nil, fmt.Errorf("unknown iptables verdict %d", val) } } @@ -411,6 +414,10 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { table.BuiltinChains[hk] = ruleIdx } if offset == replace.Underflow[hook] { + if !validUnderflow(table.Rules[ruleIdx]) { + nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP.") + return syserr.ErrInvalidArgument + } table.Underflows[hk] = ruleIdx } } @@ -425,12 +432,34 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { } } + // Add the user chains. + for ruleIdx, rule := range table.Rules { + target, ok := rule.Target.(iptables.UserChainTarget) + if !ok { + continue + } + + // We found a user chain. Before inserting it into the table, + // check that: + // - There's some other rule after it. + // - There are no matchers. + if ruleIdx == len(table.Rules)-1 { + nflog("user chain must have a rule or default policy.") + return syserr.ErrInvalidArgument + } + if len(table.Rules[ruleIdx].Matchers) != 0 { + nflog("user chain's first node must have no matcheres.") + return syserr.ErrInvalidArgument + } + table.UserChains[target.Name] = ruleIdx + 1 + } + // TODO(gvisor.dev/issue/170): Support other chains. // Since we only support modifying the INPUT chain right now, make sure // all other chains point to ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { if hook != iptables.Input { - if _, ok := table.Rules[ruleIdx].Target.(iptables.UnconditionalAcceptTarget); !ok { + if _, ok := table.Rules[ruleIdx].Target.(iptables.AcceptTarget); !ok { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument } @@ -519,18 +548,7 @@ func parseTarget(optVal []byte) (iptables.Target, error) { buf = optVal[:linux.SizeOfXTStandardTarget] binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) - verdict, err := translateToStandardVerdict(standardTarget.Verdict) - if err != nil { - return nil, err - } - switch verdict { - case iptables.Accept: - return iptables.UnconditionalAcceptTarget{}, nil - case iptables.Drop: - return iptables.UnconditionalDropTarget{}, nil - default: - return nil, fmt.Errorf("Unknown verdict: %v", verdict) - } + return translateToStandardTarget(standardTarget.Verdict) case errorTargetName: // Error target. @@ -548,11 +566,14 @@ func parseTarget(optVal []byte) (iptables.Target, error) { // somehow fall through every rule. // * To mark the start of a user defined chain. These // rules have an error with the name of the chain. - switch errorTarget.Name.String() { + switch name := errorTarget.Name.String(); name { case errorTargetName: + nflog("set entries: error target") return iptables.ErrorTarget{}, nil default: - return nil, fmt.Errorf("unknown error target %q doesn't exist or isn't supported yet.", errorTarget.Name.String()) + // User defined chain. + nflog("set entries: user-defined target %q", name) + return iptables.UserChainTarget{Name: name}, nil } } @@ -585,6 +606,18 @@ func containsUnsupportedFields(iptip linux.IPTIP) bool { iptip.InverseFlags != 0 } +func validUnderflow(rule iptables.Rule) bool { + if len(rule.Matchers) != 0 { + return false + } + switch rule.Target.(type) { + case iptables.AcceptTarget, iptables.DropTarget: + return true + default: + return false + } +} + func hookFromLinux(hook int) iptables.Hook { switch hook { case linux.NF_INET_PRE_ROUTING: diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index 1b9485bbd..75a433a3b 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -52,10 +52,10 @@ func DefaultTables() IPTables { Tables: map[string]Table{ TablenameNat: Table{ Rules: []Rule{ - Rule{Target: UnconditionalAcceptTarget{}}, - Rule{Target: UnconditionalAcceptTarget{}}, - Rule{Target: UnconditionalAcceptTarget{}}, - Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, BuiltinChains: map[Hook]int{ @@ -74,8 +74,8 @@ func DefaultTables() IPTables { }, TablenameMangle: Table{ Rules: []Rule{ - Rule{Target: UnconditionalAcceptTarget{}}, - Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, BuiltinChains: map[Hook]int{ @@ -90,9 +90,9 @@ func DefaultTables() IPTables { }, TablenameFilter: Table{ Rules: []Rule{ - Rule{Target: UnconditionalAcceptTarget{}}, - Rule{Target: UnconditionalAcceptTarget{}}, - Rule{Target: UnconditionalAcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, BuiltinChains: map[Hook]int{ @@ -149,13 +149,11 @@ func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool { for _, tablename := range it.Priorities[hook] { switch verdict := it.checkTable(hook, pkt, tablename); verdict { // If the table returns Accept, move on to the next table. - case Accept: + case TableAccept: continue // The Drop verdict is final. - case Drop: + case TableDrop: return false - case Stolen, Queue, Repeat, None, Jump, Return, Continue: - panic(fmt.Sprintf("Unimplemented verdict %v.", verdict)) default: panic(fmt.Sprintf("Unknown verdict %v.", verdict)) } @@ -166,36 +164,58 @@ func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool { } // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename string) Verdict { +func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename string) TableVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. table := it.Tables[tablename] for ruleIdx := table.BuiltinChains[hook]; ruleIdx < len(table.Rules); ruleIdx++ { switch verdict := it.checkRule(hook, pkt, table, ruleIdx); verdict { - // In either of these cases, this table is done with the packet. - case Accept, Drop: - return verdict - // Continue traversing the rules of the table. - case Continue: + case RuleAccept: + return TableAccept + + case RuleDrop: + return TableDrop + + case RuleContinue: continue - case Stolen, Queue, Repeat, None, Jump, Return: - panic(fmt.Sprintf("Unimplemented verdict %v.", verdict)) + + case RuleReturn: + // TODO(gvisor.dev/issue/170): We don't implement jump + // yet, so any Return is from a built-in chain. That + // means we have to to call the underflow. + underflow := table.Rules[table.Underflows[hook]] + // Underflow is guaranteed to be an unconditional + // ACCEPT or DROP. + switch v, _ := underflow.Target.Action(pkt); v { + case RuleAccept: + return TableAccept + case RuleDrop: + return TableDrop + case RuleContinue, RuleReturn: + panic("Underflows should only return RuleAccept or RuleDrop.") + default: + panic(fmt.Sprintf("Unknown verdict: %d", v)) + } + default: - panic(fmt.Sprintf("Unknown verdict %v.", verdict)) + panic(fmt.Sprintf("Unknown verdict: %d", verdict)) } + } - panic(fmt.Sprintf("Traversed past the entire list of iptables rules in table %q.", tablename)) + // We got through the entire table without a decision. Default to DROP + // for safety. + return TableDrop } // Precondition: pk.NetworkHeader is set. -func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) Verdict { +func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) RuleVerdict { rule := table.Rules[ruleIdx] // First check whether the packet matches the IP header filter. // TODO(gvisor.dev/issue/170): Support other fields of the filter. if rule.Filter.Protocol != 0 && rule.Filter.Protocol != header.IPv4(pkt.NetworkHeader).TransportProtocol() { - return Continue + return RuleContinue } // Go through each rule matcher. If they all match, run @@ -203,10 +223,10 @@ func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ru for _, matcher := range rule.Matchers { matches, hotdrop := matcher.Match(hook, pkt, "") if hotdrop { - return Drop + return RuleDrop } if !matches { - return Continue + return RuleContinue } } diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go index 4dd281371..9fc60cfad 100644 --- a/pkg/tcpip/iptables/targets.go +++ b/pkg/tcpip/iptables/targets.go @@ -21,20 +21,20 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -// UnconditionalAcceptTarget accepts all packets. -type UnconditionalAcceptTarget struct{} +// AcceptTarget accepts packets. +type AcceptTarget struct{} // Action implements Target.Action. -func (UnconditionalAcceptTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) { - return Accept, "" +func (AcceptTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, string) { + return RuleAccept, "" } -// UnconditionalDropTarget denies all packets. -type UnconditionalDropTarget struct{} +// DropTarget drops packets. +type DropTarget struct{} // Action implements Target.Action. -func (UnconditionalDropTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) { - return Drop, "" +func (DropTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, string) { + return RuleDrop, "" } // ErrorTarget logs an error and drops the packet. It represents a target that @@ -42,7 +42,26 @@ func (UnconditionalDropTarget) Action(packet tcpip.PacketBuffer) (Verdict, strin type ErrorTarget struct{} // Action implements Target.Action. -func (ErrorTarget) Action(packet tcpip.PacketBuffer) (Verdict, string) { - log.Warningf("ErrorTarget triggered.") - return Drop, "" +func (ErrorTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, string) { + log.Debugf("ErrorTarget triggered.") + return RuleDrop, "" +} + +// UserChainTarget marks a rule as the beginning of a user chain. +type UserChainTarget struct { + Name string +} + +// Action implements Target.Action. +func (UserChainTarget) Action(tcpip.PacketBuffer) (RuleVerdict, string) { + panic("UserChainTarget should never be called.") +} + +// ReturnTarget returns from the current chain. If the chain is a built-in, the +// hook's underflow should be called. +type ReturnTarget struct{} + +// Action implements Target.Action. +func (ReturnTarget) Action(tcpip.PacketBuffer) (RuleVerdict, string) { + return RuleReturn, "" } diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 7d593c35c..5735d001b 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -56,44 +56,32 @@ const ( NumHooks ) -// A Verdict is returned by a rule's target to indicate how traversal of rules -// should (or should not) continue. -type Verdict int +// A TableVerdict is what a table decides should be done with a packet. +type TableVerdict int const ( - // Invalid indicates an unkonwn or erroneous verdict. - Invalid Verdict = iota + // TableAccept indicates the packet should continue through netstack. + TableAccept TableVerdict = iota - // Accept indicates the packet should continue traversing netstack as - // normal. - Accept - - // Drop inicates the packet should be dropped, stopping traversing - // netstack. - Drop - - // Stolen indicates the packet was co-opted by the target and should - // stop traversing netstack. - Stolen - - // Queue indicates the packet should be queued for userspace processing. - Queue + // TableAccept indicates the packet should be dropped. + TableDrop +) - // Repeat indicates the packet should re-traverse the chains for the - // current hook. - Repeat +// A RuleVerdict is what a rule decides should be done with a packet. +type RuleVerdict int - // None indicates no verdict was reached. - None +const ( + // RuleAccept indicates the packet should continue through netstack. + RuleAccept RuleVerdict = iota - // Jump indicates a jump to another chain. - Jump + // RuleContinue indicates the packet should continue to the next rule. + RuleContinue - // Continue indicates that traversal should continue at the next rule. - Continue + // RuleDrop indicates the packet should be dropped. + RuleDrop - // Return indicates that traversal should return to the calling chain. - Return + // RuleReturn indicates the packet should return to the previous chain. + RuleReturn ) // IPTables holds all the tables for a netstack. @@ -187,5 +175,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the name of the chain to jump to. - Action(packet tcpip.PacketBuffer) (Verdict, string) + Action(packet tcpip.PacketBuffer) (RuleVerdict, string) } diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index bd6059921..e26d6a7d2 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -36,6 +36,10 @@ func init() { RegisterTestCase(FilterInputDropTCPSrcPort{}) RegisterTestCase(FilterInputDropUDPPort{}) RegisterTestCase(FilterInputDropUDP{}) + RegisterTestCase(FilterInputCreateUserChain{}) + RegisterTestCase(FilterInputDefaultPolicyAccept{}) + RegisterTestCase(FilterInputDefaultPolicyDrop{}) + RegisterTestCase(FilterInputReturnUnderflow{}) } // FilterInputDropUDP tests that we can drop UDP traffic. @@ -295,8 +299,119 @@ func (FilterInputRequireProtocolUDP) ContainerAction(ip net.IP) error { return nil } -// LocalAction implements TestCase.LocalAction. func (FilterInputRequireProtocolUDP) LocalAction(ip net.IP) error { // No-op. return nil } + +// FilterInputCreateUserChain tests chain creation. +type FilterInputCreateUserChain struct{} + +// Name implements TestCase.Name. +func (FilterInputCreateUserChain) Name() string { + return "FilterInputCreateUserChain" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputCreateUserChain) ContainerAction(ip net.IP) error { + // Create a chain. + const chainName = "foochain" + if err := filterTable("-N", chainName); err != nil { + return err + } + + // Add a simple rule to the chain. + return filterTable("-A", chainName, "-j", "DROP") +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputCreateUserChain) LocalAction(ip net.IP) error { + // No-op. + return nil +} + +// FilterInputDefaultPolicyAccept tests the default ACCEPT policy. +type FilterInputDefaultPolicyAccept struct{} + +// Name implements TestCase.Name. +func (FilterInputDefaultPolicyAccept) Name() string { + return "FilterInputDefaultPolicyAccept" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDefaultPolicyAccept) ContainerAction(ip net.IP) error { + // Set the default policy to accept, then receive a packet. + if err := filterTable("-P", "INPUT", "ACCEPT"); err != nil { + return err + } + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDefaultPolicyAccept) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// FilterInputDefaultPolicyDrop tests the default DROP policy. +type FilterInputDefaultPolicyDrop struct{} + +// Name implements TestCase.Name. +func (FilterInputDefaultPolicyDrop) Name() string { + return "FilterInputDefaultPolicyDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP) error { + if err := filterTable("-P", "INPUT", "DROP"); err != nil { + return err + } + + // Listen for UDP packets on dropPort. + if err := listenUDP(dropPort, sendloopDuration); err == nil { + return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) + } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + return fmt.Errorf("error reading: %v", err) + } + + // At this point we know that reading timed out and never received a + // packet. + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputDefaultPolicyDrop) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// FilterInputReturnUnderflow tests that -j RETURN in a built-in chain causes +// the underflow rule (i.e. default policy) to be executed. +type FilterInputReturnUnderflow struct{} + +// Name implements TestCase.Name. +func (FilterInputReturnUnderflow) Name() string { + return "FilterInputReturnUnderflow" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error { + // Add a RETURN rule followed by an unconditional accept, and set the + // default policy to DROP. + if err := filterTable("-A", "INPUT", "-j", "RETURN"); err != nil { + return err + } + if err := filterTable("-A", "INPUT", "-j", "DROP"); err != nil { + return err + } + if err := filterTable("-P", "INPUT", "ACCEPT"); err != nil { + return err + } + + // We should receive packets, as the RETURN rule will trigger the default + // ACCEPT policy. + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputReturnUnderflow) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 41909582a..46a7c99b0 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -214,6 +214,30 @@ func TestFilterInputDropTCPSrcPort(t *testing.T) { } } +func TestFilterInputCreateUserChain(t *testing.T) { + if err := singleTest(FilterInputCreateUserChain{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterInputDefaultPolicyAccept(t *testing.T) { + if err := singleTest(FilterInputDefaultPolicyAccept{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterInputDefaultPolicyDrop(t *testing.T) { + if err := singleTest(FilterInputDefaultPolicyDrop{}); err != nil { + t.Fatal(err) + } +} + +func TestFilterInputReturnUnderflow(t *testing.T) { + if err := singleTest(FilterInputReturnUnderflow{}); err != nil { + t.Fatal(err) + } +} + func TestFilterOutputDropTCPDestPort(t *testing.T) { if err := singleTest(FilterOutputDropTCPDestPort{}); err != nil { t.Fatal(err) -- cgit v1.2.3 From 69bf39e8a47d3b4dcbbd04d2e8df476cdfab5e74 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 13 Feb 2020 10:58:47 -0800 Subject: Internal change. PiperOrigin-RevId: 294952610 --- pkg/abi/linux/socket.go | 13 ++++ pkg/sentry/socket/control/BUILD | 1 + pkg/sentry/socket/control/control.go | 43 +++++++++++++ pkg/sentry/socket/hostinet/socket.go | 11 +++- pkg/sentry/socket/netstack/netstack.go | 37 ++++++++++-- pkg/tcpip/tcpip.go | 25 ++++++++ pkg/tcpip/transport/udp/endpoint.go | 26 ++++++++ test/syscalls/linux/socket_ip_udp_generic.cc | 44 ++++++++++++++ test/syscalls/linux/socket_ipv4_udp_unbound.cc | 84 ++++++++++++++++++++++++++ test/syscalls/linux/udp_socket_test_cases.cc | 1 - 10 files changed, 278 insertions(+), 7 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 766ee4014..4a14ef691 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -411,6 +411,15 @@ type ControlMessageCredentials struct { GID uint32 } +// A ControlMessageIPPacketInfo is IP_PKTINFO socket control message. +// +// ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h. +type ControlMessageIPPacketInfo struct { + NIC int32 + LocalAddr InetAddr + DestinationAddr InetAddr +} + // SizeOfControlMessageCredentials is the binary size of a // ControlMessageCredentials struct. var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{})) @@ -431,6 +440,10 @@ const SizeOfControlMessageTOS = 1 // SizeOfControlMessageTClass is the size of an IPV6_TCLASS control message. const SizeOfControlMessageTClass = 4 +// SizeOfControlMessageIPPacketInfo is the size of an IP_PKTINFO +// control message. +const SizeOfControlMessageIPPacketInfo = 12 + // SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call. // From net/scm.h. const SCM_MAX_FD = 253 diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 79e16d6e8..4d42d29cb 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -19,6 +19,7 @@ go_library( "//pkg/sentry/socket", "//pkg/sentry/socket/unix/transport", "//pkg/syserror", + "//pkg/tcpip", "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 6145a7fc3..4667373d2 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" ) @@ -338,6 +339,22 @@ func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte { ) } +// PackIPPacketInfo packs an IP_PKTINFO socket control message. +func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte { + var p linux.ControlMessageIPPacketInfo + p.NIC = int32(packetInfo.NIC) + copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) + copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) + + return putCmsgStruct( + buf, + linux.SOL_IP, + linux.IP_PKTINFO, + t.Arch().Width(), + p, + ) +} + // PackControlMessages packs control messages into the given buffer. // // We skip control messages specific to Unix domain sockets. @@ -362,6 +379,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt buf = PackTClass(t, cmsgs.IP.TClass, buf) } + if cmsgs.IP.HasIPPacketInfo { + buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf) + } + return buf } @@ -394,6 +415,16 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { return space } +// NewIPPacketInfo returns the IPPacketInfo struct. +func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo { + var p tcpip.IPPacketInfo + p.NIC = tcpip.NICID(packetInfo.NIC) + copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:]) + copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:]) + + return p +} + // Parse parses a raw socket control message into portable objects. func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) { var ( @@ -468,6 +499,18 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS) i += binary.AlignUp(length, width) + case linux.IP_PKTINFO: + if length < linux.SizeOfControlMessageIPPacketInfo { + return socket.ControlMessages{}, syserror.EINVAL + } + + cmsgs.IP.HasIPPacketInfo = true + var packetInfo linux.ControlMessageIPPacketInfo + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) + + cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo) + i += binary.AlignUp(length, width) + default: return socket.ControlMessages{}, syserror.EINVAL } diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index de76388ac..22f78d2e2 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -289,7 +289,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO: optlen = sizeofInt32 } case linux.SOL_IPV6: @@ -336,6 +336,8 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ switch name { case linux.IP_TOS, linux.IP_RECVTOS: optlen = sizeofInt32 + case linux.IP_PKTINFO: + optlen = linux.SizeOfControlMessageIPPacketInfo } case linux.SOL_IPV6: switch name { @@ -473,7 +475,14 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags case syscall.IP_TOS: controlMessages.IP.HasTOS = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) + + case syscall.IP_PKTINFO: + controlMessages.IP.HasIPPacketInfo = true + var packetInfo linux.ControlMessageIPPacketInfo + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) + controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo) } + case syscall.SOL_IPV6: switch unixCmsg.Header.Type { case syscall.IPV6_TCLASS: diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index ed2fbcceb..9757fbfba 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1414,6 +1414,21 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } return o, nil + case linux.IP_PKTINFO: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + var o int32 + if v { + o = 1 + } + return o, nil + default: emitUnimplementedEventIP(t, name) } @@ -1762,6 +1777,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) linux.IPV6_IPSEC_POLICY, linux.IPV6_JOIN_ANYCAST, linux.IPV6_LEAVE_ANYCAST, + // TODO(b/148887420): Add support for IPV6_PKTINFO. linux.IPV6_PKTINFO, linux.IPV6_ROUTER_ALERT, linux.IPV6_XFRM_POLICY, @@ -1949,6 +1965,16 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + case linux.IP_PKTINFO: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -1964,7 +1990,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_NODEFRAG, linux.IP_OPTIONS, linux.IP_PASSSEC, - linux.IP_PKTINFO, linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, @@ -2395,10 +2420,12 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe func (s *SocketOperations) controlMessages() socket.ControlMessages { return socket.ControlMessages{ IP: tcpip.ControlMessages{ - HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, - Timestamp: s.readCM.Timestamp, - HasTOS: s.readCM.HasTOS, - TOS: s.readCM.TOS, + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + HasIPPacketInfo: s.readCM.HasIPPacketInfo, + PacketInfo: s.readCM.PacketInfo, }, } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 0e944712f..9ca39ce40 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -328,6 +328,12 @@ type ControlMessages struct { // Tclass is the IPv6 traffic class of the associated packet. TClass int32 + + // HasIPPacketInfo indicates whether PacketInfo is set. + HasIPPacketInfo bool + + // PacketInfo holds interface and address data on an incoming packet. + PacketInfo IPPacketInfo } // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) @@ -503,6 +509,11 @@ const ( // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. V6OnlyOption + + // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify + // if more inforamtion is provided with incoming packets such + // as interface index and address. + ReceiveIPPacketInfoOption ) // SockOptInt represents socket options which values have the int type. @@ -685,6 +696,20 @@ type IPv4TOSOption uint8 // for all subsequent outgoing IPv6 packets from the endpoint. type IPv6TrafficClassOption uint8 +// IPPacketInfo is the message struture for IP_PKTINFO. +// +// +stateify savable +type IPPacketInfo struct { + // NIC is the ID of the NIC to be used. + NIC NICID + + // LocalAddr is the local address. + LocalAddr Address + + // DestinationAddr is the destination address. + DestinationAddr Address +} + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index c9cbed8f4..3fe91cac2 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -29,6 +29,7 @@ import ( type udpPacket struct { udpPacketEntry senderAddress tcpip.FullAddress + packetInfo tcpip.IPPacketInfo data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 tos uint8 @@ -118,6 +119,9 @@ type endpoint struct { // as ancillary data to ControlMessages on Read. receiveTOS bool + // receiveIPPacketInfo determines if the packet info is returned by Read. + receiveIPPacketInfo bool + // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -254,11 +258,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess } e.mu.RLock() receiveTOS := e.receiveTOS + receiveIPPacketInfo := e.receiveIPPacketInfo e.mu.RUnlock() if receiveTOS { cm.HasTOS = true cm.TOS = p.tos } + + if receiveIPPacketInfo { + cm.HasIPPacketInfo = true + cm.PacketInfo = p.packetInfo + } return p.data.ToView(), cm, nil } @@ -495,6 +505,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { } e.v6only = v + return nil + + case tcpip.ReceiveIPPacketInfoOption: + e.mu.Lock() + e.receiveIPPacketInfo = v + e.mu.Unlock() + return nil } return nil @@ -703,6 +720,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil + + case tcpip.ReceiveIPPacketInfoOption: + e.mu.RLock() + v := e.receiveIPPacketInfo + e.mu.RUnlock() + return v, nil } return false, tcpip.ErrUnknownProtocolOption @@ -1247,6 +1270,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk switch r.NetProto { case header.IPv4ProtocolNumber: packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() + packet.packetInfo.LocalAddr = r.LocalAddress + packet.packetInfo.DestinationAddr = r.RemoteAddress + packet.packetInfo.NIC = r.NICID() } packet.timestamp = e.stack.NowNanoseconds() diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 53290bed7..db5663ecd 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -357,5 +357,49 @@ TEST_P(UDPSocketPairTest, SetReuseAddrReusePort) { EXPECT_EQ(get, kSockOptOn); } +// Test getsockopt for a socket which is not set with IP_PKTINFO option. +TEST_P(UDPSocketPairTest, IPPKTINFODefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_IP, IP_PKTINFO, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +// Test setsockopt and getsockopt for a socket with IP_PKTINFO option. +TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int level = SOL_IP; + int type = IP_PKTINFO; + + // Check getsockopt before IP_PKTINFO is set. + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, kSockOptOn); + EXPECT_EQ(get_len, sizeof(get)); + + ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOff, + sizeof(kSockOptOff)), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, kSockOptOff); + EXPECT_EQ(get_len, sizeof(get)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index 990ccf23c..bc4b07a62 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -15,6 +15,7 @@ #include "test/syscalls/linux/socket_ipv4_udp_unbound.h" #include +#include #include #include #include @@ -2128,5 +2129,88 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { SyscallSucceedsWithValue(kMessageSize)); } +// Test that socket will receive packet info control message. +TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF((IsRunningWithHostinet())); + + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto sender_addr = V4Loopback(); + int level = SOL_IP; + int type = IP_PKTINFO; + + ASSERT_THAT( + bind(receiver->get(), reinterpret_cast(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + socklen_t sender_addr_len = sender_addr.addr_len; + ASSERT_THAT(getsockname(receiver->get(), + reinterpret_cast(&sender_addr.addr), + &sender_addr_len), + SyscallSucceeds()); + EXPECT_EQ(sender_addr_len, sender_addr.addr_len); + + auto receiver_addr = V4Loopback(); + reinterpret_cast(&receiver_addr.addr)->sin_port = + reinterpret_cast(&sender_addr.addr)->sin_port; + ASSERT_THAT( + connect(sender->get(), reinterpret_cast(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + + // Allow socket to receive control message. + ASSERT_THAT( + setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + msghdr sent_msg = {}; + iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = sent_data; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + sent_msg.msg_flags = 0; + + ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + msghdr received_msg = {}; + iovec received_iov = {}; + char received_data[kDataLength]; + char received_cmsg_buf[CMSG_SPACE(sizeof(in_pktinfo))] = {}; + size_t cmsg_data_len = sizeof(in_pktinfo); + received_iov.iov_base = received_data; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + received_msg.msg_control = received_cmsg_buf; + + ASSERT_THAT(RetryEINTR(recvmsg)(receiver->get(), &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, level); + EXPECT_EQ(cmsg->cmsg_type, type); + + // Get loopback index. + ifreq ifr = {}; + absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo"); + ASSERT_THAT(ioctl(sender->get(), SIOCGIFINDEX, &ifr), SyscallSucceeds()); + ASSERT_NE(ifr.ifr_ifindex, 0); + + // Check the data + in_pktinfo received_pktinfo = {}; + memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo)); + EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex); + EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK)); + EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, htonl(INADDR_LOOPBACK)); +} } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index a2f6ef8cc..9f8de6b48 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -1495,6 +1495,5 @@ TEST_P(UdpSocketTest, SendAndReceiveTOS) { memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); EXPECT_EQ(received_tos, sent_tos); } - } // namespace testing } // namespace gvisor -- cgit v1.2.3 From 6ef63cd7da107d487fda7c48af50fa9802913cd9 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Wed, 12 Feb 2020 16:19:06 -0800 Subject: We can now create and jump in iptables. For example: $ iptables -N foochain $ iptables -A INPUT -j foochain --- pkg/abi/linux/netfilter.go | 9 +- pkg/sentry/socket/netfilter/BUILD | 1 + pkg/sentry/socket/netfilter/netfilter.go | 62 +++++++-- pkg/sentry/socket/netfilter/targets.go | 35 +++++ pkg/tcpip/iptables/iptables.go | 103 +++++++++------ pkg/tcpip/iptables/targets.go | 20 ++- pkg/tcpip/iptables/types.go | 21 +-- test/iptables/filter_input.go | 217 ++++++++++++++++++++++++++++--- test/iptables/iptables_test.go | 36 +++++ test/iptables/iptables_util.go | 10 ++ 10 files changed, 420 insertions(+), 94 deletions(-) create mode 100644 pkg/sentry/socket/netfilter/targets.go (limited to 'pkg/tcpip') diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index bbc4df74c..bd2e13ba1 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -225,11 +225,14 @@ type XTEntryTarget struct { // SizeOfXTEntryTarget is the size of an XTEntryTarget. const SizeOfXTEntryTarget = 32 -// XTStandardTarget is a builtin target, one of ACCEPT, DROP, JUMP, QUEUE, or -// RETURN. It corresponds to struct xt_standard_target in +// XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE, +// RETURN, or jump. It corresponds to struct xt_standard_target in // include/uapi/linux/netfilter/x_tables.h. type XTStandardTarget struct { - Target XTEntryTarget + Target XTEntryTarget + // A positive verdict indicates a jump, and is the offset from the + // start of the table to jump to. A negative value means one of the + // other built-in targets. Verdict int32 _ [4]byte } diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index c91ec7494..7cd2ce55b 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "extensions.go", "netfilter.go", + "targets.go", "tcp_matcher.go", "udp_matcher.go", ], diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 3fc80e0de..d322e4144 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -240,13 +240,15 @@ func marshalTarget(target iptables.Target) []byte { return marshalErrorTarget(tg.Name) case iptables.ReturnTarget: return marshalStandardTarget(iptables.RuleReturn) + case JumpTarget: + return marshalJumpTarget(tg) default: panic(fmt.Errorf("unknown target of type %T", target)) } } func marshalStandardTarget(verdict iptables.RuleVerdict) []byte { - nflog("convert to binary: marshalling standard target with size %d", linux.SizeOfXTStandardTarget) + nflog("convert to binary: marshalling standard target") // The target's name will be the empty string. target := linux.XTStandardTarget{ @@ -274,6 +276,23 @@ func marshalErrorTarget(errorName string) []byte { return binary.Marshal(ret, usermem.ByteOrder, target) } +func marshalJumpTarget(jt JumpTarget) []byte { + nflog("convert to binary: marshalling jump target") + + // The target's name will be the empty string. + target := linux.XTStandardTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTStandardTarget, + }, + // Verdict is overloaded by the ABI. When positive, it holds + // the jump offset from the start of the table. + Verdict: int32(jt.Offset), + } + + ret := make([]byte, 0, linux.SizeOfXTStandardTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + // translateFromStandardVerdict translates verdicts the same way as the iptables // tool. func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 { @@ -335,7 +354,8 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { // Convert input into a list of rules and their offsets. var offset uint32 - var offsets []uint32 + // offsets maps rule byte offsets to their position in table.Rules. + offsets := map[uint32]int{} for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { nflog("set entries: processing entry at offset %d", offset) @@ -396,11 +416,12 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { Target: target, Matchers: matchers, }) - offsets = append(offsets, offset) + offsets[offset] = int(entryIdx) offset += uint32(entry.NextOffset) if initialOptValLen-len(optVal) != int(entry.NextOffset) { nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal)) + return syserr.ErrInvalidArgument } } @@ -409,13 +430,13 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error { for hook, _ := range replace.HookEntry { if table.ValidHooks()&(1<= 0 indicates a jump. + return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil + } case errorTargetName: // Error target. diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go new file mode 100644 index 000000000..c421b87cf --- /dev/null +++ b/pkg/sentry/socket/netfilter/targets.go @@ -0,0 +1,35 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/iptables" +) + +// JumpTarget implements iptables.Target. +type JumpTarget struct { + // Offset is the byte offset of the rule to jump to. It is used for + // marshaling and unmarshaling. + Offset uint32 + + // RuleNum is the rule to jump to. + RuleNum int +} + +// Action implements iptables.Target.Action. +func (jt JumpTarget) Action(tcpip.PacketBuffer) (iptables.RuleVerdict, int) { + return iptables.RuleJump, jt.RuleNum +} diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index 75a433a3b..dbaccbb36 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -135,25 +135,53 @@ func EmptyFilterTable() Table { } } +// A chainVerdict is what a table decides should be done with a packet. +type chainVerdict int + +const ( + // chainAccept indicates the packet should continue through netstack. + chainAccept chainVerdict = iota + + // chainAccept indicates the packet should be dropped. + chainDrop + + // chainReturn indicates the packet should return to the calling chain + // or the underflow rule of a builtin chain. + chainReturn +) + // Check runs pkt through the rules for hook. It returns true when the packet // should continue traversing the network stack and false when it should be // dropped. // // Precondition: pkt.NetworkHeader is set. func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool { - // TODO(gvisor.dev/issue/170): A lot of this is uncomplicated because - // we're missing features. Jumps, the call stack, etc. aren't checked - // for yet because we're yet to support them. - // Go through each table containing the hook. for _, tablename := range it.Priorities[hook] { - switch verdict := it.checkTable(hook, pkt, tablename); verdict { + table := it.Tables[tablename] + ruleIdx := table.BuiltinChains[hook] + switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict { // If the table returns Accept, move on to the next table. - case TableAccept: + case chainAccept: continue // The Drop verdict is final. - case TableDrop: + case chainDrop: return false + case chainReturn: + // Any Return from a built-in chain means we have to + // call the underflow. + underflow := table.Rules[table.Underflows[hook]] + switch v, _ := underflow.Target.Action(pkt); v { + case RuleAccept: + continue + case RuleDrop: + return false + case RuleJump, RuleReturn: + panic("Underflows should only return RuleAccept or RuleDrop.") + default: + panic(fmt.Sprintf("Unknown verdict: %d", v)) + } + default: panic(fmt.Sprintf("Unknown verdict %v.", verdict)) } @@ -164,37 +192,37 @@ func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool { } // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename string) TableVerdict { +func (it *IPTables) checkChain(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. - table := it.Tables[tablename] - for ruleIdx := table.BuiltinChains[hook]; ruleIdx < len(table.Rules); ruleIdx++ { - switch verdict := it.checkRule(hook, pkt, table, ruleIdx); verdict { + for ruleIdx < len(table.Rules) { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict { case RuleAccept: - return TableAccept + return chainAccept case RuleDrop: - return TableDrop - - case RuleContinue: - continue + return chainDrop case RuleReturn: - // TODO(gvisor.dev/issue/170): We don't implement jump - // yet, so any Return is from a built-in chain. That - // means we have to to call the underflow. - underflow := table.Rules[table.Underflows[hook]] - // Underflow is guaranteed to be an unconditional - // ACCEPT or DROP. - switch v, _ := underflow.Target.Action(pkt); v { - case RuleAccept: - return TableAccept - case RuleDrop: - return TableDrop - case RuleContinue, RuleReturn: - panic("Underflows should only return RuleAccept or RuleDrop.") + return chainReturn + + case RuleJump: + // "Jumping" to the next rule just means we're + // continuing on down the list. + if jumpTo == ruleIdx+1 { + ruleIdx++ + continue + } + switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict { + case chainAccept: + return chainAccept + case chainDrop: + return chainDrop + case chainReturn: + ruleIdx++ + continue default: - panic(fmt.Sprintf("Unknown verdict: %d", v)) + panic(fmt.Sprintf("Unknown verdict: %d", verdict)) } default: @@ -205,17 +233,18 @@ func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename stri // We got through the entire table without a decision. Default to DROP // for safety. - return TableDrop + return chainDrop } // Precondition: pk.NetworkHeader is set. -func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) RuleVerdict { +func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // First check whether the packet matches the IP header filter. // TODO(gvisor.dev/issue/170): Support other fields of the filter. if rule.Filter.Protocol != 0 && rule.Filter.Protocol != header.IPv4(pkt.NetworkHeader).TransportProtocol() { - return RuleContinue + // Continue on to the next rule. + return RuleJump, ruleIdx + 1 } // Go through each rule matcher. If they all match, run @@ -223,14 +252,14 @@ func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ru for _, matcher := range rule.Matchers { matches, hotdrop := matcher.Match(hook, pkt, "") if hotdrop { - return RuleDrop + return RuleDrop, 0 } if !matches { - return RuleContinue + // Continue on to the next rule. + return RuleJump, ruleIdx + 1 } } // All the matchers matched, so run the target. - verdict, _ := rule.Target.Action(pkt) - return verdict + return rule.Target.Action(pkt) } diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go index 9fc60cfad..81a2e39a2 100644 --- a/pkg/tcpip/iptables/targets.go +++ b/pkg/tcpip/iptables/targets.go @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// This file contains various Targets. - package iptables import ( @@ -25,16 +23,16 @@ import ( type AcceptTarget struct{} // Action implements Target.Action. -func (AcceptTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, string) { - return RuleAccept, "" +func (AcceptTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) { + return RuleAccept, 0 } // DropTarget drops packets. type DropTarget struct{} // Action implements Target.Action. -func (DropTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, string) { - return RuleDrop, "" +func (DropTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) { + return RuleDrop, 0 } // ErrorTarget logs an error and drops the packet. It represents a target that @@ -42,9 +40,9 @@ func (DropTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, string) { type ErrorTarget struct{} // Action implements Target.Action. -func (ErrorTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, string) { +func (ErrorTarget) Action(packet tcpip.PacketBuffer) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") - return RuleDrop, "" + return RuleDrop, 0 } // UserChainTarget marks a rule as the beginning of a user chain. @@ -53,7 +51,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (UserChainTarget) Action(tcpip.PacketBuffer) (RuleVerdict, string) { +func (UserChainTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -62,6 +60,6 @@ func (UserChainTarget) Action(tcpip.PacketBuffer) (RuleVerdict, string) { type ReturnTarget struct{} // Action implements Target.Action. -func (ReturnTarget) Action(tcpip.PacketBuffer) (RuleVerdict, string) { - return RuleReturn, "" +func (ReturnTarget) Action(tcpip.PacketBuffer) (RuleVerdict, int) { + return RuleReturn, 0 } diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 5735d001b..7d032fd23 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -56,17 +56,6 @@ const ( NumHooks ) -// A TableVerdict is what a table decides should be done with a packet. -type TableVerdict int - -const ( - // TableAccept indicates the packet should continue through netstack. - TableAccept TableVerdict = iota - - // TableAccept indicates the packet should be dropped. - TableDrop -) - // A RuleVerdict is what a rule decides should be done with a packet. type RuleVerdict int @@ -74,12 +63,12 @@ const ( // RuleAccept indicates the packet should continue through netstack. RuleAccept RuleVerdict = iota - // RuleContinue indicates the packet should continue to the next rule. - RuleContinue - // RuleDrop indicates the packet should be dropped. RuleDrop + // RuleJump indicates the packet should jump to another chain. + RuleJump + // RuleReturn indicates the packet should return to the previous chain. RuleReturn ) @@ -174,6 +163,6 @@ type Matcher interface { type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is - // Jump, it also returns the name of the chain to jump to. - Action(packet tcpip.PacketBuffer) (RuleVerdict, string) + // Jump, it also returns the index of the rule to jump to. + Action(packet tcpip.PacketBuffer) (RuleVerdict, int) } diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index e26d6a7d2..706c09cea 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -26,6 +26,7 @@ const ( acceptPort = 2402 sendloopDuration = 2 * time.Second network = "udp4" + chainName = "foochain" ) func init() { @@ -40,6 +41,12 @@ func init() { RegisterTestCase(FilterInputDefaultPolicyAccept{}) RegisterTestCase(FilterInputDefaultPolicyDrop{}) RegisterTestCase(FilterInputReturnUnderflow{}) + RegisterTestCase(FilterInputSerializeJump{}) + RegisterTestCase(FilterInputJumpBasic{}) + RegisterTestCase(FilterInputJumpReturn{}) + RegisterTestCase(FilterInputJumpReturnDrop{}) + RegisterTestCase(FilterInputJumpBuiltin{}) + RegisterTestCase(FilterInputJumpTwice{}) } // FilterInputDropUDP tests that we can drop UDP traffic. @@ -267,13 +274,12 @@ func (FilterInputMultiUDPRules) Name() string { // ContainerAction implements TestCase.ContainerAction. func (FilterInputMultiUDPRules) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { - return err - } - if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"); err != nil { - return err + rules := [][]string{ + {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"}, + {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"}, + {"-L"}, } - return filterTable("-L") + return filterTableRules(rules) } // LocalAction implements TestCase.LocalAction. @@ -314,14 +320,13 @@ func (FilterInputCreateUserChain) Name() string { // ContainerAction implements TestCase.ContainerAction. func (FilterInputCreateUserChain) ContainerAction(ip net.IP) error { - // Create a chain. - const chainName = "foochain" - if err := filterTable("-N", chainName); err != nil { - return err + rules := [][]string{ + // Create a chain. + {"-N", chainName}, + // Add a simple rule to the chain. + {"-A", chainName, "-j", "DROP"}, } - - // Add a simple rule to the chain. - return filterTable("-A", chainName, "-j", "DROP") + return filterTableRules(rules) } // LocalAction implements TestCase.LocalAction. @@ -396,13 +401,12 @@ func (FilterInputReturnUnderflow) Name() string { func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error { // Add a RETURN rule followed by an unconditional accept, and set the // default policy to DROP. - if err := filterTable("-A", "INPUT", "-j", "RETURN"); err != nil { - return err + rules := [][]string{ + {"-A", "INPUT", "-j", "RETURN"}, + {"-A", "INPUT", "-j", "DROP"}, + {"-P", "INPUT", "ACCEPT"}, } - if err := filterTable("-A", "INPUT", "-j", "DROP"); err != nil { - return err - } - if err := filterTable("-P", "INPUT", "ACCEPT"); err != nil { + if err := filterTableRules(rules); err != nil { return err } @@ -415,3 +419,178 @@ func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error { func (FilterInputReturnUnderflow) LocalAction(ip net.IP) error { return sendUDPLoop(ip, acceptPort, sendloopDuration) } + +// Verify that we can serialize jumps. +type FilterInputSerializeJump struct{} + +// Name implements TestCase.Name. +func (FilterInputSerializeJump) Name() string { + return "FilterInputSerializeJump" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputSerializeJump) ContainerAction(ip net.IP) error { + // Write a JUMP rule, the serialize it with `-L`. + rules := [][]string{ + {"-N", chainName}, + {"-A", "INPUT", "-j", chainName}, + {"-L"}, + } + return filterTableRules(rules) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputSerializeJump) LocalAction(ip net.IP) error { + // No-op. + return nil +} + +// Jump to a chain and execute a rule there. +type FilterInputJumpBasic struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpBasic) Name() string { + return "FilterInputJumpBasic" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpBasic) ContainerAction(ip net.IP) error { + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-N", chainName}, + {"-A", "INPUT", "-j", chainName}, + {"-A", chainName, "-j", "ACCEPT"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + // Listen for UDP packets on acceptPort. + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpBasic) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +// Jump, return, and execute a rule. +type FilterInputJumpReturn struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpReturn) Name() string { + return "FilterInputJumpReturn" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpReturn) ContainerAction(ip net.IP) error { + rules := [][]string{ + {"-N", chainName}, + {"-P", "INPUT", "ACCEPT"}, + {"-A", "INPUT", "-j", chainName}, + {"-A", chainName, "-j", "RETURN"}, + {"-A", chainName, "-j", "DROP"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + // Listen for UDP packets on acceptPort. + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpReturn) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} + +type FilterInputJumpReturnDrop struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpReturnDrop) Name() string { + return "FilterInputJumpReturnDrop" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP) error { + rules := [][]string{ + {"-N", chainName}, + {"-A", "INPUT", "-j", chainName}, + {"-A", "INPUT", "-j", "DROP"}, + {"-A", chainName, "-j", "RETURN"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + // Listen for UDP packets on dropPort. + if err := listenUDP(dropPort, sendloopDuration); err == nil { + return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) + } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + return fmt.Errorf("error reading: %v", err) + } + + // At this point we know that reading timed out and never received a + // packet. + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpReturnDrop) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, dropPort, sendloopDuration) +} + +// Jumping to a top-levl chain is illegal. +type FilterInputJumpBuiltin struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpBuiltin) Name() string { + return "FilterInputJumpBuiltin" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpBuiltin) ContainerAction(ip net.IP) error { + if err := filterTable("-A", "INPUT", "-j", "OUTPUT"); err == nil { + return fmt.Errorf("iptables should be unable to jump to a built-in chain") + } + return nil +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpBuiltin) LocalAction(ip net.IP) error { + // No-op. + return nil +} + +// Jump twice, then return twice and execute a rule. +type FilterInputJumpTwice struct{} + +// Name implements TestCase.Name. +func (FilterInputJumpTwice) Name() string { + return "FilterInputJumpTwice" +} + +// ContainerAction implements TestCase.ContainerAction. +func (FilterInputJumpTwice) ContainerAction(ip net.IP) error { + const chainName2 = chainName + "2" + rules := [][]string{ + {"-P", "INPUT", "DROP"}, + {"-N", chainName}, + {"-N", chainName2}, + {"-A", "INPUT", "-j", chainName}, + {"-A", chainName, "-j", chainName2}, + {"-A", "INPUT", "-j", "ACCEPT"}, + } + if err := filterTableRules(rules); err != nil { + return err + } + + // UDP packets should jump and return twice, eventually hitting the + // ACCEPT rule. + return listenUDP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (FilterInputJumpTwice) LocalAction(ip net.IP) error { + return sendUDPLoop(ip, acceptPort, sendloopDuration) +} diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 46a7c99b0..0621861eb 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -249,3 +249,39 @@ func TestFilterOutputDropTCPSrcPort(t *testing.T) { t.Fatal(err) } } + +func TestJumpSerialize(t *testing.T) { + if err := singleTest(FilterInputSerializeJump{}); err != nil { + t.Fatal(err) + } +} + +func TestJumpBasic(t *testing.T) { + if err := singleTest(FilterInputJumpBasic{}); err != nil { + t.Fatal(err) + } +} + +func TestJumpReturn(t *testing.T) { + if err := singleTest(FilterInputJumpReturn{}); err != nil { + t.Fatal(err) + } +} + +func TestJumpReturnDrop(t *testing.T) { + if err := singleTest(FilterInputJumpReturnDrop{}); err != nil { + t.Fatal(err) + } +} + +func TestJumpBuiltin(t *testing.T) { + if err := singleTest(FilterInputJumpBuiltin{}); err != nil { + t.Fatal(err) + } +} + +func TestJumpTwice(t *testing.T) { + if err := singleTest(FilterInputJumpTwice{}); err != nil { + t.Fatal(err) + } +} diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index 043114c78..293c4e6ed 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -35,6 +35,16 @@ func filterTable(args ...string) error { return nil } +// filterTableRules is like filterTable, but runs multiple iptables commands. +func filterTableRules(argsList [][]string) error { + for _, args := range argsList { + if err := filterTable(args...); err != nil { + return err + } + } + return nil +} + // listenUDP listens on a UDP port and returns the value of net.Conn.Read() for // the first read on that port. func listenUDP(port int, timeout time.Duration) error { -- cgit v1.2.3 From 56fd9504aab44a738d3df164cbee8e572b309f28 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Tue, 18 Feb 2020 15:44:22 -0800 Subject: Enable IPV6_RECVTCLASS socket option for datagram sockets Added the ability to get/set the IP_RECVTCLASS socket option on UDP endpoints. If enabled, traffic class from the incoming Network Header passed as ancillary data in the ControlMessages. Adding Get/SetSockOptBool to decrease the overhead of getting/setting simple options. (This was absorbed in a CL that will be landing before this one). Test: * Added unit test to udp_test.go that tests getting/setting as well as verifying that we receive expected TOS from incoming packet. * Added a syscall test for verifying getting/setting * Removed test skip for existing syscall test to enable end to end test. PiperOrigin-RevId: 295840218 --- pkg/sentry/socket/control/control.go | 2 +- pkg/sentry/socket/netstack/netstack.go | 27 +++++- pkg/tcpip/checker/checker.go | 14 +++ pkg/tcpip/tcpip.go | 15 ++- pkg/tcpip/transport/udp/endpoint.go | 38 +++++++- pkg/tcpip/transport/udp/udp_test.go | 120 ++++++++++++++---------- test/syscalls/linux/ip_socket_test_util.h | 16 ++-- test/syscalls/linux/socket_ip_udp_generic.cc | 133 +++++++++++++++++++-------- test/syscalls/linux/udp_socket_test_cases.cc | 4 - 9 files changed, 260 insertions(+), 109 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 4667373d2..8834a1e1a 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -329,7 +329,7 @@ func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { } // PackTClass packs an IPV6_TCLASS socket control message. -func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte { +func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IPV6, diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9757fbfba..e187276c5 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1318,6 +1318,22 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } return ib, nil + case linux.IPV6_RECVTCLASS: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + var o int32 + if v { + o = 1 + } + return o, nil + default: emitUnimplementedEventIPv6(t, name) } @@ -1803,6 +1819,14 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v))) + case linux.IPV6_RECVTCLASS: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) + default: emitUnimplementedEventIPv6(t, name) } @@ -2086,7 +2110,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_RECVPATHMTU, linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, - linux.IPV6_RECVTCLASS, linux.IPV6_RTHDR, linux.IPV6_RTHDRDSTOPTS, linux.IPV6_TCLASS, @@ -2424,6 +2447,8 @@ func (s *SocketOperations) controlMessages() socket.ControlMessages { Timestamp: s.readCM.Timestamp, HasTOS: s.readCM.HasTOS, TOS: s.readCM.TOS, + HasTClass: s.readCM.HasTClass, + TClass: s.readCM.TClass, HasIPPacketInfo: s.readCM.HasIPPacketInfo, PacketInfo: s.readCM.PacketInfo, }, diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 4d6ae0871..c6c160dfc 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -161,6 +161,20 @@ func FragmentFlags(flags uint8) NetworkChecker { } } +// ReceiveTClass creates a checker that checks the TCLASS field in +// ControlMessages. +func ReceiveTClass(want uint32) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasTClass { + t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want) + } + if got := cm.TClass; got != want { + t.Fatalf("got cm.TClass = %d, want %d", got, want) + } + } +} + // ReceiveTOS creates a checker that checks the TOS field in ControlMessages. func ReceiveTOS(want uint8) ControlMessagesChecker { return func(t *testing.T, cm tcpip.ControlMessages) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 9ca39ce40..ce5527391 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -323,11 +323,11 @@ type ControlMessages struct { // TOS is the IPv4 type of service of the associated packet. TOS uint8 - // HasTClass indicates whether Tclass is valid/set. + // HasTClass indicates whether TClass is valid/set. HasTClass bool - // Tclass is the IPv6 traffic class of the associated packet. - TClass int32 + // TClass is the IPv6 traffic class of the associated packet. + TClass uint32 // HasIPPacketInfo indicates whether PacketInfo is set. HasIPPacketInfo bool @@ -502,9 +502,13 @@ type WriteOptions struct { type SockOptBool int const ( + // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the + // IPV6_TCLASS ancillary message is passed with incoming packets. + ReceiveTClassOption SockOptBool = iota + // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS // ancillary message is passed with incoming packets. - ReceiveTOSOption SockOptBool = iota + ReceiveTOSOption // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 // socket is to be restricted to sending and receiving IPv6 packets only. @@ -514,6 +518,9 @@ const ( // if more inforamtion is provided with incoming packets such // as interface index and address. ReceiveIPPacketInfoOption + + // TODO(b/146901447): convert existing bool socket options to be handled via + // Get/SetSockOptBool ) // SockOptInt represents socket options which values have the int type. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 3fe91cac2..eff7f3600 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -32,7 +32,8 @@ type udpPacket struct { packetInfo tcpip.IPPacketInfo data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - tos uint8 + // tos stores either the receiveTOS or receiveTClass value. + tos uint8 } // EndpointState represents the state of a UDP endpoint. @@ -119,6 +120,10 @@ type endpoint struct { // as ancillary data to ControlMessages on Read. receiveTOS bool + // receiveTClass determines if the incoming IPv6 TClass header field is + // passed as ancillary data to ControlMessages on Read. + receiveTClass bool + // receiveIPPacketInfo determines if the packet info is returned by Read. receiveIPPacketInfo bool @@ -258,13 +263,18 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess } e.mu.RLock() receiveTOS := e.receiveTOS + receiveTClass := e.receiveTClass receiveIPPacketInfo := e.receiveIPPacketInfo e.mu.RUnlock() if receiveTOS { cm.HasTOS = true cm.TOS = p.tos } - + if receiveTClass { + cm.HasTClass = true + // Although TClass is an 8-bit value it's read in the CMsg as a uint32. + cm.TClass = uint32(p.tos) + } if receiveIPPacketInfo { cm.HasIPPacketInfo = true cm.PacketInfo = p.packetInfo @@ -490,6 +500,17 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.ReceiveTClassOption: + // We only support this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return tcpip.ErrNotSupported + } + + e.mu.Lock() + e.receiveTClass = v + e.mu.Unlock() + return nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -709,6 +730,17 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.ReceiveTClassOption: + // We only support this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrNotSupported + } + + e.mu.RLock() + v := e.receiveTClass + e.mu.RUnlock() + return v, nil + case tcpip.V6OnlyOption: // We only recognize this option on v6 endpoints. if e.NetProto != header.IPv6ProtocolNumber { @@ -1273,6 +1305,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk packet.packetInfo.LocalAddr = r.LocalAddress packet.packetInfo.DestinationAddr = r.RemoteAddress packet.packetInfo.NIC = r.NICID() + case header.IPv6ProtocolNumber: + packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() } packet.timestamp = e.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index f0ff3fe71..34b7c2360 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -409,6 +409,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ + TrafficClass: testTOS, PayloadLength: uint16(header.UDPMinimumSize + len(payload)), NextHeader: uint8(udp.ProtocolNumber), HopLimit: 65, @@ -1336,7 +1337,7 @@ func TestSetTTL(t *testing.T) { } } -func TestTOSV4(t *testing.T) { +func TestSetTOS(t *testing.T) { for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { c := newDualTestContext(t, defaultMTU) @@ -1347,23 +1348,23 @@ func TestTOSV4(t *testing.T) { const tos = testTOS var v tcpip.IPv4TOSOption if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + c.t.Errorf("GetSockopt(%T) failed: %s", v, err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) + c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) } if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil { - c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err) + c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err) } if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + c.t.Errorf("GetSockopt(%T) failed: %s", v, err) } if want := tcpip.IPv4TOSOption(tos); v != want { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) } testWrite(c, flow, checker.TOS(tos, 0)) @@ -1371,7 +1372,7 @@ func TestTOSV4(t *testing.T) { } } -func TestTOSV6(t *testing.T) { +func TestSetTClass(t *testing.T) { for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { c := newDualTestContext(t, defaultMTU) @@ -1379,71 +1380,92 @@ func TestTOSV6(t *testing.T) { c.createEndpointForFlow(flow) - const tos = testTOS + const tClass = testTOS var v tcpip.IPv6TrafficClassOption if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + c.t.Errorf("GetSockopt(%T) failed: %s", v, err) } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0) + c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0) } - if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil { - c.t.Errorf("SetSockOpt failed: %s", err) + if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil { + c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err) } if err := c.ep.GetSockOpt(&v); err != nil { - c.t.Errorf("GetSockopt failed: %s", err) + c.t.Errorf("GetSockopt(%T) failed: %s", v, err) } - if want := tcpip.IPv6TrafficClassOption(tos); v != want { - c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want) + if want := tcpip.IPv6TrafficClassOption(tClass); v != want { + c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want) } - testWrite(c, flow, checker.TOS(tos, 0)) + // The header getter for TClass is called TOS, so use that checker. + testWrite(c, flow, checker.TOS(tClass, 0)) }) } } -func TestReceiveTOSV4(t *testing.T) { - for _, flow := range []testFlow{unicastV4, broadcast} { - t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() +func TestReceiveTosTClass(t *testing.T) { + testCases := []struct { + name string + getReceiveOption tcpip.SockOptBool + tests []testFlow + }{ + {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}}, + {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, + } + for _, testCase := range testCases { + for _, flow := range testCase.tests { + t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - c.createEndpointForFlow(flow) + c.createEndpointForFlow(flow) + option := testCase.getReceiveOption + name := testCase.name - // Verify that setting and reading the option works. - v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err) - } - // Test for expected default value. - if v != false { - c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false) - } + // Verify that setting and reading the option works. + v, err := c.ep.GetSockOptBool(option) + if err != nil { + c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) + } + // Test for expected default value. + if v != false { + c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) + } - want := true - if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil { - c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err) - } + want := true + if err := c.ep.SetSockOptBool(option, want); err != nil { + c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err) + } - got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err) - } - if got != want { - c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want) - } + got, err := c.ep.GetSockOptBool(option) + if err != nil { + c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err) + } - // Verify that the correct received TOS is handed through as - // ancillary data to the ControlMessages struct. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatal("Bind failed:", err) - } - testRead(c, flow, checker.ReceiveTOS(testTOS)) - }) + if got != want { + c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want) + } + + // Verify that the correct received TOS or TClass is handed through as + // ancillary data to the ControlMessages struct. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + switch option { + case tcpip.ReceiveTClassOption: + testRead(c, flow, checker.ReceiveTClass(testTOS)) + case tcpip.ReceiveTOSOption: + testRead(c, flow, checker.ReceiveTOS(testTOS)) + default: + t.Fatalf("unknown test variant: %s", name) + } + }) + } } } diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h index 083ebbcf0..39fd6709d 100644 --- a/test/syscalls/linux/ip_socket_test_util.h +++ b/test/syscalls/linux/ip_socket_test_util.h @@ -84,20 +84,20 @@ SocketPairKind DualStackUDPBidirectionalBindSocketPair(int type); // SocketPairs created with AF_INET and the given type. SocketPairKind IPv4UDPUnboundSocketPair(int type); -// IPv4UDPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET, SOCK_DGRAM, and the given type. +// IPv4UDPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET, SOCK_DGRAM, and the given type. SocketKind IPv4UDPUnboundSocket(int type); -// IPv6UDPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET6, SOCK_DGRAM, and the given type. +// IPv6UDPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET6, SOCK_DGRAM, and the given type. SocketKind IPv6UDPUnboundSocket(int type); -// IPv4TCPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET, SOCK_STREAM and the given type. +// IPv4TCPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET, SOCK_STREAM and the given type. SocketKind IPv4TCPUnboundSocket(int type); -// IPv6TCPUnboundSocketPair returns a SocketKind that represents -// a SimpleSocket created with AF_INET6, SOCK_STREAM and the given type. +// IPv6TCPUnboundSocket returns a SocketKind that represents a SimpleSocket +// created with AF_INET6, SOCK_STREAM and the given type. SocketKind IPv6TCPUnboundSocket(int type); // IfAddrHelper is a helper class that determines the local interfaces present diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index db5663ecd..1c533fdf2 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -14,6 +14,7 @@ #include "test/syscalls/linux/socket_ip_udp_generic.h" +#include #include #include #include @@ -209,46 +210,6 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) { EXPECT_EQ(get, kSockOptOn); } -// Ensure that Receiving TOS is off by default. -TEST_P(UDPSocketPairTest, RecvTosDefault) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); -} - -// Test that setting and getting IP_RECVTOS works as expected. -TEST_P(UDPSocketPairTest, SetRecvTos) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, - &kSockOptOff, sizeof(kSockOptOff)), - SyscallSucceeds()); - - int get = -1; - socklen_t get_len = sizeof(get); - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOff); - - ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, - &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); - - ASSERT_THAT( - getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len), - SyscallSucceedsWithValue(0)); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kSockOptOn); -} - TEST_P(UDPSocketPairTest, ReuseAddrDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -401,5 +362,97 @@ TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) { EXPECT_EQ(get_len, sizeof(get)); } +// Holds TOS or TClass information for IPv4 or IPv6 respectively. +struct RecvTosOption { + int level; + int option; +}; + +RecvTosOption GetRecvTosOption(int domain) { + TEST_CHECK(domain == AF_INET || domain == AF_INET6); + RecvTosOption opt; + switch (domain) { + case AF_INET: + opt.level = IPPROTO_IP; + opt.option = IP_RECVTOS; + break; + case AF_INET6: + opt.level = IPPROTO_IPV6; + opt.option = IPV6_RECVTCLASS; + break; + } + return opt; +} + +// Ensure that Receiving TOS or TCLASS is off by default. +TEST_P(UDPSocketPairTest, RecvTosDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(GetParam().domain); + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +// Test that setting and getting IP_RECVTOS or IPV6_RECVTCLASS works as +// expected. +TEST_P(UDPSocketPairTest, SetRecvTos) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(GetParam().domain); + + ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOff, + sizeof(kSockOptOff)), + SyscallSucceeds()); + + int get = -1; + socklen_t get_len = sizeof(get); + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); + + ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOn); +} + +// Test that any socket (including IPv6 only) accepts the IPv4 TOS option: this +// mirrors behavior in linux. +TEST_P(UDPSocketPairTest, TOSRecvMismatch) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + RecvTosOption t = GetRecvTosOption(AF_INET); + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT( + getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len), + SyscallSucceedsWithValue(0)); +} + +// Test that an IPv4 socket does not support the IPv6 TClass option. +TEST_P(UDPSocketPairTest, TClassRecvMismatch) { + // This should only test AF_INET sockets for the mismatch behavior. + SKIP_IF(GetParam().domain != AF_INET); + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IPV6, IPV6_RECVTCLASS, + &get, &get_len), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index 9f8de6b48..57b1a357c 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -1349,9 +1349,6 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { // outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SetAndReceiveTOS) { - // TODO(b/144868438): IPV6_RECVTCLASS not supported for netstack. - SKIP_IF((GetParam() != AddressFamily::kIpv4) && IsRunningOnGvisor() && - !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); @@ -1422,7 +1419,6 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { // TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SendAndReceiveTOS) { - // TODO(b/144868438): IPV6_RECVTCLASS not supported for netstack. // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); -- cgit v1.2.3 From 92d2d78876a938871327685e1104d7b4ff46986e Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Tue, 18 Feb 2020 21:20:41 -0800 Subject: Fix mis-named comment. --- pkg/tcpip/stack/registration.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index ec91f60dd..d83adf0ec 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -277,7 +277,7 @@ type NetworkProtocol interface { // DefaultPrefixLen returns the protocol's default prefix length. DefaultPrefixLen() int - // ParsePorts returns the source and destination addresses stored in a + // ParseAddresses returns the source and destination addresses stored in a // packet of this protocol. ParseAddresses(v buffer.View) (src, dst tcpip.Address) -- cgit v1.2.3 From 67b615b86f2aa1d4ded3dcf2eb8aca4e7fec57a0 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 20 Feb 2020 14:31:39 -0800 Subject: Support disabling a NIC - Disabled NICs will have their associated NDP state cleared. - Disabled NICs will not accept incoming packets. - Writes through a Route with a disabled NIC will return an invalid endpoint state error. - stack.Stack.FindRoute will not return a route with a disabled NIC. - NIC's Running flag will report the NIC's enabled status. Tests: - stack_test.TestDisableUnknownNIC - stack_test.TestDisabledNICsNICInfoAndCheckNIC - stack_test.TestRoutesWithDisabledNIC - stack_test.TestRouteWritePacketWithDisabledNIC - stack_test.TestStopStartSolicitingRouters - stack_test.TestCleanupNDPState - stack_test.TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable - stack_test.TestJoinLeaveAllNodesMulticastOnNICEnableDisable PiperOrigin-RevId: 296298588 --- pkg/tcpip/stack/ndp.go | 23 +- pkg/tcpip/stack/ndp_test.go | 822 ++++++++++++++++++++++++------------------ pkg/tcpip/stack/nic.go | 110 +++++- pkg/tcpip/stack/stack.go | 45 ++- pkg/tcpip/stack/stack_test.go | 377 ++++++++++++++++++- 5 files changed, 998 insertions(+), 379 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 045409bda..19bd05aa3 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -1148,22 +1148,27 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo return true } -// cleanupHostOnlyState cleans up any state that is only useful for hosts. +// cleanupState cleans up ndp's state. // -// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a -// host to a router. This function will invalidate all discovered on-link -// prefixes, discovered routers, and auto-generated addresses as routers do not -// normally process Router Advertisements to discover default routers and -// on-link prefixes, and auto-generate addresses via SLAAC. +// If hostOnly is true, then only host-specific state will be cleaned up. +// +// cleanupState MUST be called with hostOnly set to true when ndp's NIC is +// transitioning from a host to a router. This function will invalidate all +// discovered on-link prefixes, discovered routers, and auto-generated +// addresses. +// +// If hostOnly is true, then the link-local auto-generated address will not be +// invalidated as routers are also expected to generate a link-local address. // // The NIC that ndp belongs to MUST be locked. -func (ndp *ndpState) cleanupHostOnlyState() { +func (ndp *ndpState) cleanupState(hostOnly bool) { linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet() linkLocalAddrs := 0 for addr := range ndp.autoGenAddresses { // RFC 4862 section 5 states that routers are also expected to generate a - // link-local address so we do not invalidate them. - if linkLocalSubnet.Contains(addr) { + // link-local address so we do not invalidate them if we are cleaning up + // host-only state. + if hostOnly && linkLocalSubnet.Contains(addr) { linkLocalAddrs++ continue } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 1f6f77439..f7b75b74e 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -592,70 +592,94 @@ func TestDADFail(t *testing.T) { } } -// TestDADStop tests to make sure that the DAD process stops when an address is -// removed. func TestDADStop(t *testing.T) { const nicID = 1 - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - } - ndpConfigs := stack.NDPConfigurations{ - RetransmitTimer: time.Second, - DupAddrDetectTransmits: 2, - } - opts := stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPDisp: &ndpDisp, - NDPConfigs: ndpConfigs, - } + tests := []struct { + name string + stopFn func(t *testing.T, s *stack.Stack) + }{ + // Tests to make sure that DAD stops when an address is removed. + { + name: "Remove address", + stopFn: func(t *testing.T, s *stack.Stack) { + if err := s.RemoveAddress(nicID, addr1); err != nil { + t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err) + } + }, + }, - e := channel.New(0, 1280, linkAddr1) - s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + // Tests to make sure that DAD stops when the NIC is disabled. + { + name: "Disable NIC", + stopFn: func(t *testing.T, s *stack.Stack) { + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + }, + }, } - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + } + ndpConfigs := stack.NDPConfigurations{ + RetransmitTimer: time.Second, + DupAddrDetectTransmits: 2, + } + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + } - // Address should not be considered bound to the NIC yet (DAD ongoing). - addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } + e := channel.New(0, 1280, linkAddr1) + s := stack.New(opts) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } - // Remove the address. This should stop DAD. - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) - } + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err) + } - // Wait for DAD to fail (since the address was removed during DAD). - select { - case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } - } - addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) - if err != nil { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) - } - if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) - } + // Address should not be considered bound to the NIC yet (DAD ongoing). + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } + + test.stopFn(t, s) + + // Wait for DAD to fail (since the address was removed during DAD). + select { + case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second): + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + } + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) + } + if want := (tcpip.AddressWithPrefix{}); addr != want { + t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) + } - // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { - t.Fatalf("got NeighborSolicit = %d, want <= 1", got) + // Should not have sent more than 1 NS message. + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { + t.Errorf("got NeighborSolicit = %d, want <= 1", got) + } + }) } } @@ -2886,17 +2910,16 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) { } } -// TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers -// and prefixes, and non-linklocal auto-generated addresses are invalidated when -// a NIC becomes a router. -func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { +// TestCleanupNDPState tests that all discovered routers and prefixes, and +// auto-generated addresses are invalidated when a NIC becomes a router. +func TestCleanupNDPState(t *testing.T) { t.Parallel() const ( - lifetimeSeconds = 5 - maxEvents = 4 - nicID1 = 1 - nicID2 = 2 + lifetimeSeconds = 5 + maxRouterAndPrefixEvents = 4 + nicID1 = 1 + nicID2 = 2 ) prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1) @@ -2912,254 +2935,308 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) { PrefixLen: 64, } - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, maxEvents), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, maxEvents), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents), - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - AutoGenIPv6LinkLocal: true, - NDPConfigs: stack.NDPConfigurations{ - HandleRAs: true, - DiscoverDefaultRouters: true, - DiscoverOnLinkPrefixes: true, - AutoGenGlobalAddresses: true, + tests := []struct { + name string + cleanupFn func(t *testing.T, s *stack.Stack) + keepAutoGenLinkLocal bool + maxAutoGenAddrEvents int + }{ + // A NIC should still keep its auto-generated link-local address when + // becoming a router. + { + name: "Forwarding Enable", + cleanupFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + s.SetForwarding(true) + }, + keepAutoGenLinkLocal: true, + maxAutoGenAddrEvents: 4, }, - NDPDisp: &ndpDisp, - }) - expectRouterEvent := func() (bool, ndpRouterEvent) { - select { - case e := <-ndpDisp.routerC: - return true, e - default: - } + // A NIC should cleanup all NDP state when it is disabled. + { + name: "NIC Disable", + cleanupFn: func(t *testing.T, s *stack.Stack) { + t.Helper() - return false, ndpRouterEvent{} + if err := s.DisableNIC(nicID1); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) + } + if err := s.DisableNIC(nicID2); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) + } + }, + keepAutoGenLinkLocal: false, + maxAutoGenAddrEvents: 6, + }, } - expectPrefixEvent := func() (bool, ndpPrefixEvent) { - select { - case e := <-ndpDisp.prefixC: - return true, e - default: - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents), + rememberRouter: true, + prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), + rememberPrefix: true, + autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + AutoGenIPv6LinkLocal: true, + NDPConfigs: stack.NDPConfigurations{ + HandleRAs: true, + DiscoverDefaultRouters: true, + DiscoverOnLinkPrefixes: true, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + }) - return false, ndpPrefixEvent{} - } + expectRouterEvent := func() (bool, ndpRouterEvent) { + select { + case e := <-ndpDisp.routerC: + return true, e + default: + } - expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { - select { - case e := <-ndpDisp.autoGenAddrC: - return true, e - default: - } + return false, ndpRouterEvent{} + } - return false, ndpAutoGenAddrEvent{} - } + expectPrefixEvent := func() (bool, ndpPrefixEvent) { + select { + case e := <-ndpDisp.prefixC: + return true, e + default: + } - e1 := channel.New(0, 1280, linkAddr1) - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) - } - // We have other tests that make sure we receive the *correct* events - // on normal discovery of routers/prefixes, and auto-generated - // addresses. Here we just make sure we get an event and let other tests - // handle the correctness check. - expectAutoGenAddrEvent() + return false, ndpPrefixEvent{} + } - e2 := channel.New(0, 1280, linkAddr2) - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) - } - expectAutoGenAddrEvent() + expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) { + select { + case e := <-ndpDisp.autoGenAddrC: + return true, e + default: + } - // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and - // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from - // llAddr4) to discover multiple routers and prefixes, and auto-gen - // multiple addresses. + return false, ndpAutoGenAddrEvent{} + } - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) - } + e1 := channel.New(0, 1280, linkAddr1) + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err) + } + // We have other tests that make sure we receive the *correct* events + // on normal discovery of routers/prefixes, and auto-generated + // addresses. Here we just make sure we get an event and let other tests + // handle the correctness check. + expectAutoGenAddrEvent() - e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) - } + e2 := channel.New(0, 1280, linkAddr2) + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err) + } + expectAutoGenAddrEvent() - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) - } + // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and + // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from + // llAddr4) to discover multiple routers and prefixes, and auto-gen + // multiple addresses. - e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) - } - if ok, _ := expectPrefixEvent(); !ok { - t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) - } - if ok, _ := expectAutoGenAddrEvent(); !ok { - t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) - } + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1) + } - // We should have the auto-generated addresses added. - nicinfo := s.NICInfo() - nic1Addrs := nicinfo[nicID1].ProtocolAddresses - nic2Addrs := nicinfo[nicID2].ProtocolAddresses - if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if !containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if !containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } + e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1) + } - // We can't proceed any further if we already failed the test (missing - // some discovery/auto-generated address events or addresses). - if t.Failed() { - t.FailNow() - } + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2) + } - s.SetForwarding(true) + e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) + if ok, _ := expectRouterEvent(); !ok { + t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) + } + if ok, _ := expectPrefixEvent(); !ok { + t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) + } + if ok, _ := expectAutoGenAddrEvent(); !ok { + t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2) + } - // Collect invalidation events after becoming a router - gotRouterEvents := make(map[ndpRouterEvent]int) - for i := 0; i < maxEvents; i++ { - ok, e := expectRouterEvent() - if !ok { - t.Errorf("expected %d router events after becoming a router; got = %d", maxEvents, i) - break - } - gotRouterEvents[e]++ - } - gotPrefixEvents := make(map[ndpPrefixEvent]int) - for i := 0; i < maxEvents; i++ { - ok, e := expectPrefixEvent() - if !ok { - t.Errorf("expected %d prefix events after becoming a router; got = %d", maxEvents, i) - break - } - gotPrefixEvents[e]++ - } - gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) - for i := 0; i < maxEvents; i++ { - ok, e := expectAutoGenAddrEvent() - if !ok { - t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", maxEvents, i) - break - } - gotAutoGenAddrEvents[e]++ - } + // We should have the auto-generated addresses added. + nicinfo := s.NICInfo() + nic1Addrs := nicinfo[nicID1].ProtocolAddresses + nic2Addrs := nicinfo[nicID2].ProtocolAddresses + if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } + if !containsV6Addr(nic1Addrs, e1Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if !containsV6Addr(nic1Addrs, e1Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } + if !containsV6Addr(nic2Addrs, e2Addr1) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if !containsV6Addr(nic2Addrs, e2Addr2) { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } - // No need to proceed any further if we already failed the test (missing - // some invalidation events). - if t.Failed() { - t.FailNow() - } + // We can't proceed any further if we already failed the test (missing + // some discovery/auto-generated address events or addresses). + if t.Failed() { + t.FailNow() + } - expectedRouterEvents := map[ndpRouterEvent]int{ - {nicID: nicID1, addr: llAddr3, discovered: false}: 1, - {nicID: nicID1, addr: llAddr4, discovered: false}: 1, - {nicID: nicID2, addr: llAddr3, discovered: false}: 1, - {nicID: nicID2, addr: llAddr4, discovered: false}: 1, - } - if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { - t.Errorf("router events mismatch (-want +got):\n%s", diff) - } - expectedPrefixEvents := map[ndpPrefixEvent]int{ - {nicID: nicID1, prefix: subnet1, discovered: false}: 1, - {nicID: nicID1, prefix: subnet2, discovered: false}: 1, - {nicID: nicID2, prefix: subnet1, discovered: false}: 1, - {nicID: nicID2, prefix: subnet2, discovered: false}: 1, - } - if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { - t.Errorf("prefix events mismatch (-want +got):\n%s", diff) - } - expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ - {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, - {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, - } - if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { - t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) - } + test.cleanupFn(t, s) - // Make sure the auto-generated addresses got removed. - nicinfo = s.NICInfo() - nic1Addrs = nicinfo[nicID1].ProtocolAddresses - nic2Addrs = nicinfo[nicID2].ProtocolAddresses - if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) - } - if containsV6Addr(nic1Addrs, e1Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) - } - if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) { - t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr1) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) - } - if containsV6Addr(nic2Addrs, e2Addr2) { - t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) - } + // Collect invalidation events after having NDP state cleaned up. + gotRouterEvents := make(map[ndpRouterEvent]int) + for i := 0; i < maxRouterAndPrefixEvents; i++ { + ok, e := expectRouterEvent() + if !ok { + t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + break + } + gotRouterEvents[e]++ + } + gotPrefixEvents := make(map[ndpPrefixEvent]int) + for i := 0; i < maxRouterAndPrefixEvents; i++ { + ok, e := expectPrefixEvent() + if !ok { + t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + break + } + gotPrefixEvents[e]++ + } + gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int) + for i := 0; i < test.maxAutoGenAddrEvents; i++ { + ok, e := expectAutoGenAddrEvent() + if !ok { + t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i) + break + } + gotAutoGenAddrEvents[e]++ + } - // Should not get any more events (invalidation timers should have been - // cancelled when we transitioned into a router). - time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) - select { - case <-ndpDisp.routerC: - t.Error("unexpected router event") - default: - } - select { - case <-ndpDisp.prefixC: - t.Error("unexpected prefix event") - default: - } - select { - case <-ndpDisp.autoGenAddrC: - t.Error("unexpected auto-generated address event") - default: + // No need to proceed any further if we already failed the test (missing + // some invalidation events). + if t.Failed() { + t.FailNow() + } + + expectedRouterEvents := map[ndpRouterEvent]int{ + {nicID: nicID1, addr: llAddr3, discovered: false}: 1, + {nicID: nicID1, addr: llAddr4, discovered: false}: 1, + {nicID: nicID2, addr: llAddr3, discovered: false}: 1, + {nicID: nicID2, addr: llAddr4, discovered: false}: 1, + } + if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { + t.Errorf("router events mismatch (-want +got):\n%s", diff) + } + expectedPrefixEvents := map[ndpPrefixEvent]int{ + {nicID: nicID1, prefix: subnet1, discovered: false}: 1, + {nicID: nicID1, prefix: subnet2, discovered: false}: 1, + {nicID: nicID2, prefix: subnet1, discovered: false}: 1, + {nicID: nicID2, prefix: subnet2, discovered: false}: 1, + } + if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" { + t.Errorf("prefix events mismatch (-want +got):\n%s", diff) + } + expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{ + {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1, + {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1, + } + + if !test.keepAutoGenLinkLocal { + expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1 + expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1 + } + + if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" { + t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff) + } + + // Make sure the auto-generated addresses got removed. + nicinfo = s.NICInfo() + nic1Addrs = nicinfo[nicID1].ProtocolAddresses + nic2Addrs = nicinfo[nicID2].ProtocolAddresses + if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs) + } + } + if containsV6Addr(nic1Addrs, e1Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs) + } + if containsV6Addr(nic1Addrs, e1Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs) + } + if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal { + if test.keepAutoGenLinkLocal { + t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } else { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs) + } + } + if containsV6Addr(nic2Addrs, e2Addr1) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs) + } + if containsV6Addr(nic2Addrs, e2Addr2) { + t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs) + } + + // Should not get any more events (invalidation timers should have been + // cancelled when the NDP state was cleaned up). + time.Sleep(lifetimeSeconds*time.Second + defaultTimeout) + select { + case <-ndpDisp.routerC: + t.Error("unexpected router event") + default: + } + select { + case <-ndpDisp.prefixC: + t.Error("unexpected prefix event") + default: + } + select { + case <-ndpDisp.autoGenAddrC: + t.Error("unexpected auto-generated address event") + default: + } + }) } } @@ -3406,77 +3483,130 @@ func TestRouterSolicitation(t *testing.T) { }) } -// TestStopStartSolicitingRouters tests that when forwarding is enabled or -// disabled, router solicitations are stopped or started, respecitively. func TestStopStartSolicitingRouters(t *testing.T) { t.Parallel() + const nicID = 1 const interval = 500 * time.Millisecond const delay = time.Second const maxRtrSolicitations = 3 - e := channel.New(maxRtrSolicitations, 1280, linkAddr1) - waitForPkt := func(timeout time.Duration) { - t.Helper() - ctx, _ := context.WithTimeout(context.Background(), timeout) - p, ok := e.ReadContext(ctx) - if !ok { - t.Fatal("timed out waiting for packet") - return - } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS()) - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - NDPConfigs: stack.NDPConfigurations{ - MaxRtrSolicitations: maxRtrSolicitations, - RtrSolicitationInterval: interval, - MaxRtrSolicitationDelay: delay, + tests := []struct { + name string + startFn func(t *testing.T, s *stack.Stack) + stopFn func(t *testing.T, s *stack.Stack) + }{ + // Tests that when forwarding is enabled or disabled, router solicitations + // are stopped or started, respectively. + { + name: "Forwarding enabled and disabled", + startFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + s.SetForwarding(false) + }, + stopFn: func(t *testing.T, s *stack.Stack) { + t.Helper() + s.SetForwarding(true) + }, }, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - // Enable forwarding which should stop router solicitations. - s.SetForwarding(true) - ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - // A single RS may have been sent before forwarding was enabled. - ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) - if _, ok = e.ReadContext(ctx); ok { - t.Fatal("Should not have sent more than one RS message") - } - } + // Tests that when a NIC is enabled or disabled, router solicitations + // are started or stopped, respectively. + { + name: "NIC disabled and enabled", + startFn: func(t *testing.T, s *stack.Stack) { + t.Helper() - // Enabling forwarding again should do nothing. - s.SetForwarding(true) - ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet after becoming a router") - } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + }, + stopFn: func(t *testing.T, s *stack.Stack) { + t.Helper() - // Disable forwarding which should start router solicitations. - s.SetForwarding(false) - waitForPkt(delay + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - waitForPkt(interval + defaultAsyncEventTimeout) - ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + }, + }, } - // Disabling forwarding again should do nothing. - s.SetForwarding(false) - ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet after becoming a router") + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(maxRtrSolicitations, 1280, linkAddr1) + waitForPkt := func(timeout time.Duration) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + p, ok := e.ReadContext(ctx) + if !ok { + t.Fatal("timed out waiting for packet") + return + } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS()) + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NDPConfigs: stack.NDPConfigurations{ + MaxRtrSolicitations: maxRtrSolicitations, + RtrSolicitationInterval: interval, + MaxRtrSolicitationDelay: delay, + }, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + // Stop soliciting routers. + test.stopFn(t, s) + ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + // A single RS may have been sent before forwarding was enabled. + ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout) + defer cancel() + if _, ok = e.ReadContext(ctx); ok { + t.Fatal("should not have sent more than one RS message") + } + } + + // Stopping router solicitations after it has already been stopped should + // do nothing. + test.stopFn(t, s) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") + } + + // Start soliciting routers. + test.startFn(t, s) + waitForPkt(delay + defaultAsyncEventTimeout) + waitForPkt(interval + defaultAsyncEventTimeout) + waitForPkt(interval + defaultAsyncEventTimeout) + ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") + } + + // Starting router solicitations after it has already completed should do + // nothing. + test.startFn(t, s) + ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout) + defer cancel() + if _, ok := e.ReadContext(ctx); ok { + t.Fatal("unexpectedly got a packet after finishing router solicitations") + } + }) } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ca3a7a07e..b2be18e47 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -27,6 +27,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) +var ipv4BroadcastAddr = tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 8 * header.IPv4AddressSize, + }, +} + // NIC represents a "network interface card" to which the networking stack is // attached. type NIC struct { @@ -36,7 +44,8 @@ type NIC struct { linkEP LinkEndpoint context NICContext - stats NICStats + stats NICStats + attach sync.Once mu struct { sync.RWMutex @@ -135,7 +144,69 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC return nic } -// enable enables the NIC. enable will attach the link to its LinkEndpoint and +// enabled returns true if n is enabled. +func (n *NIC) enabled() bool { + n.mu.RLock() + enabled := n.mu.enabled + n.mu.RUnlock() + return enabled +} + +// disable disables n. +// +// It undoes the work done by enable. +func (n *NIC) disable() *tcpip.Error { + n.mu.RLock() + enabled := n.mu.enabled + n.mu.RUnlock() + if !enabled { + return nil + } + + n.mu.Lock() + defer n.mu.Unlock() + + if !n.mu.enabled { + return nil + } + + // TODO(b/147015577): Should Routes that are currently bound to n be + // invalidated? Currently, Routes will continue to work when a NIC is enabled + // again, and applications may not know that the underlying NIC was ever + // disabled. + + if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { + n.mu.ndp.stopSolicitingRouters() + n.mu.ndp.cleanupState(false /* hostOnly */) + + // Stop DAD for all the unicast IPv6 endpoints that are in the + // permanentTentative state. + for _, r := range n.mu.endpoints { + if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) { + n.mu.ndp.stopDuplicateAddressDetection(addr) + } + } + + // The NIC may have already left the multicast group. + if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + } + + if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + // The address may have already been removed. + if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + } + + // TODO(b/147015577): Should n detach from its LinkEndpoint? + + n.mu.enabled = false + return nil +} + +// enable enables n. enable will attach the nic to its LinkEndpoint and // join the IPv6 All-Nodes Multicast address (ff02::1). func (n *NIC) enable() *tcpip.Error { n.mu.RLock() @@ -158,10 +229,7 @@ func (n *NIC) enable() *tcpip.Error { // Create an endpoint to receive broadcast packets on this interface. if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}, - }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { + if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { return err } } @@ -183,6 +251,14 @@ func (n *NIC) enable() *tcpip.Error { return nil } + // Join the All-Nodes multicast group before starting DAD as responses to DAD + // (NDP NS) messages may be sent to the All-Nodes multicast group if the + // source address of the NDP NS is the unspecified address, as per RFC 4861 + // section 7.2.4. + if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { + return err + } + // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent // state. // @@ -200,10 +276,6 @@ func (n *NIC) enable() *tcpip.Error { } } - if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil { - return err - } - // Do not auto-generate an IPv6 link-local address for loopback devices. if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() { // The valid and preferred lifetime is infinite for the auto-generated @@ -234,7 +306,7 @@ func (n *NIC) becomeIPv6Router() { n.mu.Lock() defer n.mu.Unlock() - n.mu.ndp.cleanupHostOnlyState() + n.mu.ndp.cleanupState(true /* hostOnly */) n.mu.ndp.stopSolicitingRouters() } @@ -252,7 +324,9 @@ func (n *NIC) becomeIPv6Host() { // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it // to start delivering packets. func (n *NIC) attachLinkEndpoint() { - n.linkEP.Attach(n) + n.attach.Do(func() { + n.linkEP.Attach(n) + }) } // setPromiscuousMode enables or disables promiscuous mode. @@ -712,6 +786,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { case permanentExpired, temporary: continue } + addrs = append(addrs, tcpip.ProtocolAddress{ Protocol: ref.protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -1009,6 +1084,15 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { return nil } +// isInGroup returns true if n has joined the multicast group addr. +func (n *NIC) isInGroup(addr tcpip.Address) bool { + n.mu.RLock() + joins := n.mu.mcastJoins[NetworkEndpointID{addr}] + n.mu.RUnlock() + + return joins != 0 +} + func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) { r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) r.RemoteLinkAddress = remotelinkAddr @@ -1411,7 +1495,7 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { // // r's NIC must be read locked. func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { - return r.getKind() != permanentExpired || r.nic.mu.spoofing + return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing) } // decRef decrements the ref count and cleans up the endpoint once it reaches diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6eac16e16..fabc976a7 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -921,23 +921,38 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[id] - if nic == nil { + nic, ok := s.nics[id] + if !ok { return tcpip.ErrUnknownNICID } return nic.enable() } +// DisableNIC disables the given NIC. +func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.disable() +} + // CheckNIC checks if a NIC is usable. func (s *Stack) CheckNIC(id tcpip.NICID) bool { s.mu.RLock() + defer s.mu.RUnlock() + nic, ok := s.nics[id] - s.mu.RUnlock() - if ok { - return nic.linkEP.IsAttached() + if !ok { + return false } - return false + + return nic.enabled() } // NICAddressRanges returns a map of NICIDs to their associated subnets. @@ -989,7 +1004,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { for id, nic := range s.nics { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. - Running: nic.linkEP.IsAttached(), + Running: nic.enabled(), Promiscuous: nic.isPromiscuousMode(), Loopback: nic.isLoopback(), } @@ -1151,7 +1166,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { - if nic, ok := s.nics[id]; ok { + if nic, ok := s.nics[id]; ok && nic.enabled() { if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil } @@ -1161,7 +1176,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { continue } - if nic, ok := s.nics[route.NIC]; ok { + if nic, ok := s.nics[route.NIC]; ok && nic.enabled() { if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { if len(remoteAddr) == 0 { // If no remote address was provided, then the route @@ -1614,6 +1629,18 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC return tcpip.ErrUnknownNICID } +// IsInGroup returns true if the NIC with ID nicID has joined the multicast +// group multicastAddr. +func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[nicID]; ok { + return nic.isInGroup(multicastAddr), nil + } + return false, tcpip.ErrUnknownNICID +} + // IPTables returns the stack's iptables. func (s *Stack) IPTables() iptables.IPTables { s.tablesMu.RLock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 7ba604442..eb6f7d1fc 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -33,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -509,6 +510,257 @@ func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr } } +func TestDisableUnknownNIC(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID) + } +} + +func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := loopback.New() + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + checkNIC := func(enabled bool) { + t.Helper() + + allNICInfo := s.NICInfo() + nicInfo, ok := allNICInfo[nicID] + if !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } else if nicInfo.Flags.Running != enabled { + t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled) + } + + if got := s.CheckNIC(nicID); got != enabled { + t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled) + } + } + + // NIC should initially report itself as disabled. + checkNIC(false) + + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + checkNIC(true) + + // If the NIC is not reporting a correct enabled status, we cannot trust the + // next check so end the test here. + if t.Failed() { + t.FailNow() + } + + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + checkNIC(false) +} + +func TestRoutesWithDisabledNIC(t *testing.T) { + const unspecifiedNIC = 0 + const nicID1 = 1 + const nicID2 = 2 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep1 := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + + addr1 := tcpip.Address("\x01") + if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + } + + ep2 := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + + addr2 := tcpip.Address("\x02") + if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, + {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, + }) + } + + // Test routes to odd address. + testRoute(t, s, unspecifiedNIC, "", "\x05", addr1) + testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1) + testRoute(t, s, nicID1, addr1, "\x05", addr1) + + // Test routes to even address. + testRoute(t, s, unspecifiedNIC, "", "\x06", addr2) + testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2) + testRoute(t, s, nicID2, addr2, "\x06", addr2) + + // Disabling NIC1 should result in no routes to odd addresses. Routes to even + // addresses should continue to be available as NIC2 is still enabled. + if err := s.DisableNIC(nicID1); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) + } + nic1Dst := tcpip.Address("\x05") + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + nic2Dst := tcpip.Address("\x06") + testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2) + testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2) + testRoute(t, s, nicID2, addr2, nic2Dst, addr2) + + // Disabling NIC2 should result in no routes to even addresses. No route + // should be available to any address as routes to odd addresses were made + // unavailable by disabling NIC1 above. + if err := s.DisableNIC(nicID2); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) + } + testNoRoute(t, s, unspecifiedNIC, "", nic1Dst) + testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst) + testNoRoute(t, s, nicID1, addr1, nic1Dst) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) + + // Enabling NIC1 should make routes to odd addresses available again. Routes + // to even addresses should continue to be unavailable as NIC2 is still + // disabled. + if err := s.EnableNIC(nicID1); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID1, err) + } + testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1) + testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1) + testRoute(t, s, nicID1, addr1, nic1Dst, addr1) + testNoRoute(t, s, unspecifiedNIC, "", nic2Dst) + testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst) + testNoRoute(t, s, nicID2, addr2, nic2Dst) +} + +func TestRouteWritePacketWithDisabledNIC(t *testing.T) { + const unspecifiedNIC = 0 + const nicID1 = 1 + const nicID2 = 2 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + + addr1 := tcpip.Address("\x01") + if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err) + } + + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + + addr2 := tcpip.Address("\x02") + if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + { + subnet0, err := tcpip.NewSubnet("\x00", "\x01") + if err != nil { + t.Fatal(err) + } + subnet1, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: nicID1}, + {Destination: subnet0, Gateway: "\x00", NIC: nicID2}, + }) + } + + nic1Dst := tcpip.Address("\x05") + r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err) + } + defer r1.Release() + + nic2Dst := tcpip.Address("\x06") + r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err) + } + defer r2.Release() + + // If we failed to get routes r1 or r2, we cannot proceed with the test. + if t.Failed() { + t.FailNow() + } + + buf := buffer.View([]byte{1}) + testSend(t, r1, ep1, buf) + testSend(t, r2, ep2, buf) + + // Writes with Routes that use the disabled NIC1 should fail. + if err := s.DisableNIC(nicID1); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID1, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testSend(t, r2, ep2, buf) + + // Writes with Routes that use the disabled NIC2 should fail. + if err := s.DisableNIC(nicID2); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID2, err) + } + testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) + + // Writes with Routes that use the re-enabled NIC1 should succeed. + // TODO(b/147015577): Should we instead completely invalidate all Routes that + // were bound to a disabled NIC at some point? + if err := s.EnableNIC(nicID1); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID1, err) + } + testSend(t, r1, ep1, buf) + testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState) +} + func TestRoutes(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has @@ -2173,13 +2425,29 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { e := channel.New(0, 1280, test.linkAddr) s := stack.New(opts) - nicOpts := stack.NICOptions{Name: test.nicName} + nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true} if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err) } - var expectedMainAddr tcpip.AddressWithPrefix + // A new disabled NIC should not have any address, even if auto generation + // was enabled. + allStackAddrs := s.AllAddresses() + allNICAddrs, ok := allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + } + // Enabling the NIC should attempt auto-generation of a link-local + // address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + var expectedMainAddr tcpip.AddressWithPrefix if test.shouldGen { expectedMainAddr = tcpip.AddressWithPrefix{ Address: test.expectedAddr, @@ -2609,6 +2877,111 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { } } +func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { + const nicID = 1 + + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + allStackAddrs := s.AllAddresses() + allNICAddrs, ok := allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + } + + // Enabling the NIC should add the IPv4 broadcast address. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + allStackAddrs = s.AllAddresses() + allNICAddrs, ok = allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 1 { + t.Fatalf("got len(allNICAddrs) = %d, want = 1", l) + } + want := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 32, + }, + } + if allNICAddrs[0] != want { + t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want) + } + + // Disabling the NIC should remove the IPv4 broadcast address. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + allStackAddrs = s.AllAddresses() + allNICAddrs, ok = allStackAddrs[nicID] + if !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } + if l := len(allNICAddrs); l != 0 { + t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + } +} + +func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { + const nicID = 1 + + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } + + // Should not be in the IPv6 all-nodes multicast group yet because the NIC has + // not been enabled yet. + isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + } + + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress) + } + + // The all-nodes multicast group should be left when the NIC is disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) + if err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) + } + if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + } +} + // TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC // was disabled have DAD performed on them when the NIC is enabled. func TestDoDADWhenNICEnabled(t *testing.T) { -- cgit v1.2.3 From 4a73bae269ae9f52a962ae3b08a17ccaacf7ba80 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 20 Feb 2020 15:19:40 -0800 Subject: Initial network namespace support. TCP/IP will work with netstack networking. hostinet doesn't work, and sockets will have the same behavior as it is now. Before the userspace is able to create device, the default loopback device can be used to test. /proc/net and /sys/net will still be connected to the root network stack; this is the same behavior now. Issue #1833 PiperOrigin-RevId: 296309389 --- pkg/sentry/fs/proc/net.go | 5 +- pkg/sentry/fs/proc/sys_net.go | 4 +- pkg/sentry/fsimpl/proc/tasks_net.go | 5 +- pkg/sentry/fsimpl/proc/tasks_sys.go | 4 +- pkg/sentry/fsimpl/testutil/kernel.go | 1 + pkg/sentry/inet/BUILD | 1 + pkg/sentry/inet/namespace.go | 99 +++++++++++++++++++++++++ pkg/sentry/kernel/kernel.go | 26 ++++--- pkg/sentry/kernel/task.go | 9 +-- pkg/sentry/kernel/task_clone.go | 16 ++-- pkg/sentry/kernel/task_net.go | 19 +++-- pkg/sentry/kernel/task_start.go | 8 +- pkg/tcpip/time_unsafe.go | 2 + runsc/boot/BUILD | 2 +- runsc/boot/controller.go | 11 +-- runsc/boot/loader.go | 121 +++++++++++++++++++++---------- runsc/boot/network.go | 27 +++++++ runsc/boot/pprof.go | 18 ----- runsc/boot/pprof/BUILD | 11 +++ runsc/boot/pprof/pprof.go | 20 +++++ runsc/sandbox/network.go | 25 +------ test/syscalls/BUILD | 2 + test/syscalls/linux/BUILD | 17 +++++ test/syscalls/linux/network_namespace.cc | 121 +++++++++++++++++++++++++++++++ 24 files changed, 451 insertions(+), 123 deletions(-) create mode 100644 pkg/sentry/inet/namespace.go delete mode 100644 runsc/boot/pprof.go create mode 100644 runsc/boot/pprof/BUILD create mode 100644 runsc/boot/pprof/pprof.go create mode 100644 test/syscalls/linux/network_namespace.cc (limited to 'pkg/tcpip') diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 6f2775344..95d5817ff 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -43,7 +43,10 @@ import ( // newNet creates a new proc net entry. func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode { var contents map[string]*fs.Inode - if s := p.k.NetworkStack(); s != nil { + // TODO(gvisor.dev/issue/1833): Support for using the network stack in the + // network namespace of the calling process. We should make this per-process, + // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net. + if s := p.k.RootNetworkNamespace().Stack(); s != nil { contents = map[string]*fs.Inode{ "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc), "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc), diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index 0772d4ae4..d4c4b533d 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -357,7 +357,9 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine func (p *proc) newSysNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode { var contents map[string]*fs.Inode - if s := p.k.NetworkStack(); s != nil { + // TODO(gvisor.dev/issue/1833): Support for using the network stack in the + // network namespace of the calling process. + if s := p.k.RootNetworkNamespace().Stack(); s != nil { contents = map[string]*fs.Inode{ "ipv4": p.newSysNetIPv4Dir(ctx, msrc, s), "core": p.newSysNetCore(ctx, msrc, s), diff --git a/pkg/sentry/fsimpl/proc/tasks_net.go b/pkg/sentry/fsimpl/proc/tasks_net.go index 608fec017..d4e1812d8 100644 --- a/pkg/sentry/fsimpl/proc/tasks_net.go +++ b/pkg/sentry/fsimpl/proc/tasks_net.go @@ -39,7 +39,10 @@ import ( func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry { var contents map[string]*kernfs.Dentry - if stack := k.NetworkStack(); stack != nil { + // TODO(gvisor.dev/issue/1833): Support for using the network stack in the + // network namespace of the calling process. We should make this per-process, + // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net. + if stack := k.RootNetworkNamespace().Stack(); stack != nil { const ( arp = "IP address HW type Flags HW address Mask Device\n" netlink = "sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n" diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index c7ce74883..3d5dc463c 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -50,7 +50,9 @@ func newSysDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k func newSysNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry { var contents map[string]*kernfs.Dentry - if stack := k.NetworkStack(); stack != nil { + // TODO(gvisor.dev/issue/1833): Support for using the network stack in the + // network namespace of the calling process. + if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]*kernfs.Dentry{ "ipv4": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{ "tcp_sack": newDentry(root, inoGen.NextIno(), 0644, &tcpSackData{stack: stack}), diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index d0be32e72..488478e29 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -128,6 +128,7 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns ThreadGroup: tc, TaskContext: &kernel.TaskContext{Name: name}, Credentials: auth.CredentialsFromContext(ctx), + NetworkNamespace: k.RootNetworkNamespace(), AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()), UTSNamespace: kernel.UTSNamespaceFromContext(ctx), IPCNamespace: kernel.IPCNamespaceFromContext(ctx), diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD index 334432abf..07bf39fed 100644 --- a/pkg/sentry/inet/BUILD +++ b/pkg/sentry/inet/BUILD @@ -10,6 +10,7 @@ go_library( srcs = [ "context.go", "inet.go", + "namespace.go", "test_stack.go", ], deps = [ diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go new file mode 100644 index 000000000..c16667e7f --- /dev/null +++ b/pkg/sentry/inet/namespace.go @@ -0,0 +1,99 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inet + +// Namespace represents a network namespace. See network_namespaces(7). +// +// +stateify savable +type Namespace struct { + // stack is the network stack implementation of this network namespace. + stack Stack `state:"nosave"` + + // creator allows kernel to create new network stack for network namespaces. + // If nil, no networking will function if network is namespaced. + creator NetworkStackCreator + + // isRoot indicates whether this is the root network namespace. + isRoot bool +} + +// NewRootNamespace creates the root network namespace, with creator +// allowing new network namespaces to be created. If creator is nil, no +// networking will function if the network is namespaced. +func NewRootNamespace(stack Stack, creator NetworkStackCreator) *Namespace { + return &Namespace{ + stack: stack, + creator: creator, + isRoot: true, + } +} + +// NewNamespace creates a new network namespace from the root. +func NewNamespace(root *Namespace) *Namespace { + n := &Namespace{ + creator: root.creator, + } + n.init() + return n +} + +// Stack returns the network stack of n. Stack may return nil if no network +// stack is configured. +func (n *Namespace) Stack() Stack { + return n.stack +} + +// IsRoot returns whether n is the root network namespace. +func (n *Namespace) IsRoot() bool { + return n.isRoot +} + +// RestoreRootStack restores the root network namespace with stack. This should +// only be called when restoring kernel. +func (n *Namespace) RestoreRootStack(stack Stack) { + if !n.isRoot { + panic("RestoreRootStack can only be called on root network namespace") + } + if n.stack != nil { + panic("RestoreRootStack called after a stack has already been set") + } + n.stack = stack +} + +func (n *Namespace) init() { + // Root network namespace will have stack assigned later. + if n.isRoot { + return + } + if n.creator != nil { + var err error + n.stack, err = n.creator.CreateStack() + if err != nil { + panic(err) + } + } +} + +// afterLoad is invoked by stateify. +func (n *Namespace) afterLoad() { + n.init() +} + +// NetworkStackCreator allows new instances of a network stack to be created. It +// is used by the kernel to create new network namespaces when requested. +type NetworkStackCreator interface { + // CreateStack creates a new network stack for a network namespace. + CreateStack() (Stack, error) +} diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 7da0368f1..c62fd6eb1 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -111,7 +111,7 @@ type Kernel struct { timekeeper *Timekeeper tasks *TaskSet rootUserNamespace *auth.UserNamespace - networkStack inet.Stack `state:"nosave"` + rootNetworkNamespace *inet.Namespace applicationCores uint useHostCores bool extraAuxv []arch.AuxEntry @@ -260,8 +260,9 @@ type InitKernelArgs struct { // RootUserNamespace is the root user namespace. RootUserNamespace *auth.UserNamespace - // NetworkStack is the TCP/IP network stack. NetworkStack may be nil. - NetworkStack inet.Stack + // RootNetworkNamespace is the root network namespace. If nil, no networking + // will be available. + RootNetworkNamespace *inet.Namespace // ApplicationCores is the number of logical CPUs visible to sandboxed // applications. The set of logical CPU IDs is [0, ApplicationCores); thus @@ -320,7 +321,10 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.rootUTSNamespace = args.RootUTSNamespace k.rootIPCNamespace = args.RootIPCNamespace k.rootAbstractSocketNamespace = args.RootAbstractSocketNamespace - k.networkStack = args.NetworkStack + k.rootNetworkNamespace = args.RootNetworkNamespace + if k.rootNetworkNamespace == nil { + k.rootNetworkNamespace = inet.NewRootNamespace(nil, nil) + } k.applicationCores = args.ApplicationCores if args.UseHostCores { k.useHostCores = true @@ -543,8 +547,6 @@ func (ts *TaskSet) unregisterEpollWaiters() { func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error { loadStart := time.Now() - k.networkStack = net - initAppCores := k.applicationCores // Load the pre-saved CPUID FeatureSet. @@ -575,6 +577,10 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) log.Infof("Kernel load stats: %s", &stats) log.Infof("Kernel load took [%s].", time.Since(kernelStart)) + // rootNetworkNamespace should be populated after loading the state file. + // Restore the root network stack. + k.rootNetworkNamespace.RestoreRootStack(net) + // Load the memory file's state. memoryStart := time.Now() if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil { @@ -905,6 +911,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, FSContext: fsContext, FDTable: args.FDTable, Credentials: args.Credentials, + NetworkNamespace: k.RootNetworkNamespace(), AllowedCPUMask: sched.NewFullCPUSet(k.applicationCores), UTSNamespace: args.UTSNamespace, IPCNamespace: args.IPCNamespace, @@ -1255,10 +1262,9 @@ func (k *Kernel) RootAbstractSocketNamespace() *AbstractSocketNamespace { return k.rootAbstractSocketNamespace } -// NetworkStack returns the network stack. NetworkStack may return nil if no -// network stack is available. -func (k *Kernel) NetworkStack() inet.Stack { - return k.networkStack +// RootNetworkNamespace returns the root network namespace, always non-nil. +func (k *Kernel) RootNetworkNamespace() *inet.Namespace { + return k.rootNetworkNamespace } // GlobalInit returns the thread group with ID 1 in the root PID namespace, or diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index a3443ff21..e37e23231 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -486,13 +486,10 @@ type Task struct { numaPolicy int32 numaNodeMask uint64 - // If netns is true, the task is in a non-root network namespace. Network - // namespaces aren't currently implemented in full; being in a network - // namespace simply prevents the task from observing any network devices - // (including loopback) or using abstract socket addresses (see unix(7)). + // netns is the task's network namespace. netns is never nil. // - // netns is protected by mu. netns is owned by the task goroutine. - netns bool + // netns is protected by mu. + netns *inet.Namespace // If rseqPreempted is true, before the next call to p.Switch(), // interrupt rseq critical regions as defined by rseqAddr and diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index ba74b4c1c..78866f280 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -17,6 +17,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -54,8 +55,7 @@ type SharingOptions struct { NewUserNamespace bool // If NewNetworkNamespace is true, the task should have an independent - // network namespace. (Note that network namespaces are not really - // implemented; see comment on Task.netns for details.) + // network namespace. NewNetworkNamespace bool // If NewFiles is true, the task should use an independent file descriptor @@ -199,6 +199,11 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { ipcns = NewIPCNamespace(userns) } + netns := t.NetworkNamespace() + if opts.NewNetworkNamespace { + netns = inet.NewNamespace(netns) + } + // TODO(b/63601033): Implement CLONE_NEWNS. mntnsVFS2 := t.mountNamespaceVFS2 if mntnsVFS2 != nil { @@ -268,7 +273,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { FDTable: fdTable, Credentials: creds, Niceness: t.Niceness(), - NetworkNamespaced: t.netns, + NetworkNamespace: netns, AllowedCPUMask: t.CPUMask(), UTSNamespace: utsns, IPCNamespace: ipcns, @@ -283,9 +288,6 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { } else { cfg.InheritParent = t } - if opts.NewNetworkNamespace { - cfg.NetworkNamespaced = true - } nt, err := t.tg.pidns.owner.NewTask(cfg) if err != nil { if opts.NewThreadGroup { @@ -482,7 +484,7 @@ func (t *Task) Unshare(opts *SharingOptions) error { t.mu.Unlock() return syserror.EPERM } - t.netns = true + t.netns = inet.NewNamespace(t.netns) } if opts.NewUTSNamespace { if !haveCapSysAdmin { diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go index 172a31e1d..f7711232c 100644 --- a/pkg/sentry/kernel/task_net.go +++ b/pkg/sentry/kernel/task_net.go @@ -22,14 +22,23 @@ import ( func (t *Task) IsNetworkNamespaced() bool { t.mu.Lock() defer t.mu.Unlock() - return t.netns + return !t.netns.IsRoot() } // NetworkContext returns the network stack used by the task. NetworkContext // may return nil if no network stack is available. +// +// TODO(gvisor.dev/issue/1833): Migrate callers of this method to +// NetworkNamespace(). func (t *Task) NetworkContext() inet.Stack { - if t.IsNetworkNamespaced() { - return nil - } - return t.k.networkStack + t.mu.Lock() + defer t.mu.Unlock() + return t.netns.Stack() +} + +// NetworkNamespace returns the network namespace observed by the task. +func (t *Task) NetworkNamespace() *inet.Namespace { + t.mu.Lock() + defer t.mu.Unlock() + return t.netns } diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index f9236a842..a5035bb7f 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -17,6 +17,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/kernel/sched" @@ -65,9 +66,8 @@ type TaskConfig struct { // Niceness is the niceness of the new task. Niceness int - // If NetworkNamespaced is true, the new task should observe a non-root - // network namespace. - NetworkNamespaced bool + // NetworkNamespace is the network namespace to be used for the new task. + NetworkNamespace *inet.Namespace // AllowedCPUMask contains the cpus that this task can run on. AllowedCPUMask sched.CPUSet @@ -133,7 +133,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { allowedCPUMask: cfg.AllowedCPUMask.Copy(), ioUsage: &usage.IO{}, niceness: cfg.Niceness, - netns: cfg.NetworkNamespaced, + netns: cfg.NetworkNamespace, utsns: cfg.UTSNamespace, ipcns: cfg.IPCNamespace, abstractSockets: cfg.AbstractSocketNamespace, diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index 48764b978..2f98a996f 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -25,6 +25,8 @@ import ( ) // StdClock implements Clock with the time package. +// +// +stateify savable type StdClock struct{} var _ Clock = (*StdClock)(nil) diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index ae4dd102a..26f68fe3d 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -19,7 +19,6 @@ go_library( "loader_amd64.go", "loader_arm64.go", "network.go", - "pprof.go", "strace.go", "user.go", ], @@ -91,6 +90,7 @@ go_library( "//pkg/usermem", "//runsc/boot/filter", "//runsc/boot/platforms", + "//runsc/boot/pprof", "//runsc/specutils", "@com_github_golang_protobuf//proto:go_default_library", "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 9c9e94864..17e774e0c 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/urpc" + "gvisor.dev/gvisor/runsc/boot/pprof" "gvisor.dev/gvisor/runsc/specutils" ) @@ -142,7 +143,7 @@ func newController(fd int, l *Loader) (*controller, error) { } srv.Register(manager) - if eps, ok := l.k.NetworkStack().(*netstack.Stack); ok { + if eps, ok := l.k.RootNetworkNamespace().Stack().(*netstack.Stack); ok { net := &Network{ Stack: eps.Stack, } @@ -341,7 +342,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { return fmt.Errorf("creating memory file: %v", err) } k.SetMemoryFile(mf) - networkStack := cm.l.k.NetworkStack() + networkStack := cm.l.k.RootNetworkNamespace().Stack() cm.l.k = k // Set up the restore environment. @@ -365,9 +366,9 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { } if cm.l.conf.ProfileEnable { - // initializePProf opens /proc/self/maps, so has to be - // called before installing seccomp filters. - initializePProf() + // pprof.Initialize opens /proc/self/maps, so has to be called before + // installing seccomp filters. + pprof.Initialize() } // Seccomp filters have to be applied before parsing the state file. diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index eef43b9df..e7ca98134 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -49,6 +49,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -60,6 +61,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/runsc/boot/filter" _ "gvisor.dev/gvisor/runsc/boot/platforms" // register all platforms. + "gvisor.dev/gvisor/runsc/boot/pprof" "gvisor.dev/gvisor/runsc/specutils" // Include supported socket providers. @@ -230,11 +232,8 @@ func New(args Args) (*Loader, error) { return nil, fmt.Errorf("enabling strace: %v", err) } - // Create an empty network stack because the network namespace may be empty at - // this point. Netns is configured before Run() is called. Netstack is - // configured using a control uRPC message. Host network is configured inside - // Run(). - networkStack, err := newEmptyNetworkStack(args.Conf, k, k) + // Create root network namespace/stack. + netns, err := newRootNetworkNamespace(args.Conf, k, k) if err != nil { return nil, fmt.Errorf("creating network: %v", err) } @@ -277,7 +276,7 @@ func New(args Args) (*Loader, error) { FeatureSet: cpuid.HostFeatureSet(), Timekeeper: tk, RootUserNamespace: creds.UserNamespace, - NetworkStack: networkStack, + RootNetworkNamespace: netns, ApplicationCores: uint(args.NumCPU), Vdso: vdso, RootUTSNamespace: kernel.NewUTSNamespace(args.Spec.Hostname, args.Spec.Hostname, creds.UserNamespace), @@ -466,7 +465,7 @@ func (l *Loader) run() error { // Delay host network configuration to this point because network namespace // is configured after the loader is created and before Run() is called. log.Debugf("Configuring host network") - stack := l.k.NetworkStack().(*hostinet.Stack) + stack := l.k.RootNetworkNamespace().Stack().(*hostinet.Stack) if err := stack.Configure(); err != nil { return err } @@ -485,7 +484,7 @@ func (l *Loader) run() error { // l.restore is set by the container manager when a restore call is made. if !l.restore { if l.conf.ProfileEnable { - initializePProf() + pprof.Initialize() } // Finally done with all configuration. Setup filters before user code @@ -908,48 +907,92 @@ func (l *Loader) WaitExit() kernel.ExitStatus { return l.k.GlobalInit().ExitStatus() } -func newEmptyNetworkStack(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) { +func newRootNetworkNamespace(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (*inet.Namespace, error) { + // Create an empty network stack because the network namespace may be empty at + // this point. Netns is configured before Run() is called. Netstack is + // configured using a control uRPC message. Host network is configured inside + // Run(). switch conf.Network { case NetworkHost: - return hostinet.NewStack(), nil + // No network namespacing support for hostinet yet, hence creator is nil. + return inet.NewRootNamespace(hostinet.NewStack(), nil), nil case NetworkNone, NetworkSandbox: - // NetworkNone sets up loopback using netstack. - netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()} - transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()} - s := netstack.Stack{stack.New(stack.Options{ - NetworkProtocols: netProtos, - TransportProtocols: transProtos, - Clock: clock, - Stats: netstack.Metrics, - HandleLocal: true, - // Enable raw sockets for users with sufficient - // privileges. - RawFactory: raw.EndpointFactory{}, - UniqueID: uniqueID, - })} - - // Enable SACK Recovery. - if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil { - return nil, fmt.Errorf("failed to enable SACK: %v", err) + s, err := newEmptySandboxNetworkStack(clock, uniqueID) + if err != nil { + return nil, err } + creator := &sandboxNetstackCreator{ + clock: clock, + uniqueID: uniqueID, + } + return inet.NewRootNamespace(s, creator), nil - // Set default TTLs as required by socket/netstack. - s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL)) - s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL)) + default: + panic(fmt.Sprintf("invalid network configuration: %v", conf.Network)) + } - // Enable Receive Buffer Auto-Tuning. - if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err) - } +} - s.FillDefaultIPTables() +func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) { + netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()} + transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()} + s := netstack.Stack{stack.New(stack.Options{ + NetworkProtocols: netProtos, + TransportProtocols: transProtos, + Clock: clock, + Stats: netstack.Metrics, + HandleLocal: true, + // Enable raw sockets for users with sufficient + // privileges. + RawFactory: raw.EndpointFactory{}, + UniqueID: uniqueID, + })} - return &s, nil + // Enable SACK Recovery. + if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil { + return nil, fmt.Errorf("failed to enable SACK: %v", err) + } - default: - panic(fmt.Sprintf("invalid network configuration: %v", conf.Network)) + // Set default TTLs as required by socket/netstack. + s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL)) + s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL)) + + // Enable Receive Buffer Auto-Tuning. + if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err) + } + + s.FillDefaultIPTables() + + return &s, nil +} + +// sandboxNetstackCreator implements kernel.NetworkStackCreator. +// +// +stateify savable +type sandboxNetstackCreator struct { + clock tcpip.Clock + uniqueID stack.UniqueID +} + +// CreateStack implements kernel.NetworkStackCreator.CreateStack. +func (f *sandboxNetstackCreator) CreateStack() (inet.Stack, error) { + s, err := newEmptySandboxNetworkStack(f.clock, f.uniqueID) + if err != nil { + return nil, err } + + // Setup loopback. + n := &Network{Stack: s.(*netstack.Stack).Stack} + nicID := tcpip.NICID(f.uniqueID.UniqueID()) + link := DefaultLoopbackLink + linkEP := loopback.New() + if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil { + return nil, err + } + + return s, nil } // signal sends a signal to one or more processes in a container. If PID is 0, diff --git a/runsc/boot/network.go b/runsc/boot/network.go index 6a8765ec8..bee6ee336 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -17,6 +17,7 @@ package boot import ( "fmt" "net" + "strings" "syscall" "gvisor.dev/gvisor/pkg/log" @@ -31,6 +32,32 @@ import ( "gvisor.dev/gvisor/pkg/urpc" ) +var ( + // DefaultLoopbackLink contains IP addresses and routes of "127.0.0.1/8" and + // "::1/8" on "lo" interface. + DefaultLoopbackLink = LoopbackLink{ + Name: "lo", + Addresses: []net.IP{ + net.IP("\x7f\x00\x00\x01"), + net.IPv6loopback, + }, + Routes: []Route{ + { + Destination: net.IPNet{ + IP: net.IPv4(0x7f, 0, 0, 0), + Mask: net.IPv4Mask(0xff, 0, 0, 0), + }, + }, + { + Destination: net.IPNet{ + IP: net.IPv6loopback, + Mask: net.IPMask(strings.Repeat("\xff", net.IPv6len)), + }, + }, + }, + } +) + // Network exposes methods that can be used to configure a network stack. type Network struct { Stack *stack.Stack diff --git a/runsc/boot/pprof.go b/runsc/boot/pprof.go deleted file mode 100644 index 463362f02..000000000 --- a/runsc/boot/pprof.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package boot - -func initializePProf() { -} diff --git a/runsc/boot/pprof/BUILD b/runsc/boot/pprof/BUILD new file mode 100644 index 000000000..29cb42b2f --- /dev/null +++ b/runsc/boot/pprof/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "pprof", + srcs = ["pprof.go"], + visibility = [ + "//runsc:__subpackages__", + ], +) diff --git a/runsc/boot/pprof/pprof.go b/runsc/boot/pprof/pprof.go new file mode 100644 index 000000000..1ded20dee --- /dev/null +++ b/runsc/boot/pprof/pprof.go @@ -0,0 +1,20 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package pprof provides a stub to initialize custom profilers. +package pprof + +// Initialize will be called at boot for initializing custom profilers. +func Initialize() { +} diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index 99e143696..bc093fba5 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -21,7 +21,6 @@ import ( "path/filepath" "runtime" "strconv" - "strings" "syscall" specs "github.com/opencontainers/runtime-spec/specs-go" @@ -75,30 +74,8 @@ func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Confi } func createDefaultLoopbackInterface(conn *urpc.Client) error { - link := boot.LoopbackLink{ - Name: "lo", - Addresses: []net.IP{ - net.IP("\x7f\x00\x00\x01"), - net.IPv6loopback, - }, - Routes: []boot.Route{ - { - Destination: net.IPNet{ - - IP: net.IPv4(0x7f, 0, 0, 0), - Mask: net.IPv4Mask(0xff, 0, 0, 0), - }, - }, - { - Destination: net.IPNet{ - IP: net.IPv6loopback, - Mask: net.IPMask(strings.Repeat("\xff", net.IPv6len)), - }, - }, - }, - } if err := conn.Call(boot.NetworkCreateLinksAndRoutes, &boot.CreateLinksAndRoutesArgs{ - LoopbackLinks: []boot.LoopbackLink{link}, + LoopbackLinks: []boot.LoopbackLink{boot.DefaultLoopbackLink}, }, nil); err != nil { return fmt.Errorf("creating loopback link and routes: %v", err) } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index d69ac8356..d1977d4de 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -258,6 +258,8 @@ syscall_test( syscall_test(test = "//test/syscalls/linux:munmap_test") +syscall_test(test = "//test/syscalls/linux:network_namespace_test") + syscall_test( add_overlay = True, test = "//test/syscalls/linux:open_create_test", diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 05a818795..aa303af84 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -3639,6 +3639,23 @@ cc_binary( ], ) +cc_binary( + name = "network_namespace_test", + testonly = 1, + srcs = ["network_namespace.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + gtest, + "//test/util:capability_util", + "//test/util:memory_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_absl//absl/synchronization", + ], +) + cc_binary( name = "semaphore_test", testonly = 1, diff --git a/test/syscalls/linux/network_namespace.cc b/test/syscalls/linux/network_namespace.cc new file mode 100644 index 000000000..6ea48c263 --- /dev/null +++ b/test/syscalls/linux/network_namespace.cc @@ -0,0 +1,121 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/synchronization/notification.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/memory_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +using TestFunc = std::function; +using RunFunc = std::function; + +struct NamespaceStrategy { + RunFunc run; + + static NamespaceStrategy Of(RunFunc run) { + NamespaceStrategy s; + s.run = run; + return s; + } +}; + +PosixError RunWithUnshare(TestFunc fn) { + PosixError err = PosixError(-1, "function did not return a value"); + ScopedThread t([&] { + if (unshare(CLONE_NEWNET) != 0) { + err = PosixError(errno); + return; + } + err = fn(); + }); + t.Join(); + return err; +} + +PosixError RunWithClone(TestFunc fn) { + struct Args { + absl::Notification n; + TestFunc fn; + PosixError err; + }; + Args args; + args.fn = fn; + args.err = PosixError(-1, "function did not return a value"); + + ASSIGN_OR_RETURN_ERRNO( + Mapping child_stack, + MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); + pid_t child = clone( + +[](void *arg) { + Args *args = reinterpret_cast(arg); + args->err = args->fn(); + args->n.Notify(); + syscall(SYS_exit, 0); // Exit manually. No return address on stack. + return 0; + }, + reinterpret_cast(child_stack.addr() + kPageSize), + CLONE_NEWNET | CLONE_THREAD | CLONE_SIGHAND | CLONE_VM, &args); + if (child < 0) { + return PosixError(errno, "clone() failed"); + } + args.n.WaitForNotification(); + return args.err; +} + +class NetworkNamespaceTest + : public ::testing::TestWithParam {}; + +TEST_P(NetworkNamespaceTest, LoopbackExists) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + EXPECT_NO_ERRNO(GetParam().run([]() { + // TODO(gvisor.dev/issue/1833): Update this to test that only "lo" exists. + // Check loopback device exists. + int sock = socket(AF_INET, SOCK_DGRAM, 0); + if (sock < 0) { + return PosixError(errno, "socket() failed"); + } + struct ifreq ifr; + snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); + if (ioctl(sock, SIOCGIFINDEX, &ifr) < 0) { + return PosixError(errno, "ioctl() failed, lo cannot be found"); + } + return NoError(); + })); +} + +INSTANTIATE_TEST_SUITE_P( + AllNetworkNamespaceTest, NetworkNamespaceTest, + ::testing::Values(NamespaceStrategy::Of(RunWithUnshare), + NamespaceStrategy::Of(RunWithClone))); + +} // namespace + +} // namespace testing +} // namespace gvisor -- cgit v1.2.3 From 97c07242c37e56f6cfdc52036d554052ba95f2ae Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 21 Feb 2020 09:53:56 -0800 Subject: Use Route.MaxHeaderLength when constructing NDP RS Test: stack_test.TestRouterSolicitation PiperOrigin-RevId: 296454766 --- pkg/tcpip/stack/ndp.go | 2 +- pkg/tcpip/stack/ndp_test.go | 78 ++++++++++++++++++++++++++++++++++++++------- 2 files changed, 68 insertions(+), 12 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 19bd05aa3..f651871ce 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -1235,7 +1235,7 @@ func (ndp *ndpState) startSolicitingRouters() { } payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize - hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize) + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize) pkt := header.ICMPv6(hdr.Prepend(payloadSize)) pkt.SetType(header.ICMPv6RouterSolicit) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index f7b75b74e..6e9306d09 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -267,6 +267,17 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s } } +// channelLinkWithHeaderLength is a channel.Endpoint with a configurable +// header length. +type channelLinkWithHeaderLength struct { + *channel.Endpoint + headerLength uint16 +} + +func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { + return l.headerLength +} + // Check e to make sure that the event is for addr on nic with ID 1, and the // resolved flag set to resolved with the specified err. func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string { @@ -323,21 +334,46 @@ func TestDADDisabled(t *testing.T) { // DAD for various values of DupAddrDetectTransmits and RetransmitTimer. // Included in the subtests is a test to make sure that an invalid // RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. +// This tests also validates the NDP NS packet that is transmitted. func TestDADResolve(t *testing.T) { const nicID = 1 tests := []struct { name string + linkHeaderLen uint16 dupAddrDetectTransmits uint8 retransTimer time.Duration expectedRetransmitTimer time.Duration }{ - {"1:1s:1s", 1, time.Second, time.Second}, - {"2:1s:1s", 2, time.Second, time.Second}, - {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second}, + { + name: "1:1s:1s", + dupAddrDetectTransmits: 1, + retransTimer: time.Second, + expectedRetransmitTimer: time.Second, + }, + { + name: "2:1s:1s", + linkHeaderLen: 1, + dupAddrDetectTransmits: 2, + retransTimer: time.Second, + expectedRetransmitTimer: time.Second, + }, + { + name: "1:2s:2s", + linkHeaderLen: 2, + dupAddrDetectTransmits: 1, + retransTimer: 2 * time.Second, + expectedRetransmitTimer: 2 * time.Second, + }, // 0s is an invalid RetransmitTimer timer and will be fixed to // the default RetransmitTimer value of 1s. - {"1:0s:1s", 1, 0, time.Second}, + { + name: "1:0s:1s", + linkHeaderLen: 3, + dupAddrDetectTransmits: 1, + retransTimer: 0, + expectedRetransmitTimer: time.Second, + }, } for _, test := range tests { @@ -356,10 +392,13 @@ func TestDADResolve(t *testing.T) { opts.NDPConfigs.RetransmitTimer = test.retransTimer opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits - e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(opts) - if err := s.CreateNIC(nicID, e); err != nil { + if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -445,6 +484,10 @@ func TestDADResolve(t *testing.T) { checker.NDPNSTargetAddress(addr1), checker.NDPNSOptions(nil), )) + + if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + } } }) } @@ -3336,8 +3379,11 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { func TestRouterSolicitation(t *testing.T) { t.Parallel() + const nicID = 1 + tests := []struct { name string + linkHeaderLen uint16 maxRtrSolicit uint8 rtrSolicitInt time.Duration effectiveRtrSolicitInt time.Duration @@ -3354,6 +3400,7 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Two RS with delay", + linkHeaderLen: 1, maxRtrSolicit: 2, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3362,6 +3409,7 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Single RS without delay", + linkHeaderLen: 2, maxRtrSolicit: 1, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3370,6 +3418,7 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Two RS without delay and invalid zero interval", + linkHeaderLen: 3, maxRtrSolicit: 2, rtrSolicitInt: 0, effectiveRtrSolicitInt: 4 * time.Second, @@ -3407,8 +3456,11 @@ func TestRouterSolicitation(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired waitForPkt := func(timeout time.Duration) { t.Helper() ctx, _ := context.WithTimeout(context.Background(), timeout) @@ -3434,6 +3486,10 @@ func TestRouterSolicitation(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPRS(), ) + + if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want) + } } waitForNothing := func(timeout time.Duration) { t.Helper() @@ -3450,8 +3506,8 @@ func TestRouterSolicitation(t *testing.T) { MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, }, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } // Make sure each RS got sent at the right -- cgit v1.2.3 From a155a23480abfafe096ff50f2c4aaf2c215b6c44 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 21 Feb 2020 11:20:15 -0800 Subject: Attach LinkEndpoint to NetworkDispatcher immediately Tests: stack_test.TestAttachToLinkEndpointImmediately PiperOrigin-RevId: 296474068 --- pkg/tcpip/stack/nic.go | 25 +++++++------------ pkg/tcpip/stack/stack.go | 5 ++-- pkg/tcpip/stack/stack_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 18 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index b2be18e47..862954ab2 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -44,8 +44,7 @@ type NIC struct { linkEP LinkEndpoint context NICContext - stats NICStats - attach sync.Once + stats NICStats mu struct { sync.RWMutex @@ -141,6 +140,8 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{} } + nic.linkEP.Attach(nic) + return nic } @@ -200,14 +201,16 @@ func (n *NIC) disable() *tcpip.Error { } } - // TODO(b/147015577): Should n detach from its LinkEndpoint? - n.mu.enabled = false return nil } -// enable enables n. enable will attach the nic to its LinkEndpoint and -// join the IPv6 All-Nodes Multicast address (ff02::1). +// enable enables n. +// +// If the stack has IPv6 enabled, enable will join the IPv6 All-Nodes Multicast +// address (ff02::1), start DAD for permanent addresses, and start soliciting +// routers if the stack is not operating as a router. If the stack is also +// configured to auto-generate a link-local address, one will be generated. func (n *NIC) enable() *tcpip.Error { n.mu.RLock() enabled := n.mu.enabled @@ -225,8 +228,6 @@ func (n *NIC) enable() *tcpip.Error { n.mu.enabled = true - n.attachLinkEndpoint() - // Create an endpoint to receive broadcast packets on this interface. if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { @@ -321,14 +322,6 @@ func (n *NIC) becomeIPv6Host() { n.mu.ndp.startSolicitingRouters() } -// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it -// to start delivering packets. -func (n *NIC) attachLinkEndpoint() { - n.attach.Do(func() { - n.linkEP.Attach(n) - }) -} - // setPromiscuousMode enables or disables promiscuous mode. func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Lock() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index fabc976a7..f0ed76fbe 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -881,6 +881,8 @@ type NICOptions struct { // CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and // NICOptions. See the documentation on type NICOptions for details on how // NICs can be configured. +// +// LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher. func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -900,7 +902,6 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp } n := newNIC(s, id, opts.Name, ep, opts.Context) - s.nics[id] = n if !opts.Disabled { return n.enable() @@ -910,7 +911,7 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp } // CreateNIC creates a NIC with the provided id and LinkEndpoint and calls -// `LinkEndpoint.Attach` to start delivering packets to it. +// LinkEndpoint.Attach to bind ep with a NetworkDispatcher. func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index eb6f7d1fc..18016e7db 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -239,6 +239,23 @@ func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } +// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify +// that LinkEndpoint.Attach was called. +type linkEPWithMockedAttach struct { + stack.LinkEndpoint + attached bool +} + +// Attach implements stack.LinkEndpoint.Attach. +func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) { + l.LinkEndpoint.Attach(d) + l.attached = true +} + +func (l *linkEPWithMockedAttach) isAttached() bool { + return l.attached +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. @@ -510,6 +527,45 @@ func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr } } +// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to +// a NetworkDispatcher when the NIC is created. +func TestAttachToLinkEndpointImmediately(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + nicOpts stack.NICOptions + }{ + { + name: "Create enabled NIC", + nicOpts: stack.NICOptions{Disabled: false}, + }, + { + name: "Create disabled NIC", + nicOpts: stack.NICOptions{Disabled: true}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + e := linkEPWithMockedAttach{ + LinkEndpoint: loopback.New(), + } + + if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err) + } + if !e.isAttached() { + t.Fatalf("link endpoint not attached to a network disatcher") + } + }) + } +} + func TestDisableUnknownNIC(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, -- cgit v1.2.3 From b8f56c79be40d9c75f4e2f279c9d821d1c1c3569 Mon Sep 17 00:00:00 2001 From: Ting-Yu Wang Date: Fri, 21 Feb 2020 15:41:56 -0800 Subject: Implement tap/tun device in vfs. PiperOrigin-RevId: 296526279 --- pkg/abi/linux/BUILD | 1 + pkg/abi/linux/ioctl.go | 26 ++ pkg/abi/linux/ioctl_tun.go | 29 ++ pkg/sentry/fs/dev/BUILD | 5 + pkg/sentry/fs/dev/dev.go | 10 +- pkg/sentry/fs/dev/net_tun.go | 170 +++++++++++ pkg/syserror/syserror.go | 1 + pkg/tcpip/buffer/view.go | 6 + pkg/tcpip/link/channel/BUILD | 1 + pkg/tcpip/link/channel/channel.go | 180 +++++++++--- pkg/tcpip/link/tun/BUILD | 18 +- pkg/tcpip/link/tun/device.go | 352 +++++++++++++++++++++++ pkg/tcpip/link/tun/protocol.go | 56 ++++ pkg/tcpip/stack/nic.go | 32 +++ pkg/tcpip/stack/stack.go | 39 +++ test/syscalls/BUILD | 2 + test/syscalls/linux/BUILD | 30 ++ test/syscalls/linux/dev.cc | 7 + test/syscalls/linux/socket_netlink_route_util.cc | 163 +++++++++++ test/syscalls/linux/socket_netlink_route_util.h | 55 ++++ test/syscalls/linux/tuntap.cc | 346 ++++++++++++++++++++++ 21 files changed, 1490 insertions(+), 39 deletions(-) create mode 100644 pkg/abi/linux/ioctl_tun.go create mode 100644 pkg/sentry/fs/dev/net_tun.go create mode 100644 pkg/tcpip/link/tun/device.go create mode 100644 pkg/tcpip/link/tun/protocol.go create mode 100644 test/syscalls/linux/socket_netlink_route_util.cc create mode 100644 test/syscalls/linux/socket_netlink_route_util.h create mode 100644 test/syscalls/linux/tuntap.cc (limited to 'pkg/tcpip') diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index a89f34d4b..322d1ccc4 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -30,6 +30,7 @@ go_library( "futex.go", "inotify.go", "ioctl.go", + "ioctl_tun.go", "ip.go", "ipc.go", "limits.go", diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 0e18db9ef..2062e6a4b 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -72,3 +72,29 @@ const ( SIOCGMIIPHY = 0x8947 SIOCGMIIREG = 0x8948 ) + +// ioctl(2) directions. Used to calculate requests number. +// Constants from asm-generic/ioctl.h. +const ( + _IOC_NONE = 0 + _IOC_WRITE = 1 + _IOC_READ = 2 +) + +// Constants from asm-generic/ioctl.h. +const ( + _IOC_NRBITS = 8 + _IOC_TYPEBITS = 8 + _IOC_SIZEBITS = 14 + _IOC_DIRBITS = 2 + + _IOC_NRSHIFT = 0 + _IOC_TYPESHIFT = _IOC_NRSHIFT + _IOC_NRBITS + _IOC_SIZESHIFT = _IOC_TYPESHIFT + _IOC_TYPEBITS + _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS +) + +// IOC outputs the result of _IOC macro in asm-generic/ioctl.h. +func IOC(dir, typ, nr, size uint32) uint32 { + return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT +} diff --git a/pkg/abi/linux/ioctl_tun.go b/pkg/abi/linux/ioctl_tun.go new file mode 100644 index 000000000..c59c9c136 --- /dev/null +++ b/pkg/abi/linux/ioctl_tun.go @@ -0,0 +1,29 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// ioctl(2) request numbers from linux/if_tun.h +var ( + TUNSETIFF = IOC(_IOC_WRITE, 'T', 202, 4) + TUNGETIFF = IOC(_IOC_READ, 'T', 210, 4) +) + +// Flags from net/if_tun.h +const ( + IFF_TUN = 0x0001 + IFF_TAP = 0x0002 + IFF_NO_PI = 0x1000 + IFF_NOFILTER = 0x1000 +) diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD index 4c4b7d5cc..9b6bb26d0 100644 --- a/pkg/sentry/fs/dev/BUILD +++ b/pkg/sentry/fs/dev/BUILD @@ -9,6 +9,7 @@ go_library( "device.go", "fs.go", "full.go", + "net_tun.go", "null.go", "random.go", "tty.go", @@ -19,15 +20,19 @@ go_library( "//pkg/context", "//pkg/rand", "//pkg/safemem", + "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/ramfs", "//pkg/sentry/fs/tmpfs", + "//pkg/sentry/kernel", "//pkg/sentry/memmap", "//pkg/sentry/mm", "//pkg/sentry/pgalloc", + "//pkg/sentry/socket/netstack", "//pkg/syserror", + "//pkg/tcpip/link/tun", "//pkg/usermem", "//pkg/waiter", ], diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go index 35bd23991..7e66c29b0 100644 --- a/pkg/sentry/fs/dev/dev.go +++ b/pkg/sentry/fs/dev/dev.go @@ -66,8 +66,8 @@ func newMemDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo }) } -func newDirectory(ctx context.Context, msrc *fs.MountSource) *fs.Inode { - iops := ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0555)) +func newDirectory(ctx context.Context, contents map[string]*fs.Inode, msrc *fs.MountSource) *fs.Inode { + iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) return fs.NewInode(ctx, iops, msrc, fs.StableAttr{ DeviceID: devDevice.DeviceID(), InodeID: devDevice.NextIno(), @@ -111,7 +111,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { // A devpts is typically mounted at /dev/pts to provide // pseudoterminal support. Place an empty directory there for // the devpts to be mounted over. - "pts": newDirectory(ctx, msrc), + "pts": newDirectory(ctx, nil, msrc), // Similarly, applications expect a ptmx device at /dev/ptmx // connected to the terminals provided by /dev/pts/. Rather // than creating a device directly (which requires a hairy @@ -124,6 +124,10 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode { "ptmx": newSymlink(ctx, "pts/ptmx", msrc), "tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor), + + "net": newDirectory(ctx, map[string]*fs.Inode{ + "tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor), + }, msrc), } iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go new file mode 100644 index 000000000..755644488 --- /dev/null +++ b/pkg/sentry/fs/dev/net_tun.go @@ -0,0 +1,170 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/link/tun" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + netTunDevMajor = 10 + netTunDevMinor = 200 +) + +// +stateify savable +type netTunInodeOperations struct { + fsutil.InodeGenericChecker `state:"nosave"` + fsutil.InodeNoExtendedAttributes `state:"nosave"` + fsutil.InodeNoopAllocate `state:"nosave"` + fsutil.InodeNoopRelease `state:"nosave"` + fsutil.InodeNoopTruncate `state:"nosave"` + fsutil.InodeNoopWriteOut `state:"nosave"` + fsutil.InodeNotDirectory `state:"nosave"` + fsutil.InodeNotMappable `state:"nosave"` + fsutil.InodeNotSocket `state:"nosave"` + fsutil.InodeNotSymlink `state:"nosave"` + fsutil.InodeVirtual `state:"nosave"` + + fsutil.InodeSimpleAttributes +} + +var _ fs.InodeOperations = (*netTunInodeOperations)(nil) + +func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *netTunInodeOperations { + return &netTunInodeOperations{ + InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC), + } +} + +// GetFile implements fs.InodeOperations.GetFile. +func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil +} + +// +stateify savable +type netTunFileOperations struct { + fsutil.FileNoSeek `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + device tun.Device +} + +var _ fs.FileOperations = (*netTunFileOperations)(nil) + +// Release implements fs.FileOperations.Release. +func (fops *netTunFileOperations) Release() { + fops.device.Release() +} + +// Ioctl implements fs.FileOperations.Ioctl. +func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + request := args[1].Uint() + data := args[2].Pointer() + + switch request { + case linux.TUNSETIFF: + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("Ioctl should be called from a task context") + } + if !t.HasCapability(linux.CAP_NET_ADMIN) { + return 0, syserror.EPERM + } + stack, ok := t.NetworkContext().(*netstack.Stack) + if !ok { + return 0, syserror.EINVAL + } + + var req linux.IFReq + if _, err := usermem.CopyObjectIn(ctx, io, data, &req, usermem.IOOpts{ + AddressSpaceActive: true, + }); err != nil { + return 0, err + } + flags := usermem.ByteOrder.Uint16(req.Data[:]) + return 0, fops.device.SetIff(stack.Stack, req.Name(), flags) + + case linux.TUNGETIFF: + var req linux.IFReq + + copy(req.IFName[:], fops.device.Name()) + + // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when + // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c. + flags := fops.device.Flags() | linux.IFF_NOFILTER + usermem.ByteOrder.PutUint16(req.Data[:], flags) + + _, err := usermem.CopyObjectOut(ctx, io, data, &req, usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err + + default: + return 0, syserror.ENOTTY + } +} + +// Write implements fs.FileOperations.Write. +func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + data := make([]byte, src.NumBytes()) + if _, err := src.CopyIn(ctx, data); err != nil { + return 0, err + } + return fops.device.Write(data) +} + +// Read implements fs.FileOperations.Read. +func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + data, err := fops.device.Read() + if err != nil { + return 0, err + } + n, err := dst.CopyOut(ctx, data) + if n > 0 && n < len(data) { + // Not an error for partial copying. Packet truncated. + err = nil + } + return int64(n), err +} + +// Readiness implements watier.Waitable.Readiness. +func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { + return fops.device.Readiness(mask) +} + +// EventRegister implements watier.Waitable.EventRegister. +func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fops.device.EventRegister(e, mask) +} + +// EventUnregister implements watier.Waitable.EventUnregister. +func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) { + fops.device.EventUnregister(e) +} diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go index 2269f6237..4b5a0fca6 100644 --- a/pkg/syserror/syserror.go +++ b/pkg/syserror/syserror.go @@ -29,6 +29,7 @@ var ( EACCES = error(syscall.EACCES) EAGAIN = error(syscall.EAGAIN) EBADF = error(syscall.EBADF) + EBADFD = error(syscall.EBADFD) EBUSY = error(syscall.EBUSY) ECHILD = error(syscall.ECHILD) ECONNREFUSED = error(syscall.ECONNREFUSED) diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 150310c11..17e94c562 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -156,3 +156,9 @@ func (vv *VectorisedView) Append(vv2 VectorisedView) { vv.views = append(vv.views, vv2.views...) vv.size += vv2.size } + +// AppendView appends the given view into this vectorised view. +func (vv *VectorisedView) AppendView(v View) { + vv.views = append(vv.views, v) + vv.size += len(v) +} diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 3974c464e..b8b93e78e 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -7,6 +7,7 @@ go_library( srcs = ["channel.go"], visibility = ["//visibility:public"], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 78d447acd..5944ba190 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -20,6 +20,7 @@ package channel import ( "context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -33,6 +34,118 @@ type PacketInfo struct { Route stack.Route } +// Notification is the interface for receiving notification from the packet +// queue. +type Notification interface { + // WriteNotify will be called when a write happens to the queue. + WriteNotify() +} + +// NotificationHandle is an opaque handle to the registered notification target. +// It can be used to unregister the notification when no longer interested. +// +// +stateify savable +type NotificationHandle struct { + n Notification +} + +type queue struct { + // mu protects fields below. + mu sync.RWMutex + // c is the outbound packet channel. Sending to c should hold mu. + c chan PacketInfo + numWrite int + numRead int + notify []*NotificationHandle +} + +func (q *queue) Close() { + close(q.c) +} + +func (q *queue) Read() (PacketInfo, bool) { + q.mu.Lock() + defer q.mu.Unlock() + select { + case p := <-q.c: + q.numRead++ + return p, true + default: + return PacketInfo{}, false + } +} + +func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) { + // We have to receive from channel without holding the lock, since it can + // block indefinitely. This will cause a window that numWrite - numRead + // produces a larger number, but won't go to negative. numWrite >= numRead + // still holds. + select { + case pkt := <-q.c: + q.mu.Lock() + defer q.mu.Unlock() + q.numRead++ + return pkt, true + case <-ctx.Done(): + return PacketInfo{}, false + } +} + +func (q *queue) Write(p PacketInfo) bool { + wrote := false + + // It's important to make sure nobody can see numWrite until we increment it, + // so numWrite >= numRead holds. + q.mu.Lock() + select { + case q.c <- p: + wrote = true + q.numWrite++ + default: + } + notify := q.notify + q.mu.Unlock() + + if wrote { + // Send notification outside of lock. + for _, h := range notify { + h.n.WriteNotify() + } + } + return wrote +} + +func (q *queue) Num() int { + q.mu.RLock() + defer q.mu.RUnlock() + n := q.numWrite - q.numRead + if n < 0 { + panic("numWrite < numRead") + } + return n +} + +func (q *queue) AddNotify(notify Notification) *NotificationHandle { + q.mu.Lock() + defer q.mu.Unlock() + h := &NotificationHandle{n: notify} + q.notify = append(q.notify, h) + return h +} + +func (q *queue) RemoveNotify(handle *NotificationHandle) { + q.mu.Lock() + defer q.mu.Unlock() + // Make a copy, since we reads the array outside of lock when notifying. + notify := make([]*NotificationHandle, 0, len(q.notify)) + for _, h := range q.notify { + if h != handle { + notify = append(notify, h) + } + } + q.notify = notify +} + // Endpoint is link layer endpoint that stores outbound packets in a channel // and allows injection of inbound packets. type Endpoint struct { @@ -41,14 +154,16 @@ type Endpoint struct { linkAddr tcpip.LinkAddress LinkEPCapabilities stack.LinkEndpointCapabilities - // c is where outbound packets are queued. - c chan PacketInfo + // Outbound packet queue. + q *queue } // New creates a new channel endpoint. func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { return &Endpoint{ - c: make(chan PacketInfo, size), + q: &queue{ + c: make(chan PacketInfo, size), + }, mtu: mtu, linkAddr: linkAddr, } @@ -57,43 +172,36 @@ func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { // Close closes e. Further packet injections will panic. Reads continue to // succeed until all packets are read. func (e *Endpoint) Close() { - close(e.c) + e.q.Close() } -// Read does non-blocking read for one packet from the outbound packet queue. +// Read does non-blocking read one packet from the outbound packet queue. func (e *Endpoint) Read() (PacketInfo, bool) { - select { - case pkt := <-e.c: - return pkt, true - default: - return PacketInfo{}, false - } + return e.q.Read() } // ReadContext does blocking read for one packet from the outbound packet queue. // It can be cancelled by ctx, and in this case, it returns false. func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) { - select { - case pkt := <-e.c: - return pkt, true - case <-ctx.Done(): - return PacketInfo{}, false - } + return e.q.ReadContext(ctx) } // Drain removes all outbound packets from the channel and counts them. func (e *Endpoint) Drain() int { c := 0 for { - select { - case <-e.c: - c++ - default: + if _, ok := e.Read(); !ok { return c } + c++ } } +// NumQueued returns the number of packet queued for outbound. +func (e *Endpoint) NumQueued() int { + return e.q.Num() +} + // InjectInbound injects an inbound packet. func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) { e.InjectLinkAddr(protocol, "", pkt) @@ -155,10 +263,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne Route: route, } - select { - case e.c <- p: - default: - } + e.q.Write(p) return nil } @@ -171,7 +276,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac route.Release() payloadView := pkts[0].Data.ToView() n := 0 -packetLoop: for _, pkt := range pkts { off := pkt.DataOffset size := pkt.DataSize @@ -185,12 +289,10 @@ packetLoop: Route: route, } - select { - case e.c <- p: - n++ - default: - break packetLoop + if !e.q.Write(p) { + break } + n++ } return n, nil @@ -204,13 +306,21 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { GSO: nil, } - select { - case e.c <- p: - default: - } + e.q.Write(p) return nil } // Wait implements stack.LinkEndpoint.Wait. func (*Endpoint) Wait() {} + +// AddNotify adds a notification target for receiving event about outgoing +// packets. +func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle { + return e.q.AddNotify(notify) +} + +// RemoveNotify removes handle from the list of notification targets. +func (e *Endpoint) RemoveNotify(handle *NotificationHandle) { + e.q.RemoveNotify(handle) +} diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index e5096ea38..e0db6cf54 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -4,6 +4,22 @@ package(licenses = ["notice"]) go_library( name = "tun", - srcs = ["tun_unsafe.go"], + srcs = [ + "device.go", + "protocol.go", + "tun_unsafe.go", + ], visibility = ["//visibility:public"], + deps = [ + "//pkg/abi/linux", + "//pkg/refs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], ) diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go new file mode 100644 index 000000000..6ff47a742 --- /dev/null +++ b/pkg/tcpip/link/tun/device.go @@ -0,0 +1,352 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + // drivers/net/tun.c:tun_net_init() + defaultDevMtu = 1500 + + // Queue length for outbound packet, arriving at fd side for read. Overflow + // causes packet drops. gVisor implementation-specific. + defaultDevOutQueueLen = 1024 +) + +var zeroMAC [6]byte + +// Device is an opened /dev/net/tun device. +// +// +stateify savable +type Device struct { + waiter.Queue + + mu sync.RWMutex `state:"nosave"` + endpoint *tunEndpoint + notifyHandle *channel.NotificationHandle + flags uint16 +} + +// beforeSave is invoked by stateify. +func (d *Device) beforeSave() { + d.mu.Lock() + defer d.mu.Unlock() + // TODO(b/110961832): Restore the device to stack. At this moment, the stack + // is not savable. + if d.endpoint != nil { + panic("/dev/net/tun does not support save/restore when a device is associated with it.") + } +} + +// Release implements fs.FileOperations.Release. +func (d *Device) Release() { + d.mu.Lock() + defer d.mu.Unlock() + + // Decrease refcount if there is an endpoint associated with this file. + if d.endpoint != nil { + d.endpoint.RemoveNotify(d.notifyHandle) + d.endpoint.DecRef() + d.endpoint = nil + } +} + +// SetIff services TUNSETIFF ioctl(2) request. +func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { + d.mu.Lock() + defer d.mu.Unlock() + + if d.endpoint != nil { + return syserror.EINVAL + } + + // Input validations. + isTun := flags&linux.IFF_TUN != 0 + isTap := flags&linux.IFF_TAP != 0 + supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI) + if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 { + return syserror.EINVAL + } + + prefix := "tun" + if isTap { + prefix = "tap" + } + + endpoint, err := attachOrCreateNIC(s, name, prefix) + if err != nil { + return syserror.EINVAL + } + + d.endpoint = endpoint + d.notifyHandle = d.endpoint.AddNotify(d) + d.flags = flags + return nil +} + +func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error) { + for { + // 1. Try to attach to an existing NIC. + if name != "" { + if nic, found := s.GetNICByName(name); found { + endpoint, ok := nic.LinkEndpoint().(*tunEndpoint) + if !ok { + // Not a NIC created by tun device. + return nil, syserror.EOPNOTSUPP + } + if !endpoint.TryIncRef() { + // Race detected: NIC got deleted in between. + continue + } + return endpoint, nil + } + } + + // 2. Creating a new NIC. + id := tcpip.NICID(s.UniqueID()) + endpoint := &tunEndpoint{ + Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""), + stack: s, + nicID: id, + name: name, + } + if endpoint.name == "" { + endpoint.name = fmt.Sprintf("%s%d", prefix, id) + } + err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{ + Name: endpoint.name, + }) + switch err { + case nil: + return endpoint, nil + case tcpip.ErrDuplicateNICID: + // Race detected: A NIC has been created in between. + continue + default: + return nil, syserror.EINVAL + } + } +} + +// Write inject one inbound packet to the network interface. +func (d *Device) Write(data []byte) (int64, error) { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint == nil { + return 0, syserror.EBADFD + } + if !endpoint.IsAttached() { + return 0, syserror.EIO + } + + dataLen := int64(len(data)) + + // Packet information. + var pktInfoHdr PacketInfoHeader + if !d.hasFlags(linux.IFF_NO_PI) { + if len(data) < PacketInfoHeaderSize { + // Ignore bad packet. + return dataLen, nil + } + pktInfoHdr = PacketInfoHeader(data[:PacketInfoHeaderSize]) + data = data[PacketInfoHeaderSize:] + } + + // Ethernet header (TAP only). + var ethHdr header.Ethernet + if d.hasFlags(linux.IFF_TAP) { + if len(data) < header.EthernetMinimumSize { + // Ignore bad packet. + return dataLen, nil + } + ethHdr = header.Ethernet(data[:header.EthernetMinimumSize]) + data = data[header.EthernetMinimumSize:] + } + + // Try to determine network protocol number, default zero. + var protocol tcpip.NetworkProtocolNumber + switch { + case pktInfoHdr != nil: + protocol = pktInfoHdr.Protocol() + case ethHdr != nil: + protocol = ethHdr.Type() + } + + // Try to determine remote link address, default zero. + var remote tcpip.LinkAddress + switch { + case ethHdr != nil: + remote = ethHdr.SourceAddress() + default: + remote = tcpip.LinkAddress(zeroMAC[:]) + } + + pkt := tcpip.PacketBuffer{ + Data: buffer.View(data).ToVectorisedView(), + } + if ethHdr != nil { + pkt.LinkHeader = buffer.View(ethHdr) + } + endpoint.InjectLinkAddr(protocol, remote, pkt) + return dataLen, nil +} + +// Read reads one outgoing packet from the network interface. +func (d *Device) Read() ([]byte, error) { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint == nil { + return nil, syserror.EBADFD + } + + for { + info, ok := endpoint.Read() + if !ok { + return nil, syserror.ErrWouldBlock + } + + v, ok := d.encodePkt(&info) + if !ok { + // Ignore unsupported packet. + continue + } + return v, nil + } +} + +// encodePkt encodes packet for fd side. +func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { + var vv buffer.VectorisedView + + // Packet information. + if !d.hasFlags(linux.IFF_NO_PI) { + hdr := make(PacketInfoHeader, PacketInfoHeaderSize) + hdr.Encode(&PacketInfoFields{ + Protocol: info.Proto, + }) + vv.AppendView(buffer.View(hdr)) + } + + // If the packet does not already have link layer header, and the route + // does not exist, we can't compute it. This is possibly a raw packet, tun + // device doesn't support this at the moment. + if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" { + return nil, false + } + + // Ethernet header (TAP only). + if d.hasFlags(linux.IFF_TAP) { + // Add ethernet header if not provided. + if info.Pkt.LinkHeader == nil { + hdr := &header.EthernetFields{ + SrcAddr: info.Route.LocalLinkAddress, + DstAddr: info.Route.RemoteLinkAddress, + Type: info.Proto, + } + if hdr.SrcAddr == "" { + hdr.SrcAddr = d.endpoint.LinkAddress() + } + + eth := make(header.Ethernet, header.EthernetMinimumSize) + eth.Encode(hdr) + vv.AppendView(buffer.View(eth)) + } else { + vv.AppendView(info.Pkt.LinkHeader) + } + } + + // Append upper headers. + vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):])) + // Append data payload. + vv.Append(info.Pkt.Data) + + return vv.ToView(), true +} + +// Name returns the name of the attached network interface. Empty string if +// unattached. +func (d *Device) Name() string { + d.mu.RLock() + defer d.mu.RUnlock() + if d.endpoint != nil { + return d.endpoint.name + } + return "" +} + +// Flags returns the flags set for d. Zero value if unset. +func (d *Device) Flags() uint16 { + d.mu.RLock() + defer d.mu.RUnlock() + return d.flags +} + +func (d *Device) hasFlags(flags uint16) bool { + return d.flags&flags == flags +} + +// Readiness implements watier.Waitable.Readiness. +func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask { + if mask&waiter.EventIn != 0 { + d.mu.RLock() + endpoint := d.endpoint + d.mu.RUnlock() + if endpoint != nil && endpoint.NumQueued() == 0 { + mask &= ^waiter.EventIn + } + } + return mask & (waiter.EventIn | waiter.EventOut) +} + +// WriteNotify implements channel.Notification.WriteNotify. +func (d *Device) WriteNotify() { + d.Notify(waiter.EventIn) +} + +// tunEndpoint is the link endpoint for the NIC created by the tun device. +// +// It is ref-counted as multiple opening files can attach to the same NIC. +// The last owner is responsible for deleting the NIC. +type tunEndpoint struct { + *channel.Endpoint + + refs.AtomicRefCount + + stack *stack.Stack + nicID tcpip.NICID + name string +} + +// DecRef decrements refcount of e, removes NIC if refcount goes to 0. +func (e *tunEndpoint) DecRef() { + e.DecRefWithDestructor(func() { + e.stack.RemoveNIC(e.nicID) + }) +} diff --git a/pkg/tcpip/link/tun/protocol.go b/pkg/tcpip/link/tun/protocol.go new file mode 100644 index 000000000..89d9d91a9 --- /dev/null +++ b/pkg/tcpip/link/tun/protocol.go @@ -0,0 +1,56 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tun + +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // PacketInfoHeaderSize is the size of the packet information header. + PacketInfoHeaderSize = 4 + + offsetFlags = 0 + offsetProtocol = 2 +) + +// PacketInfoFields contains fields sent through the wire if IFF_NO_PI flag is +// not set. +type PacketInfoFields struct { + Flags uint16 + Protocol tcpip.NetworkProtocolNumber +} + +// PacketInfoHeader is the wire representation of the packet information sent if +// IFF_NO_PI flag is not set. +type PacketInfoHeader []byte + +// Encode encodes f into h. +func (h PacketInfoHeader) Encode(f *PacketInfoFields) { + binary.BigEndian.PutUint16(h[offsetFlags:][:2], f.Flags) + binary.BigEndian.PutUint16(h[offsetProtocol:][:2], uint16(f.Protocol)) +} + +// Flags returns the flag field in h. +func (h PacketInfoHeader) Flags() uint16 { + return binary.BigEndian.Uint16(h[offsetFlags:]) +} + +// Protocol returns the protocol field in h. +func (h PacketInfoHeader) Protocol() tcpip.NetworkProtocolNumber { + return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(h[offsetProtocol:])) +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 862954ab2..46d3a6646 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -298,6 +298,33 @@ func (n *NIC) enable() *tcpip.Error { return nil } +// remove detaches NIC from the link endpoint, and marks existing referenced +// network endpoints expired. This guarantees no packets between this NIC and +// the network stack. +func (n *NIC) remove() *tcpip.Error { + n.mu.Lock() + defer n.mu.Unlock() + + // Detach from link endpoint, so no packet comes in. + n.linkEP.Attach(nil) + + // Remove permanent and permanentTentative addresses, so no packet goes out. + var errs []*tcpip.Error + for nid, ref := range n.mu.endpoints { + switch ref.getKind() { + case permanentTentative, permanent: + if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil { + errs = append(errs, err) + } + } + } + if len(errs) > 0 { + return errs[0] + } + + return nil +} + // becomeIPv6Router transitions n into an IPv6 router. // // When transitioning into an IPv6 router, host-only state (NDP discovered @@ -1302,6 +1329,11 @@ func (n *NIC) Stack() *Stack { return n.stack } +// LinkEndpoint returns the link endpoint of n. +func (n *NIC) LinkEndpoint() LinkEndpoint { + return n.linkEP +} + // isAddrTentative returns true if addr is tentative on n. // // Note that if addr is not associated with n, then this function will return diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index f0ed76fbe..900dd46c5 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -916,6 +916,18 @@ func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { return s.CreateNICWithOptions(id, ep, NICOptions{}) } +// GetNICByName gets the NIC specified by name. +func (s *Stack) GetNICByName(name string) (*NIC, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, nic := range s.nics { + if nic.Name() == name { + return nic, true + } + } + return nil, false +} + // EnableNIC enables the given NIC so that the link-layer endpoint can start // delivering packets to it. func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { @@ -956,6 +968,33 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { return nic.enabled() } +// RemoveNIC removes NIC and all related routes from the network stack. +func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return tcpip.ErrUnknownNICID + } + delete(s.nics, id) + + // Remove routes in-place. n tracks the number of routes written. + n := 0 + for i, r := range s.routeTable { + if r.NIC != id { + // Keep this route. + if i > n { + s.routeTable[n] = r + } + n++ + } + } + s.routeTable = s.routeTable[:n] + + return nic.remove() +} + // NICAddressRanges returns a map of NICIDs to their associated subnets. func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index d1977d4de..3518e862d 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -678,6 +678,8 @@ syscall_test( test = "//test/syscalls/linux:truncate_test", ) +syscall_test(test = "//test/syscalls/linux:tuntap_test") + syscall_test(test = "//test/syscalls/linux:udp_bind_test") syscall_test( diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index aa303af84..704bae17b 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -131,6 +131,17 @@ cc_library( ], ) +cc_library( + name = "socket_netlink_route_util", + testonly = 1, + srcs = ["socket_netlink_route_util.cc"], + hdrs = ["socket_netlink_route_util.h"], + deps = [ + ":socket_netlink_util", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "socket_test_util", testonly = 1, @@ -3430,6 +3441,25 @@ cc_binary( ], ) +cc_binary( + name = "tuntap_test", + testonly = 1, + srcs = ["tuntap.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + gtest, + "//test/syscalls/linux:socket_netlink_route_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:posix_error", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "udp_socket_test_cases", testonly = 1, diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc index 4dd302eed..4e473268c 100644 --- a/test/syscalls/linux/dev.cc +++ b/test/syscalls/linux/dev.cc @@ -153,6 +153,13 @@ TEST(DevTest, TTYExists) { EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); } +TEST(DevTest, NetTunExists) { + struct stat statbuf = {}; + ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallSucceeds()); + // Check that it's a character device with rw-rw-rw- permissions. + EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc new file mode 100644 index 000000000..53eb3b6b2 --- /dev/null +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -0,0 +1,163 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_netlink_route_util.h" + +#include +#include +#include + +#include "absl/types/optional.h" +#include "test/syscalls/linux/socket_netlink_util.h" + +namespace gvisor { +namespace testing { +namespace { + +constexpr uint32_t kSeq = 12345; + +} // namespace + +PosixError DumpLinks( + const FileDescriptor& fd, uint32_t seq, + const std::function& fn) { + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = seq; + req.ifm.ifi_family = AF_UNSPEC; + + return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false); +} + +PosixErrorOr> DumpLinks() { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + std::vector links; + RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) { + if (hdr->nlmsg_type != RTM_NEWLINK || + hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) { + return; + } + const struct ifinfomsg* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME); + if (rta == nullptr) { + // Ignore links that do not have a name. + return; + } + + links.emplace_back(); + links.back().index = msg->ifi_index; + links.back().type = msg->ifi_type; + links.back().name = + std::string(reinterpret_cast(RTA_DATA(rta))); + })); + return links; +} + +PosixErrorOr> FindLoopbackLink() { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + for (const auto& link : links) { + if (link.type == ARPHRD_LOOPBACK) { + return absl::optional(link); + } + } + return absl::optional(); +} + +PosixError LinkAddLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifaddrmsg ifaddr; + char attrbuf[512]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); + req.hdr.nlmsg_type = RTM_NEWADDR; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifaddr.ifa_index = index; + req.ifaddr.ifa_family = family; + req.ifaddr.ifa_prefixlen = prefixlen; + + struct rtattr* rta = reinterpret_cast( + reinterpret_cast(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFA_LOCAL; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifinfo; + char pad[NLMSG_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo)); + req.hdr.nlmsg_type = RTM_NEWLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifinfo.ifi_index = index; + req.ifinfo.ifi_flags = flags; + req.ifinfo.ifi_change = change; + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +PosixError LinkSetMacAddr(int index, const void* addr, int addrlen) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifinfo; + char attrbuf[512]; + }; + + struct request req = {}; + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo)); + req.hdr.nlmsg_type = RTM_NEWLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.hdr.nlmsg_seq = kSeq; + req.ifinfo.ifi_index = index; + + struct rtattr* rta = reinterpret_cast( + reinterpret_cast(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFLA_ADDRESS; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h new file mode 100644 index 000000000..2c018e487 --- /dev/null +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -0,0 +1,55 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ + +#include +#include + +#include + +#include "absl/types/optional.h" +#include "test/syscalls/linux/socket_netlink_util.h" + +namespace gvisor { +namespace testing { + +struct Link { + int index; + int16_t type; + std::string name; +}; + +PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq, + const std::function& fn); + +PosixErrorOr> DumpLinks(); + +PosixErrorOr> FindLoopbackLink(); + +// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface. +PosixError LinkAddLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen); + +// LinkChangeFlags changes interface flags. E.g. IFF_UP. +PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change); + +// LinkSetMacAddr sets IFLA_ADDRESS attribute of the interface. +PosixError LinkSetMacAddr(int index, const void* addr, int addrlen); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_ diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc new file mode 100644 index 000000000..f6ac9d7b8 --- /dev/null +++ b/test/syscalls/linux/tuntap.cc @@ -0,0 +1,346 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_split.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { +namespace { + +constexpr int kIPLen = 4; + +constexpr const char kDevNetTun[] = "/dev/net/tun"; +constexpr const char kTapName[] = "tap0"; + +constexpr const uint8_t kMacA[ETH_ALEN] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}; +constexpr const uint8_t kMacB[ETH_ALEN] = {0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB}; + +PosixErrorOr> DumpLinkNames() { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + std::set names; + for (const auto& link : links) { + names.emplace(link.name); + } + return names; +} + +PosixErrorOr> GetLinkByName(const std::string& name) { + ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks()); + for (const auto& link : links) { + if (link.name == name) { + return absl::optional(link); + } + } + return absl::optional(); +} + +struct pihdr { + uint16_t pi_flags; + uint16_t pi_protocol; +} __attribute__((packed)); + +struct ping_pkt { + pihdr pi; + struct ethhdr eth; + struct iphdr ip; + struct icmphdr icmp; + char payload[64]; +} __attribute__((packed)); + +ping_pkt CreatePingPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, + const uint8_t dstmac[ETH_ALEN], const char* dstip) { + ping_pkt pkt = {}; + + pkt.pi.pi_protocol = htons(ETH_P_IP); + + memcpy(pkt.eth.h_dest, dstmac, sizeof(pkt.eth.h_dest)); + memcpy(pkt.eth.h_source, srcmac, sizeof(pkt.eth.h_source)); + pkt.eth.h_proto = htons(ETH_P_IP); + + pkt.ip.ihl = 5; + pkt.ip.version = 4; + pkt.ip.tos = 0; + pkt.ip.tot_len = htons(sizeof(struct iphdr) + sizeof(struct icmphdr) + + sizeof(pkt.payload)); + pkt.ip.id = 1; + pkt.ip.frag_off = 1 << 6; // Do not fragment + pkt.ip.ttl = 64; + pkt.ip.protocol = IPPROTO_ICMP; + inet_pton(AF_INET, dstip, &pkt.ip.daddr); + inet_pton(AF_INET, srcip, &pkt.ip.saddr); + pkt.ip.check = IPChecksum(pkt.ip); + + pkt.icmp.type = ICMP_ECHO; + pkt.icmp.code = 0; + pkt.icmp.checksum = 0; + pkt.icmp.un.echo.sequence = 1; + pkt.icmp.un.echo.id = 1; + + strncpy(pkt.payload, "abcd", sizeof(pkt.payload)); + pkt.icmp.checksum = ICMPChecksum(pkt.icmp, pkt.payload, sizeof(pkt.payload)); + + return pkt; +} + +struct arp_pkt { + pihdr pi; + struct ethhdr eth; + struct arphdr arp; + uint8_t arp_sha[ETH_ALEN]; + uint8_t arp_spa[kIPLen]; + uint8_t arp_tha[ETH_ALEN]; + uint8_t arp_tpa[kIPLen]; +} __attribute__((packed)); + +std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip, + const uint8_t dstmac[ETH_ALEN], const char* dstip) { + std::string buffer; + buffer.resize(sizeof(arp_pkt)); + + arp_pkt* pkt = reinterpret_cast(&buffer[0]); + { + pkt->pi.pi_protocol = htons(ETH_P_ARP); + + memcpy(pkt->eth.h_dest, kMacA, sizeof(pkt->eth.h_dest)); + memcpy(pkt->eth.h_source, kMacB, sizeof(pkt->eth.h_source)); + pkt->eth.h_proto = htons(ETH_P_ARP); + + pkt->arp.ar_hrd = htons(ARPHRD_ETHER); + pkt->arp.ar_pro = htons(ETH_P_IP); + pkt->arp.ar_hln = ETH_ALEN; + pkt->arp.ar_pln = kIPLen; + pkt->arp.ar_op = htons(ARPOP_REPLY); + + memcpy(pkt->arp_sha, srcmac, sizeof(pkt->arp_sha)); + inet_pton(AF_INET, srcip, pkt->arp_spa); + memcpy(pkt->arp_tha, dstmac, sizeof(pkt->arp_tha)); + inet_pton(AF_INET, dstip, pkt->arp_tpa); + } + return buffer; +} + +} // namespace + +class TuntapTest : public ::testing::Test { + protected: + void TearDown() override { + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))) { + // Bring back capability if we had dropped it in test case. + ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true)); + } + } +}; + +TEST_F(TuntapTest, CreateInterfaceNoCap) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, false)); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + strncpy(ifr.ifr_name, kTapName, IFNAMSIZ); + + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallFailsWithErrno(EPERM)); +} + +TEST_F(TuntapTest, CreateFixedNameInterface) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr_set = {}; + ifr_set.ifr_flags = IFF_TAP; + strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ); + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set), + SyscallSucceedsWithValue(0)); + + struct ifreq ifr_get = {}; + EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get), + SyscallSucceedsWithValue(0)); + + struct ifreq ifr_expect = ifr_set; + // See __tun_chr_ioctl() in net/drivers/tun.c. + ifr_expect.ifr_flags |= IFF_NOFILTER; + + EXPECT_THAT(DumpLinkNames(), + IsPosixErrorOkAndHolds(::testing::Contains(kTapName))); + EXPECT_THAT(memcmp(&ifr_expect, &ifr_get, sizeof(ifr_get)), ::testing::Eq(0)); +} + +TEST_F(TuntapTest, CreateInterface) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + // Empty ifr.ifr_name. Let kernel assign. + + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0)); + + struct ifreq ifr_get = {}; + EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get), + SyscallSucceedsWithValue(0)); + + std::string ifname = ifr_get.ifr_name; + EXPECT_THAT(ifname, ::testing::StartsWith("tap")); + EXPECT_THAT(DumpLinkNames(), + IsPosixErrorOkAndHolds(::testing::Contains(ifname))); +} + +TEST_F(TuntapTest, InvalidReadWrite) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + char buf[128] = {}; + EXPECT_THAT(read(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD)); + EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD)); +} + +TEST_F(TuntapTest, WriteToDownDevice) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // FIXME: gVisor always creates enabled/up'd interfaces. + SKIP_IF(IsRunningOnGvisor()); + + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + // Device created should be down by default. + struct ifreq ifr = {}; + ifr.ifr_flags = IFF_TAP; + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0)); + + char buf[128] = {}; + EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EIO)); +} + +// This test sets up a TAP device and pings kernel by sending ICMP echo request. +// +// It works as the following: +// * Open /dev/net/tun, and create kTapName interface. +// * Use rtnetlink to do initial setup of the interface: +// * Assign IP address 10.0.0.1/24 to kernel. +// * MAC address: kMacA +// * Bring up the interface. +// * Send an ICMP echo reqest (ping) packet from 10.0.0.2 (kMacB) to kernel. +// * Loop to receive packets from TAP device/fd: +// * If packet is an ICMP echo reply, it stops and passes the test. +// * If packet is an ARP request, it responds with canned reply and resends +// the +// ICMP request packet. +TEST_F(TuntapTest, PingKernel) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // Interface creation. + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR)); + + struct ifreq ifr_set = {}; + ifr_set.ifr_flags = IFF_TAP; + strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ); + EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set), + SyscallSucceedsWithValue(0)); + + absl::optional link = + ASSERT_NO_ERRNO_AND_VALUE(GetLinkByName(kTapName)); + ASSERT_TRUE(link.has_value()); + + // Interface setup. + struct in_addr addr; + inet_pton(AF_INET, "10.0.0.1", &addr); + EXPECT_NO_ERRNO(LinkAddLocalAddr(link->index, AF_INET, /*prefixlen=*/24, + &addr, sizeof(addr))); + + if (!IsRunningOnGvisor()) { + // FIXME: gVisor doesn't support setting MAC address on interfaces yet. + EXPECT_NO_ERRNO(LinkSetMacAddr(link->index, kMacA, sizeof(kMacA))); + + // FIXME: gVisor always creates enabled/up'd interfaces. + EXPECT_NO_ERRNO(LinkChangeFlags(link->index, IFF_UP, IFF_UP)); + } + + ping_pkt ping_req = CreatePingPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1"); + std::string arp_rep = CreateArpPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1"); + + // Send ping, this would trigger an ARP request on Linux. + EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)), + SyscallSucceedsWithValue(sizeof(ping_req))); + + // Receive loop to process inbound packets. + struct inpkt { + union { + pihdr pi; + ping_pkt ping; + arp_pkt arp; + }; + }; + while (1) { + inpkt r = {}; + int n = read(fd.get(), &r, sizeof(r)); + EXPECT_THAT(n, SyscallSucceeds()); + + if (n < sizeof(pihdr)) { + std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol + << " len: " << n << std::endl; + continue; + } + + // Process ARP packet. + if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) { + // Respond with canned ARP reply. + EXPECT_THAT(write(fd.get(), arp_rep.data(), arp_rep.size()), + SyscallSucceedsWithValue(arp_rep.size())); + // First ping request might have been dropped due to mac address not in + // ARP cache. Send it again. + EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)), + SyscallSucceedsWithValue(sizeof(ping_req))); + } + + // Process ping response packet. + if (n >= sizeof(ping_pkt) && r.pi.pi_protocol == ping_req.pi.pi_protocol && + r.ping.ip.protocol == ping_req.ip.protocol && + !memcmp(&r.ping.ip.saddr, &ping_req.ip.daddr, kIPLen) && + !memcmp(&r.ping.ip.daddr, &ping_req.ip.saddr, kIPLen) && + r.ping.icmp.type == 0 && r.ping.icmp.code == 0) { + // Ends and passes the test. + break; + } + } +} + +} // namespace testing +} // namespace gvisor -- cgit v1.2.3 From c37b196455e8b3816298e3eea98e4ee2dab8d368 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Mon, 24 Feb 2020 10:31:01 -0800 Subject: Add support for tearing down protocol dispatchers and TIME_WAIT endpoints. Protocol dispatchers were previously leaked. Bypassing TIME_WAIT is required to test this change. Also fix a race when a socket in SYN-RCVD is closed. This is also required to test this change. PiperOrigin-RevId: 296922548 --- pkg/tcpip/adapters/gonet/gonet_test.go | 63 ++++++++++++++++++++++++++-------- pkg/tcpip/network/arp/arp.go | 20 +++++++---- pkg/tcpip/network/ipv4/ipv4.go | 6 ++++ pkg/tcpip/network/ipv6/ipv6.go | 6 ++++ pkg/tcpip/stack/registration.go | 23 ++++++++++--- pkg/tcpip/stack/stack.go | 14 +++++++- pkg/tcpip/stack/stack_test.go | 6 ++++ pkg/tcpip/stack/transport_demuxer.go | 20 ----------- pkg/tcpip/stack/transport_test.go | 15 +++++++- pkg/tcpip/tcpip.go | 8 ++++- pkg/tcpip/transport/icmp/endpoint.go | 5 +++ pkg/tcpip/transport/icmp/protocol.go | 16 ++++++--- pkg/tcpip/transport/packet/endpoint.go | 5 +++ pkg/tcpip/transport/raw/endpoint.go | 5 +++ pkg/tcpip/transport/tcp/accept.go | 9 ++++- pkg/tcpip/transport/tcp/connect.go | 4 +-- pkg/tcpip/transport/tcp/dispatcher.go | 31 ++++++++++++++++- pkg/tcpip/transport/tcp/endpoint.go | 33 ++++++++++++++++-- pkg/tcpip/transport/tcp/protocol.go | 14 ++++++-- pkg/tcpip/transport/udp/endpoint.go | 5 +++ pkg/tcpip/transport/udp/protocol.go | 14 +++++--- 21 files changed, 256 insertions(+), 66 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index ea0a0409a..3c552988a 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -127,6 +127,10 @@ func TestCloseReader(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} @@ -175,6 +179,10 @@ func TestCloseReaderWithForwarder(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -225,30 +233,21 @@ func TestCloseRead(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) { var wq waiter.Queue - ep, err := r.CreateEndpoint(&wq) + _, err := r.CreateEndpoint(&wq) if err != nil { t.Fatalf("r.CreateEndpoint() = %v", err) } - defer ep.Close() - r.Complete(false) - - c := NewTCPConn(&wq, ep) - - buf := make([]byte, 256) - n, e := c.Read(buf) - if e != nil || string(buf[:n]) != "abc123" { - t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e) - } - - if n, e = c.Write([]byte("abc123")); e != nil { - t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e) - } + // Endpoint will be closed in deferred s.Close (above). }) s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket) @@ -278,6 +277,10 @@ func TestCloseWrite(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -334,6 +337,10 @@ func TestUDPForwarder(t *testing.T) { if terr != nil { t.Fatalf("newLoopbackStack() = %v", terr) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -391,6 +398,10 @@ func TestDeadlineChange(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} @@ -440,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr1 := tcpip.FullAddress{NICID, ip1, 11211} @@ -492,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -562,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) { stop = func() { c1.Close() c2.Close() + s.Close() + s.Wait() } if err := l.Close(); err != nil { @@ -624,6 +645,10 @@ func TestTCPDialError(t *testing.T) { if e != nil { t.Fatalf("newLoopbackStack() = %v", e) } + defer func() { + s.Close() + s.Wait() + }() ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4()) addr := tcpip.FullAddress{NICID, ip, 11211} @@ -641,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) @@ -659,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) { if err != nil { t.Fatalf("newLoopbackStack() = %v", err) } + defer func() { + s.Close() + s.Wait() + }() addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211} s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 4da13c5df..e9fcc89a8 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -148,12 +148,12 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi }, nil } -// LinkAddressProtocol implements stack.LinkAddressResolver. +// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol. func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return header.IPv4ProtocolNumber } -// LinkAddressRequest implements stack.LinkAddressResolver. +// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ RemoteLinkAddress: broadcastMAC, @@ -172,7 +172,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. }) } -// ResolveStaticAddress implements stack.LinkAddressResolver. +// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { return broadcastMAC, true @@ -183,16 +183,22 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo return tcpip.LinkAddress([]byte(nil)), false } -// SetOption implements NetworkProtocol. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.NetworkProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements NetworkProtocol. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.NetworkProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) // NewProtocol returns an ARP network protocol. diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 6597e6781..4f1742938 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -473,6 +473,12 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 180a480fd..9aef5234b 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -265,6 +265,12 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // calculateMTU calculates the network-layer payload MTU based on the link-layer // payload mtu. func calculateMTU(mtu uint32) uint32 { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index d83adf0ec..f9fd8f18f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -74,10 +74,11 @@ type TransportEndpoint interface { // HandleControlPacket takes ownership of pkt. HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer) - // Close puts the endpoint in a closed state and frees all resources - // associated with it. This cleanup may happen asynchronously. Wait can - // be used to block on this asynchronous cleanup. - Close() + // Abort initiates an expedited endpoint teardown. It puts the endpoint + // in a closed state and frees all resources associated with it. This + // cleanup may happen asynchronously. Wait can be used to block on this + // asynchronous cleanup. + Abort() // Wait waits for any worker goroutines owned by the endpoint to stop. // @@ -160,6 +161,13 @@ type TransportProtocol interface { // Option returns an error if the option is not supported or the // provided option value is invalid. Option(option interface{}) *tcpip.Error + + // Close requests that any worker goroutines owned by the protocol + // stop. + Close() + + // Wait waits for any worker goroutines owned by the protocol to stop. + Wait() } // TransportDispatcher contains the methods used by the network stack to deliver @@ -293,6 +301,13 @@ type NetworkProtocol interface { // Option returns an error if the option is not supported or the // provided option value is invalid. Option(option interface{}) *tcpip.Error + + // Close requests that any worker goroutines owned by the protocol + // stop. + Close() + + // Wait waits for any worker goroutines owned by the protocol to stop. + Wait() } // NetworkDispatcher contains the methods used by the network stack to deliver diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 900dd46c5..ebb6c5e3b 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1446,7 +1446,13 @@ func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { // Endpoints created or modified during this call may not get closed. func (s *Stack) Close() { for _, e := range s.RegisteredEndpoints() { - e.Close() + e.Abort() + } + for _, p := range s.transportProtocols { + p.proto.Close() + } + for _, p := range s.networkProtocols { + p.Close() } } @@ -1464,6 +1470,12 @@ func (s *Stack) Wait() { for _, e := range s.CleanupEndpoints() { e.Wait() } + for _, p := range s.transportProtocols { + p.proto.Wait() + } + for _, p := range s.networkProtocols { + p.Wait() + } s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 18016e7db..edf6bec52 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -235,6 +235,12 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { } } +// Close implements TransportProtocol.Close. +func (*fakeNetworkProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeNetworkProtocol) Wait() {} + func fakeNetFactory() stack.NetworkProtocol { return &fakeNetworkProtocol{} } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index d686e6eb8..778c0a4d6 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -306,26 +306,6 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p ep.mu.RUnlock() // Don't use defer for performance reasons. } -// Close implements stack.TransportEndpoint.Close. -func (ep *multiPortEndpoint) Close() { - ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) - ep.mu.RUnlock() - for _, e := range eps { - e.Close() - } -} - -// Wait implements stack.TransportEndpoint.Wait. -func (ep *multiPortEndpoint) Wait() { - ep.mu.RLock() - eps := append([]TransportEndpoint(nil), ep.endpointsArr...) - ep.mu.RUnlock() - for _, e := range eps { - e.Wait() - } -} - // singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint // list. The list might be empty already. func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 869c69a6d..5d1da2f8b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -61,6 +61,10 @@ func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netP return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } +func (f *fakeTransportEndpoint) Abort() { + f.Close() +} + func (f *fakeTransportEndpoint) Close() { f.route.Release() } @@ -272,7 +276,7 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } -func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { return nil, tcpip.ErrUnknownProtocol } @@ -310,6 +314,15 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } } +// Abort implements TransportProtocol.Abort. +func (*fakeTransportProtocol) Abort() {} + +// Close implements tcpip.Endpoint.Close. +func (*fakeTransportProtocol) Close() {} + +// Wait implements TransportProtocol.Wait. +func (*fakeTransportProtocol) Wait() {} + func fakeTransFactory() stack.TransportProtocol { return &fakeTransportProtocol{} } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index ce5527391..3dc5d87d6 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -341,9 +341,15 @@ type ControlMessages struct { // networking stack. type Endpoint interface { // Close puts the endpoint in a closed state and frees all resources - // associated with it. + // associated with it. Close initiates the teardown process, the + // Endpoint may not be fully closed when Close returns. Close() + // Abort initiates an expedited endpoint teardown. As compared to + // Close, Abort prioritizes closing the Endpoint quickly over cleanly. + // Abort is best effort; implementing Abort with Close is acceptable. + Abort() + // Read reads data from the endpoint and optionally returns the sender. // // This method does not block if there is no data pending. It will also diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 42afb3f5b..426da1ee6 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -96,6 +96,11 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 9ce500e80..113d92901 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -104,20 +104,26 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } -// SetOption implements TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.TransportProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.TransportProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // NewProtocol4 returns an ICMPv4 transport protocol. func NewProtocol4() stack.TransportProtocol { return &protocol{ProtocolNumber4} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index fc5bc69fa..5722815e9 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -98,6 +98,11 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb return ep, nil } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close implements tcpip.Endpoint.Close. func (ep *endpoint) Close() { ep.mu.Lock() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ee9c4c58b..2ef5fac76 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -121,6 +121,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt return e, nil } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close implements tcpip.Endpoint.Close. func (e *endpoint) Close() { e.mu.Lock() diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 08afb7c17..13e383ffc 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -299,6 +299,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.Close() + // Wake up any waiters. This is strictly not required normally + // as a socket that was never accepted can't really have any + // registered waiters except when stack.Wait() is called which + // waits for all registered endpoints to stop and expects an + // EventHUp. + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + if l.listenEP != nil { l.removePendingEndpoint(ep) } @@ -607,7 +614,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { e.mu.Unlock() // Notify waiters that the endpoint is shutdown. - e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() s := sleep.Sleeper{} diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5c5397823..7730e6445 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1372,7 +1372,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ e.snd.updateMaxPayloadSize(mtu, count) } - if n¬ifyReset != 0 { + if n¬ifyReset != 0 || n¬ifyAbort != 0 { return tcpip.ErrConnectionAborted } @@ -1655,7 +1655,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { } case notification: n := e.fetchNotifications() - if n¬ifyClose != 0 { + if n¬ifyClose != 0 || n¬ifyAbort != 0 { return nil } if n¬ifyDrain != 0 { diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index e18012ac0..d792b07d6 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -68,17 +68,28 @@ func (q *epQueue) empty() bool { type processor struct { epQ epQueue newEndpointWaker sleep.Waker + closeWaker sleep.Waker id int + wg sync.WaitGroup } func newProcessor(id int) *processor { p := &processor{ id: id, } + p.wg.Add(1) go p.handleSegments() return p } +func (p *processor) close() { + p.closeWaker.Assert() +} + +func (p *processor) wait() { + p.wg.Wait() +} + func (p *processor) queueEndpoint(ep *endpoint) { // Queue an endpoint for processing by the processor goroutine. p.epQ.enqueue(ep) @@ -87,11 +98,17 @@ func (p *processor) queueEndpoint(ep *endpoint) { func (p *processor) handleSegments() { const newEndpointWaker = 1 + const closeWaker = 2 s := sleep.Sleeper{} s.AddWaker(&p.newEndpointWaker, newEndpointWaker) + s.AddWaker(&p.closeWaker, closeWaker) defer s.Done() for { - s.Fetch(true) + id, ok := s.Fetch(true) + if ok && id == closeWaker { + p.wg.Done() + return + } for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { if ep.segmentQueue.empty() { continue @@ -160,6 +177,18 @@ func newDispatcher(nProcessors int) *dispatcher { } } +func (d *dispatcher) close() { + for _, p := range d.processors { + p.close() + } +} + +func (d *dispatcher) wait() { + for _, p := range d.processors { + p.wait() + } +} + func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) { ep := stackEP.(*endpoint) s := newSegment(r, id, pkt) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f2be0e651..f1ad19dac 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -121,6 +121,8 @@ const ( notifyDrain notifyReset notifyResetByPeer + // notifyAbort is a request for an expedited teardown. + notifyAbort notifyKeepaliveChanged notifyMSSChanged // notifyTickleWorker is used to tickle the protocol main loop during a @@ -785,6 +787,24 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) { } } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + // The abort notification is not processed synchronously, so no + // synchronization is needed. + // + // If the endpoint becomes connected after this check, we still close + // the endpoint. This worst case results in a slower abort. + // + // If the endpoint disconnected after the check, nothing needs to be + // done, so sending a notification which will potentially be ignored is + // fine. + if e.EndpointState().connected() { + e.notifyProtocolGoroutine(notifyAbort) + return + } + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources associated // with it. It must be called only once and with no other concurrent calls to // the endpoint. @@ -829,9 +849,18 @@ func (e *endpoint) closeNoShutdown() { // Either perform the local cleanup or kick the worker to make sure it // knows it needs to cleanup. tcpip.AddDanglingEndpoint(e) - if !e.workerRunning { + switch e.EndpointState() { + // Sockets in StateSynRecv state(passive connections) are closed when + // the handshake fails or if the listening socket is closed while + // handshake was in progress. In such cases the handshake goroutine + // is already gone by the time Close is called and we need to cleanup + // here. + case StateInitial, StateBound, StateSynRecv: e.cleanupLocked() - } else { + e.setEndpointState(StateClose) + case StateError, StateClose: + // do nothing. + default: e.workerCleanup = true e.notifyProtocolGoroutine(notifyClose) } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 958c06fa7..73098d904 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -194,7 +194,7 @@ func replyWithReset(s *segment) { sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } -// SetOption implements TransportProtocol.SetOption. +// SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { case SACKEnabled: @@ -269,7 +269,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } } -// Option implements TransportProtocol.Option. +// Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { case *SACKEnabled: @@ -331,6 +331,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { } } +// Close implements stack.TransportProtocol.Close. +func (p *protocol) Close() { + p.dispatcher.close() +} + +// Wait implements stack.TransportProtocol.Wait. +func (p *protocol) Wait() { + p.dispatcher.wait() +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index eff7f3600..1c6a600b8 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -186,6 +186,11 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 259c3072a..8df089d22 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -180,16 +180,22 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans return true } -// SetOption implements TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +// SetOption implements stack.TransportProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -// Option implements TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +// Option implements stack.TransportProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{} -- cgit v1.2.3 From d7b73792515d1ac34c8d8c41ef5de379f22f002b Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Tue, 25 Feb 2020 11:18:28 -0800 Subject: Deflake TestCurrentConnectedIncrement. TestCurrentConnectedIncrement fails consistently under gotsan due to the sleep to check metrics is exactly the same as the TIME-WAIT duration. Under gotsan things can be slow enough that the increment test is done before the protocol goroutine is run after the TIME-WAIT timer expires and does its cleanup. Increasing the sleep from 1s to 1.2s makes the test pass consistently. PiperOrigin-RevId: 297160181 --- pkg/tcpip/transport/tcp/tcp_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index cc118c993..5b2b16afa 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -543,8 +543,9 @@ func TestCurrentConnectedIncrement(t *testing.T) { ), ) - // Wait for the TIME-WAIT state to transition to CLOSED. - time.Sleep(1 * time.Second) + // Wait for a little more than the TIME-WAIT duration for the socket to + // transition to CLOSED state. + time.Sleep(1200 * time.Millisecond) if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got) -- cgit v1.2.3 From 5f1f9dd9d23d2b805c77b5c38d5900d13e6a29fe Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 25 Feb 2020 15:15:28 -0800 Subject: Use link-local source address for link-local multicast Tests: - header_test.TestIsV6LinkLocalMulticastAddress - header_test.TestScopeForIPv6Address - stack_test.TestIPv6SourceAddressSelectionScopeAndSameAddress PiperOrigin-RevId: 297215576 --- pkg/tcpip/header/ipv6.go | 22 ++++++++++ pkg/tcpip/header/ipv6_test.go | 98 ++++++++++++++++++++++++++++++++++++++++--- pkg/tcpip/stack/stack_test.go | 27 ++++++++---- 3 files changed, 134 insertions(+), 13 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 70e6ce095..76e88e9b3 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -115,6 +115,19 @@ const ( // for the secret key used to generate an opaque interface identifier as // outlined by RFC 7217. OpaqueIIDSecretKeyMinBytes = 16 + + // ipv6MulticastAddressScopeByteIdx is the byte where the scope (scop) field + // is located within a multicast IPv6 address, as per RFC 4291 section 2.7. + ipv6MulticastAddressScopeByteIdx = 1 + + // ipv6MulticastAddressScopeMask is the mask for the scope (scop) field, + // within the byte holding the field, as per RFC 4291 section 2.7. + ipv6MulticastAddressScopeMask = 0xF + + // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within + // a multicast IPv6 address that indicates the address has link-local scope, + // as per RFC 4291 section 2.7. + ipv6LinkLocalMulticastScope = 2 ) // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the @@ -340,6 +353,12 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool { return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } +// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6 +// link-local multicast address. +func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { + return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope +} + // IsV6UniqueLocalAddress determines if the provided address is an IPv6 // unique-local address (within the prefix FC00::/7). func IsV6UniqueLocalAddress(addr tcpip.Address) bool { @@ -411,6 +430,9 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { } switch { + case IsV6LinkLocalMulticastAddress(addr): + return LinkLocalScope, nil + case IsV6LinkLocalAddress(addr): return LinkLocalScope, nil diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index c3ad503aa..426a873b1 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -27,11 +27,12 @@ import ( ) const ( - linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") ) func TestEthernetAdddressToModifiedEUI64(t *testing.T) { @@ -256,6 +257,85 @@ func TestIsV6UniqueLocalAddress(t *testing.T) { } } +func TestIsV6LinkLocalMulticastAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid Link Local Multicast", + addr: linkLocalMulticastAddr, + expected: true, + }, + { + name: "Valid Link Local Multicast with flags", + addr: "\xff\xf2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + expected: true, + }, + { + name: "Link Local Unicast", + addr: linkLocalAddr, + expected: false, + }, + { + name: "IPv4 Multicast", + addr: "\xe0\x00\x00\x01", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV6LinkLocalMulticastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + +func TestIsV6LinkLocalAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid Link Local Unicast", + addr: linkLocalAddr, + expected: true, + }, + { + name: "Link Local Multicast", + addr: linkLocalMulticastAddr, + expected: false, + }, + { + name: "Unique Local", + addr: uniqueLocalAddr1, + expected: false, + }, + { + name: "Global", + addr: globalAddr, + expected: false, + }, + { + name: "IPv4 Link Local", + addr: "\xa9\xfe\x00\x01", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + func TestScopeForIPv6Address(t *testing.T) { tests := []struct { name string @@ -270,11 +350,17 @@ func TestScopeForIPv6Address(t *testing.T) { err: nil, }, { - name: "Link Local", + name: "Link Local Unicast", addr: linkLocalAddr, scope: header.LinkLocalScope, err: nil, }, + { + name: "Link Local Multicast", + addr: linkLocalMulticastAddr, + scope: header.LinkLocalScope, + err: nil, + }, { name: "Global", addr: globalAddr, diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index edf6bec52..e15db40fb 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2790,13 +2790,14 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) { func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { const ( - linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - nicID = 1 + linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + nicID = 1 ) // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test. @@ -2869,6 +2870,18 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { connectAddr: linkLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, + { + name: "Link Local most preferred for link local multicast (last address)", + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, + connectAddr: linkLocalMulticastAddr, + expectedLocalAddr: linkLocalAddr1, + }, + { + name: "Link Local most preferred for link local multicast (first address)", + nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, + connectAddr: linkLocalMulticastAddr, + expectedLocalAddr: linkLocalAddr1, + }, { name: "Unique Local most preferred (last address)", nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, -- cgit v1.2.3 From abf7ebcd38e8c2750f4542f29115140bb2b44a9b Mon Sep 17 00:00:00 2001 From: Nayana Bidari Date: Thu, 27 Feb 2020 10:59:32 -0800 Subject: Internal change. PiperOrigin-RevId: 297638665 --- pkg/sentry/socket/netstack/netstack.go | 40 +++++++++-- pkg/tcpip/transport/packet/endpoint.go | 21 +++++- test/syscalls/linux/packet_socket.cc | 124 ++++++++++++++++++++++++++++++--- 3 files changed, 167 insertions(+), 18 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index e187276c5..48c268bfa 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -712,14 +712,40 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, family, err := AddressAndFamily(sockaddr) - if err != nil { - return err - } - if err := s.checkFamily(family, true /* exact */); err != nil { - return err + family := usermem.ByteOrder.Uint16(sockaddr) + var addr tcpip.FullAddress + + // Bind for AF_PACKET requires only family, protocol and ifindex. + // In function AddressAndFamily, we check the address length which is + // not needed for AF_PACKET bind. + if family == linux.AF_PACKET { + var a linux.SockAddrLink + if len(sockaddr) < sockAddrLinkSize { + return syserr.ErrInvalidArgument + } + binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a) + + if a.Protocol != uint16(s.protocol) { + return syserr.ErrInvalidArgument + } + + addr = tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + } + } else { + var err *syserr.Error + addr, family, err = AddressAndFamily(sockaddr) + if err != nil { + return err + } + + if err = s.checkFamily(family, true /* exact */); err != nil { + return err + } + + addr = s.mapFamily(addr, family) } - addr = s.mapFamily(addr, family) // Issue the bind request to the endpoint. return syserr.TranslateNetstackError(s.Endpoint.Bind(addr)) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 5722815e9..09a1cd436 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -76,6 +76,7 @@ type endpoint struct { sndBufSize int closed bool stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool } // NewEndpoint returns a new packet endpoint. @@ -125,6 +126,7 @@ func (ep *endpoint) Close() { } ep.closed = true + ep.bound = false ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -216,7 +218,24 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex." // - packet(7). - return tcpip.ErrNotSupported + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.bound { + return tcpip.ErrAlreadyBound + } + + // Unregister endpoint with all the nics. + ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + + // Bind endpoint to receive packets from specific interface. + if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { + return err + } + + ep.bound = true + + return nil } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index 92ae55eec..bc22de788 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -163,16 +164,11 @@ int CookedPacketTest::GetLoopbackIndex() { return ifr.ifr_ifindex; } -// Receive via a packet socket. -TEST_P(CookedPacketTest, Receive) { - // Let's use a simple IP payload: a UDP datagram. - FileDescriptor udp_sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); - SendUDPMessage(udp_sock.get()); - +// Receive and verify the message via packet socket on interface. +void ReceiveMessage(int sock, int ifindex) { // Wait for the socket to become readable. struct pollfd pfd = {}; - pfd.fd = socket_; + pfd.fd = sock; pfd.events = POLLIN; EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); @@ -182,9 +178,10 @@ TEST_P(CookedPacketTest, Receive) { char buf[64]; struct sockaddr_ll src = {}; socklen_t src_len = sizeof(src); - ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + ASSERT_THAT(recvfrom(sock, buf, sizeof(buf), 0, reinterpret_cast(&src), &src_len), SyscallSucceedsWithValue(packet_size)); + // sockaddr_ll ends with an 8 byte physical address field, but ethernet // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns @@ -194,7 +191,7 @@ TEST_P(CookedPacketTest, Receive) { // TODO(b/129292371): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); - EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); + EXPECT_EQ(src.sll_ifindex, ifindex); EXPECT_EQ(src.sll_halen, ETH_ALEN); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { @@ -222,6 +219,18 @@ TEST_P(CookedPacketTest, Receive) { EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); } +// Receive via a packet socket. +TEST_P(CookedPacketTest, Receive) { + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Receive and verify the data. + int loopback_index = GetLoopbackIndex(); + ReceiveMessage(socket_, loopback_index); +} + // Send via a packet socket. TEST_P(CookedPacketTest, Send) { // TODO(b/129292371): Remove once we support packet socket writing. @@ -313,6 +322,101 @@ TEST_P(CookedPacketTest, Send) { EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); } +// Bind and receive via packet socket. +TEST_P(CookedPacketTest, BindReceive) { + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = GetLoopbackIndex(); + + ASSERT_THAT(bind(socket_, reinterpret_cast(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + SendUDPMessage(udp_sock.get()); + + // Receive and verify the data. + ReceiveMessage(socket_, bind_addr.sll_ifindex); +} + +// Double Bind socket. +TEST_P(CookedPacketTest, DoubleBind) { + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = GetLoopbackIndex(); + + ASSERT_THAT(bind(socket_, reinterpret_cast(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Binding socket again should fail. + ASSERT_THAT( + bind(socket_, reinterpret_cast(&bind_addr), + sizeof(bind_addr)), + // Linux 4.09 returns EINVAL here, but some time before 4.19 it switched + // to EADDRINUSE. + AnyOf(SyscallFailsWithErrno(EADDRINUSE), SyscallFailsWithErrno(EINVAL))); +} + +// Bind and verify we do not receive data on interface which is not bound +TEST_P(CookedPacketTest, BindDrop) { + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ifaddrs* if_addr_list = nullptr; + auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); }); + + ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds()); + + // Get interface other than loopback. + struct ifreq ifr = {}; + for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) { + if (strcmp(i->ifa_name, "lo") != 0) { + strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name)); + break; + } + } + + // Skip if no interface is available other than loopback. + if (strlen(ifr.ifr_name) == 0) { + GTEST_SKIP(); + } + + // Get interface index. + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + + // Bind to packet socket requires only family, protocol and ifindex. + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = ifr.ifr_ifindex; + + ASSERT_THAT(bind(socket_, reinterpret_cast(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // Send to loopback interface. + struct sockaddr_in dest = {}; + dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + dest.sin_family = AF_INET; + dest.sin_port = kPort; + EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0, + reinterpret_cast(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); + + // Wait and make sure the socket never receives any data. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); -- cgit v1.2.3 From c6bdc6b05b4abfcb3c677013496c9a6b1d1365dd Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Thu, 27 Feb 2020 14:14:34 -0800 Subject: Fix a race in TCP endpoint teardown and teardown the stack in tcp_test. Call stack.Close on stacks when we are done with them in tcp_test. This avoids leaking resources and reduces the test's flakiness when race/gotsan is enabled. It also provides test coverage for the race also fixed in this change, which can be reliably triggered with the stack.Close change (and without the other changes) when race/gotsan is enabled. The race was possible when calling Abort (via stack.Close) on an endpoint processing a SYN segment as part of a passive connect. Updates #1564 PiperOrigin-RevId: 297685432 --- pkg/tcpip/transport/tcp/accept.go | 1 + pkg/tcpip/transport/tcp/connect.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 16 +++++++++++++++- pkg/tcpip/transport/tcp/testing/context/context.go | 1 + 4 files changed, 18 insertions(+), 2 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 13e383ffc..85049e54e 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -236,6 +236,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} n.rcvBufSize = int(l.rcvWnd) n.amss = mssForRoute(&n.route) + n.setEndpointState(StateConnecting) n.maybeEnableTimestamp(rcvdSynOpts) n.maybeEnableSACKPermitted(rcvdSynOpts) diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 7730e6445..cd247f3e1 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -577,7 +577,7 @@ func (h *handshake) execute() *tcpip.Error { case wakerForNotification: n := h.ep.fetchNotifications() - if n¬ifyClose != 0 { + if (n¬ifyClose)|(n¬ifyAbort) != 0 { return tcpip.ErrAborted } if n¬ifyDrain != 0 { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index f1ad19dac..9e72730bd 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -798,7 +798,21 @@ func (e *endpoint) Abort() { // If the endpoint disconnected after the check, nothing needs to be // done, so sending a notification which will potentially be ignored is // fine. - if e.EndpointState().connected() { + // + // If the endpoint connecting finishes after the check, the endpoint + // is either in a connected state (where we would notifyAbort anyway), + // SYN-RECV (where we would also notifyAbort anyway), or in an error + // state where nothing is required and the notification can be safely + // ignored. + // + // Endpoints where a Close during connecting or SYN-RECV state would be + // problematic are set to state connecting before being registered (and + // thus possible to be Aborted). They are never available in initial + // state. + // + // Endpoints transitioning from initial to connecting state may be + // safely either closed or sent notifyAbort. + if s := e.EndpointState(); s == StateConnecting || s == StateSynRecv || s.connected() { e.notifyProtocolGoroutine(notifyAbort) return } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 1e9a0dea3..8cea20fb5 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -204,6 +204,7 @@ func (c *Context) Cleanup() { if c.EP != nil { c.EP.Close() } + c.Stack().Close() } // Stack returns a reference to the stack in the Context. -- cgit v1.2.3 From 8821a7104f8c8f263b88def1a646d518ec3f5dd2 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Mon, 2 Mar 2020 13:14:12 -0800 Subject: Do not read-lock NIC recursively A deadlock may occur if a write lock on a RWMutex is blocked between nested read lock attempts as the inner read lock attempt will be blocked in this scenario. Example (T1 and T2 are differnt goroutines): T1: obtain read-lock T2: attempt write-lock (blocks) T1: attempt inner/nested read-lock (blocks) Here we can see that T1 and T2 are deadlocked. Tests: Existing tests pass. PiperOrigin-RevId: 298426678 --- pkg/tcpip/stack/nic.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 46d3a6646..3e6196aee 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -451,7 +451,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs)) for _, r := range primaryAddrs { // If r is not valid for outgoing connections, it is not a valid endpoint. - if !r.isValidForOutgoing() { + if !r.isValidForOutgoingRLocked() { continue } -- cgit v1.2.3 From 33101752501fafea99d77f34bbd65f3e0083d22e Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Mon, 2 Mar 2020 14:43:52 -0800 Subject: Fix data-race when reading/writing e.amss. PiperOrigin-RevId: 298451319 --- pkg/tcpip/transport/tcp/connect.go | 11 +++++++++-- pkg/tcpip/transport/tcp/endpoint.go | 29 ++++++++++++++++++----------- test/syscalls/linux/tcp_socket.cc | 15 +++++++++++++++ 3 files changed, 42 insertions(+), 13 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index cd247f3e1..ae4f3f3a9 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -295,6 +295,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { h.state = handshakeSynRcvd h.ep.mu.Lock() ttl := h.ep.ttl + amss := h.ep.amss h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ @@ -307,7 +308,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // permits SACK. This is not explicitly defined in the RFC but // this is the behaviour implemented by Linux. SACKPermitted: rcvSynOpts.SACKPermitted, - MSS: h.ep.amss, + MSS: amss, } if ttl == 0 { ttl = s.route.DefaultTTL() @@ -356,6 +357,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return tcpip.ErrInvalidEndpointState } + h.ep.mu.RLock() + amss := h.ep.amss + h.ep.mu.RUnlock() + h.resetState() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, @@ -363,7 +368,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { TSVal: h.ep.timestamp(), TSEcr: h.ep.recentTimestamp(), SACKPermitted: h.ep.sackPermitted, - MSS: h.ep.amss, + MSS: amss, } h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil @@ -530,6 +535,7 @@ func (h *handshake) execute() *tcpip.Error { // Send the initial SYN segment and loop until the handshake is // completed. + h.ep.mu.Lock() h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) synOpts := header.TCPSynOptions{ @@ -540,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error { SACKPermitted: bool(sackEnabled), MSS: h.ep.amss, } + h.ep.mu.Unlock() // Execute is also called in a listen context so we want to make sure we // only send the TS/SACK option when we received the TS/SACK in the diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 9e72730bd..8b9154e69 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -959,15 +959,18 @@ func (e *endpoint) initialReceiveWindow() int { // ModerateRecvBuf adjusts the receive buffer and the advertised window // based on the number of bytes copied to user space. func (e *endpoint) ModerateRecvBuf(copied int) { + e.mu.RLock() e.rcvListMu.Lock() if e.rcvAutoParams.disabled { e.rcvListMu.Unlock() + e.mu.RUnlock() return } now := time.Now() if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt { e.rcvAutoParams.copied += copied e.rcvListMu.Unlock() + e.mu.RUnlock() return } prevRTTCopied := e.rcvAutoParams.copied + copied @@ -1008,7 +1011,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvBufSize = rcvWnd availAfter := e.receiveBufferAvailableLocked() mask := uint32(notifyReceiveWindowChanged) - if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } e.notifyProtocolGoroutine(mask) @@ -1023,6 +1026,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvAutoParams.measureTime = now e.rcvAutoParams.copied = 0 e.rcvListMu.Unlock() + e.mu.RUnlock() } // IPTables implements tcpip.Endpoint.IPTables. @@ -1052,7 +1056,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, v, err := e.readLocked() e.rcvListMu.Unlock() - e.mu.RUnlock() if err == tcpip.ErrClosedForReceive { @@ -1085,7 +1088,7 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // enough buffer space, to either fit an aMSS or half a receive buffer // (whichever smaller), then notify the protocol goroutine to send a // window update. - if crossed, above := e.windowCrossedACKThreshold(len(v)); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above { e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -1303,9 +1306,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro return num, tcpip.ControlMessages{}, nil } -// windowCrossedACKThreshold checks if the receive window to be announced now -// would be under aMSS or under half receive buffer, whichever smaller. This is -// useful as a receive side silly window syndrome prevention mechanism. If +// windowCrossedACKThresholdLocked checks if the receive window to be announced +// now would be under aMSS or under half receive buffer, whichever smaller. This +// is useful as a receive side silly window syndrome prevention mechanism. If // window grows to reasonable value, we should send ACK to the sender to inform // the rx space is now large. We also want ensure a series of small read()'s // won't trigger a flood of spurious tiny ACK's. @@ -1316,7 +1319,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // crossed will be true if the window size crossed the ACK threshold. // above will be true if the new window is >= ACK threshold and false // otherwise. -func (e *endpoint) windowCrossedACKThreshold(deltaBefore int) (crossed bool, above bool) { +// +// Precondition: e.mu and e.rcvListMu must be held. +func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { newAvail := e.receiveBufferAvailableLocked() oldAvail := newAvail - deltaBefore if oldAvail < 0 { @@ -1379,6 +1384,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { mask := uint32(notifyReceiveWindowChanged) + e.mu.RLock() e.rcvListMu.Lock() // Make sure the receive buffer size allows us to send a @@ -1405,11 +1411,11 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // Immediately send an ACK to uncork the sender silly window // syndrome prevetion, when our available space grows above aMSS // or half receive buffer, whichever smaller. - if crossed, above := e.windowCrossedACKThreshold(availAfter - availBefore); crossed && above { + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { mask |= notifyNonZeroReceiveWindow } e.rcvListMu.Unlock() - + e.mu.RUnlock() e.notifyProtocolGoroutine(mask) return nil @@ -2414,13 +2420,14 @@ func (e *endpoint) updateSndBufferUsage(v int) { // to be read, or when the connection is closed for receiving (in which case // s will be nil). func (e *endpoint) readyToRead(s *segment) { + e.mu.RLock() e.rcvListMu.Lock() if s != nil { s.incRef() e.rcvBufUsed += s.data.Size() // Increase counter if the receive window falls down below MSS // or half receive buffer size, whichever smaller. - if crossed, above := e.windowCrossedACKThreshold(-s.data.Size()); crossed && !above { + if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above { e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() } e.rcvList.PushBack(s) @@ -2428,7 +2435,7 @@ func (e *endpoint) readyToRead(s *segment) { e.rcvClosed = true } e.rcvListMu.Unlock() - + e.mu.RUnlock() e.waiterQueue.Notify(waiter.EventIn) } diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index c4591a3b9..579463384 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1349,6 +1349,21 @@ TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) { SyscallFailsWithErrno(ENOTCONN)); } +TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) { + auto s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + RetryEINTR(connect)(s.get(), reinterpret_cast(&addr), + addrlen); + int buf_sz = 1 << 18; + EXPECT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)), + SyscallSucceedsWithValue(0)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); -- cgit v1.2.3 From c15b8515eb4a07699e5f2401f0332286f0a51043 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Tue, 3 Mar 2020 13:40:59 -0800 Subject: Fix datarace on TransportEndpointInfo.ID and clean up semantics. Ensures that all access to TransportEndpointInfo.ID is either: * In a function ending in a Locked suffix. * While holding the appropriate mutex. This primary affects the checkV4Mapped method on affected endpoints, which has been renamed to checkV4MappedLocked. Also document the method and change its argument to be a value instead of a pointer which had caused some awkwardness. This race was possible in the udp and icmp endpoints between Connect and uses of TransportEndpointInfo.ID including in both itself and Bind. The tcp endpoint did not suffer from this bug, but benefited from better documentation. Updates #357 PiperOrigin-RevId: 298682913 --- pkg/tcpip/stack/stack.go | 12 +++++++----- pkg/tcpip/transport/icmp/endpoint.go | 23 +++++++++++------------ pkg/tcpip/transport/tcp/endpoint.go | 15 ++++++++------- pkg/tcpip/transport/udp/endpoint.go | 30 ++++++++++++++++-------------- pkg/tcpip/transport/udp/endpoint_state.go | 3 +++ 5 files changed, 45 insertions(+), 38 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index ebb6c5e3b..13354d884 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -551,11 +551,13 @@ type TransportEndpointInfo struct { RegisterNICID tcpip.NICID } -// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address -// and returns the network protocol number to be used to communicate with the -// specified address. It returns an error if the passed address is incompatible -// with the receiver. -func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +// AddrNetProtoLocked unwraps the specified address if it is a V4-mapped V6 +// address and returns the network protocol number to be used to communicate +// with the specified address. It returns an error if the passed address is +// incompatible with the receiver. +// +// Preconditon: the parent endpoint mu must be held while calling this method. +func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { netProto := e.NetProto switch len(addr.Addr) { case header.IPv4AddressSize: diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 426da1ee6..2a396e9bc 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -291,15 +291,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID = e.BindNICID } - toCopy := *to - to = &toCopy - netProto, err := e.checkV4Mapped(to) + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err } - // Find the enpoint. - r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) + // Find the endpoint. + r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) if err != nil { return 0, nil, err } @@ -480,13 +478,14 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err }) } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, false /* v6only */) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -517,7 +516,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -630,7 +629,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8b9154e69..40cc664c0 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1874,13 +1874,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -1910,7 +1911,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc connectingAddr := addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -2276,7 +2277,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } e.BindAddr = addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 1c6a600b8..0af4514e1 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -443,19 +443,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrBroadcastDisabled } - netProto, err := e.checkV4Mapped(to) + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err } - r, _, err := e.connectRoute(nicID, *to, netProto) + r, _, err := e.connectRoute(nicID, dst, netProto) if err != nil { return 0, nil, err } defer r.Release() route = &r - dstPort = to.Port + dstPort = dst.Port } if route.IsResolutionRequired() { @@ -566,7 +566,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - netProto, err := e.checkV4Mapped(&fa) + fa, netProto, err := e.checkV4MappedLocked(fa) if err != nil { return err } @@ -927,13 +927,14 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u return nil } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -981,10 +982,6 @@ func (e *endpoint) Disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - netProto, err := e.checkV4Mapped(&addr) - if err != nil { - return err - } if addr.Port == 0 { // We don't support connecting to port zero. return tcpip.ErrInvalidEndpointState @@ -1012,6 +1009,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + r, nicID, err := e.connectRoute(nicID, addr, netProto) if err != nil { return err @@ -1139,7 +1141,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 43fb047ed..466bd9381 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -69,6 +69,9 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.mu.Lock() + defer e.mu.Unlock() + e.stack = s for _, m := range e.multicastMemberships { -- cgit v1.2.3 From 371abe00f052dec37106f2dc22921bc84fb94818 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Tue, 3 Mar 2020 15:06:07 -0800 Subject: Avoid memory leaks Properly discard segments from the segment heap. PiperOrigin-RevId: 298704074 --- pkg/tcpip/transport/tcp/rcv.go | 4 ++++ pkg/tcpip/transport/tcp/segment_heap.go | 1 + 2 files changed, 5 insertions(+) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 958f03ac1..d80aff1b6 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -195,6 +195,10 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum for i := first; i < len(r.pendingRcvdSegments); i++ { r.pendingRcvdSegments[i].decRef() + // Note that slice truncation does not allow garbage collection of + // truncated items, thus truncated items must be set to nil to avoid + // memory leaks. + r.pendingRcvdSegments[i] = nil } r.pendingRcvdSegments = r.pendingRcvdSegments[:first] diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go index 9fd061d7d..e28f213ba 100644 --- a/pkg/tcpip/transport/tcp/segment_heap.go +++ b/pkg/tcpip/transport/tcp/segment_heap.go @@ -41,6 +41,7 @@ func (h *segmentHeap) Pop() interface{} { old := *h n := len(old) x := old[n-1] + old[n-1] = nil *h = old[:n-1] return x } -- cgit v1.2.3 From 9b3aad33c4470908953b7b548b12cba77799f342 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Thu, 5 Mar 2020 15:55:40 -0800 Subject: Use a pool of arrays to avoid slice headers from escaping in TCP options pool. By putting slices into the pool, the slice header escapes. This can be avoided by not putting the slice header into the pool. This removes an allocation from the TCP segment send path. PiperOrigin-RevId: 299215480 --- pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/connect.go | 6 +++--- pkg/tcpip/transport/tcp/connect_unsafe.go | 30 ++++++++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 3 deletions(-) create mode 100644 pkg/tcpip/transport/tcp/connect_unsafe.go (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 272e8f570..a32f9eacf 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -32,6 +32,7 @@ go_library( srcs = [ "accept.go", "connect.go", + "connect_unsafe.go", "cubic.go", "cubic_state.go", "dispatcher.go", diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index ae4f3f3a9..c0f73ef16 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -624,17 +624,17 @@ func parseSynSegmentOptions(s *segment) header.TCPSynOptions { var optionPool = sync.Pool{ New: func() interface{} { - return make([]byte, maxOptionSize) + return &[maxOptionSize]byte{} }, } func getOptions() []byte { - return optionPool.Get().([]byte) + return (*optionPool.Get().(*[maxOptionSize]byte))[:] } func putOptions(options []byte) { // Reslice to full capacity. - optionPool.Put(options[0:cap(options)]) + optionPool.Put(optionsToArray(options)) } func makeSynOptions(opts header.TCPSynOptions) []byte { diff --git a/pkg/tcpip/transport/tcp/connect_unsafe.go b/pkg/tcpip/transport/tcp/connect_unsafe.go new file mode 100644 index 000000000..cfc304616 --- /dev/null +++ b/pkg/tcpip/transport/tcp/connect_unsafe.go @@ -0,0 +1,30 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "reflect" + "unsafe" +) + +// optionsToArray converts a slice of capacity >-= maxOptionSize to an array. +// +// optionsToArray panics if the capacity of options is smaller than +// maxOptionSize. +func optionsToArray(options []byte) *[maxOptionSize]byte { + // Reslice to full capacity. + options = options[0:maxOptionSize] + return (*[maxOptionSize]byte)(unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&options)).Data)) +} -- cgit v1.2.3 From d6f5e71df2c8ff3d763cba703786af68af1f9841 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 6 Mar 2020 08:01:45 -0800 Subject: Get strings for stack.DHCPv6ConfigurationFromNDPRA Useful for logs to print the string representation of the value instead of the integer value. PiperOrigin-RevId: 299356847 --- pkg/tcpip/stack/BUILD | 1 + .../stack/dhcpv6configurationfromndpra_string.go | 39 ++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 705cf01ee..8febd54c8 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -18,6 +18,7 @@ go_template_instance( go_library( name = "stack", srcs = [ + "dhcpv6configurationfromndpra_string.go", "icmp_rate_limit.go", "linkaddrcache.go", "linkaddrentry_list.go", diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go new file mode 100644 index 000000000..8b4213eec --- /dev/null +++ b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go @@ -0,0 +1,39 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type=DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[DHCPv6NoConfiguration-0] + _ = x[DHCPv6ManagedAddress-1] + _ = x[DHCPv6OtherConfigurations-2] +} + +const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations" + +var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66} + +func (i DHCPv6ConfigurationFromNDPRA) String() string { + if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) { + return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]] +} -- cgit v1.2.3 From f50d9a31e9e734a02e0191f6bc91b387bb21f9ab Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 6 Mar 2020 11:32:13 -0800 Subject: Specify the source of outgoing NDP RS If the NIC has a valid IPv6 address assigned, use it as the source address for outgoing NDP Router Solicitation packets. Test: stack_test.TestRouterSolicitation PiperOrigin-RevId: 299398763 --- pkg/tcpip/checker/checker.go | 106 +++++++++++++++++++++++++++---------------- pkg/tcpip/stack/ndp.go | 29 ++++++++++-- pkg/tcpip/stack/ndp_test.go | 39 +++++++++++++--- 3 files changed, 125 insertions(+), 49 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index c6c160dfc..8dc0f7c0e 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -785,6 +785,52 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { } } +// ndpOptions checks that optsBuf only contains opts. +func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) { + t.Helper() + + it, err := optsBuf.Iter(true) + if err != nil { + t.Errorf("optsBuf.Iter(true): %s", err) + return + } + + i := 0 + for { + opt, done, err := it.Next() + if err != nil { + // This should never happen as Iter(true) above did not return an error. + t.Fatalf("unexpected error when iterating over NDP options: %s", err) + } + if done { + break + } + + if i >= len(opts) { + t.Errorf("got unexpected option: %s", opt) + continue + } + + switch wantOpt := opts[i].(type) { + case header.NDPSourceLinkLayerAddressOption: + gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) + if !ok { + t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) + } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { + t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) + } + default: + t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt) + } + + i++ + } + + if missing := opts[i:]; len(missing) > 0 { + t.Errorf("missing options: %s", missing) + } +} + // NDPNSOptions creates a checker that checks that the packet contains the // provided NDP options within an NDP Neighbor Solicitation message. // @@ -796,47 +842,31 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker { icmp := h.(header.ICMPv6) ns := header.NDPNeighborSolicit(icmp.NDPPayload()) - it, err := ns.Options().Iter(true) - if err != nil { - t.Errorf("opts.Iter(true): %s", err) - return - } - - i := 0 - for { - opt, done, _ := it.Next() - if done { - break - } - - if i >= len(opts) { - t.Errorf("got unexpected option: %s", opt) - continue - } - - switch wantOpt := opts[i].(type) { - case header.NDPSourceLinkLayerAddressOption: - gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) - if !ok { - t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) - } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { - t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) - } - default: - panic("not implemented") - } - - i++ - } - - if missing := opts[i:]; len(missing) > 0 { - t.Errorf("missing options: %s", missing) - } + ndpOptions(t, ns.Options(), opts) } } // NDPRS creates a checker that checks that the packet contains a valid NDP // Router Solicitation message (as per the raw wire format). -func NDPRS() NetworkChecker { - return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize) +// +// checkers may assume that a valid ICMPv6 is passed to it containing a valid +// NDPRS as far as the size of the message is concerned. The values within the +// message are up to checkers to validate. +func NDPRS(checkers ...TransportChecker) NetworkChecker { + return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...) +} + +// NDPRSOptions creates a checker that checks that the packet contains the +// provided NDP options within an NDP Router Solicitation message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPRS message as far as the size is concerned. +func NDPRSOptions(opts []header.NDPOption) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + rs := header.NDPRouterSolicit(icmp.NDPPayload()) + ndpOptions(t, rs.Options(), opts) + } } diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index f651871ce..a9f4d5dad 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -1220,9 +1220,15 @@ func (ndp *ndpState) startSolicitingRouters() { } ndp.rtrSolicitTimer = time.AfterFunc(delay, func() { - // Send an RS message with the unspecified source address. - ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) + // As per RFC 4861 section 4.1, the source of the RS is an address assigned + // to the sending interface, or the unspecified address if no address is + // assigned to the sending interface. + ref := ndp.nic.primaryIPv6Endpoint(header.IPv6AllRoutersMulticastAddress) + if ref == nil { + ref = ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) + } + localAddr := ref.ep.ID().LocalAddress + r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() // Route should resolve immediately since @@ -1234,10 +1240,25 @@ func (ndp *ndpState) startSolicitingRouters() { log.Fatalf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()) } - payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source + // link-layer address option if the source address of the NDP RS is + // specified. This option MUST NOT be included if the source address is + // unspecified. + // + // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by + // LinkEndpoint.LinkAddress) before reaching this point. + var optsSerializer header.NDPOptionsSerializer + if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) { + optsSerializer = header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress), + } + } + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize) pkt := header.ICMPv6(hdr.Prepend(payloadSize)) pkt.SetType(header.ICMPv6RouterSolicit) + rs := header.NDPRouterSolicit(pkt.NDPPayload()) + rs.Options().Serialize(optsSerializer) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) sent := r.Stats().ICMP.V6PacketsSent diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 6e9306d09..98b1c807c 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -3384,6 +3384,10 @@ func TestRouterSolicitation(t *testing.T) { tests := []struct { name string linkHeaderLen uint16 + linkAddr tcpip.LinkAddress + nicAddr tcpip.Address + expectedSrcAddr tcpip.Address + expectedNDPOpts []header.NDPOption maxRtrSolicit uint8 rtrSolicitInt time.Duration effectiveRtrSolicitInt time.Duration @@ -3392,6 +3396,7 @@ func TestRouterSolicitation(t *testing.T) { }{ { name: "Single RS with delay", + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 1, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3401,6 +3406,8 @@ func TestRouterSolicitation(t *testing.T) { { name: "Two RS with delay", linkHeaderLen: 1, + nicAddr: llAddr1, + expectedSrcAddr: llAddr1, maxRtrSolicit: 2, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3408,8 +3415,14 @@ func TestRouterSolicitation(t *testing.T) { effectiveMaxRtrSolicitDelay: 500 * time.Millisecond, }, { - name: "Single RS without delay", - linkHeaderLen: 2, + name: "Single RS without delay", + linkHeaderLen: 2, + linkAddr: linkAddr1, + nicAddr: llAddr1, + expectedSrcAddr: llAddr1, + expectedNDPOpts: []header.NDPOption{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }, maxRtrSolicit: 1, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3419,6 +3432,8 @@ func TestRouterSolicitation(t *testing.T) { { name: "Two RS without delay and invalid zero interval", linkHeaderLen: 3, + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 2, rtrSolicitInt: 0, effectiveRtrSolicitInt: 4 * time.Second, @@ -3427,6 +3442,8 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Three RS without delay", + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 3, rtrSolicitInt: 500 * time.Millisecond, effectiveRtrSolicitInt: 500 * time.Millisecond, @@ -3435,6 +3452,8 @@ func TestRouterSolicitation(t *testing.T) { }, { name: "Two RS with invalid negative delay", + linkAddr: linkAddr1, + expectedSrcAddr: header.IPv6Any, maxRtrSolicit: 2, rtrSolicitInt: time.Second, effectiveRtrSolicitInt: time.Second, @@ -3457,7 +3476,7 @@ func TestRouterSolicitation(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1), + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), headerLength: test.linkHeaderLen, } e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -3481,10 +3500,10 @@ func TestRouterSolicitation(t *testing.T) { checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), + checker.SrcAddr(test.expectedSrcAddr), checker.DstAddr(header.IPv6AllRoutersMulticastAddress), checker.TTL(header.NDPHopLimit), - checker.NDPRS(), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), ) if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want { @@ -3510,13 +3529,19 @@ func TestRouterSolicitation(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - // Make sure each RS got sent at the right - // times. + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) + } + } + + // Make sure each RS is sent at the right time. remaining := test.maxRtrSolicit if remaining > 0 { waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) remaining-- } + for ; remaining > 0; remaining-- { waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout) waitForPkt(defaultAsyncEventTimeout) -- cgit v1.2.3 From d5dbe366bf7c9f5b648b8114a9dc7f45589899b1 Mon Sep 17 00:00:00 2001 From: Eyal Soha Date: Fri, 6 Mar 2020 11:41:10 -0800 Subject: shutdown(s, SHUT_WR) in TIME-WAIT returns ENOTCONN From RFC 793 s3.9 p61 Event Processing: CLOSE Call during TIME-WAIT: return with "error: connection closing" Fixes #1603 PiperOrigin-RevId: 299401353 --- pkg/tcpip/transport/tcp/endpoint.go | 5 ++++- test/syscalls/linux/tcp_socket.cc | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 40cc664c0..dc9c18b6f 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2117,10 +2117,13 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // Close for write. if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 { e.sndBufMu.Lock() - if e.sndClosed { // Already closed. e.sndBufMu.Unlock() + if e.EndpointState() == StateTimeWait { + e.mu.Unlock() + return tcpip.ErrNotConnected + } break } diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 579463384..d9c1ac0e1 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -143,6 +143,20 @@ TEST_P(TcpSocketTest, ConnectOnEstablishedConnection) { SyscallFailsWithErrno(EISCONN)); } +TEST_P(TcpSocketTest, ShutdownWriteInTimeWait) { + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); + EXPECT_THAT(shutdown(s_, SHUT_RDWR), SyscallSucceeds()); + absl::SleepFor(absl::Seconds(1)); // Wait to enter TIME_WAIT. + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(TcpSocketTest, ShutdownWriteInFinWait1) { + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); + absl::SleepFor(absl::Seconds(1)); // Wait to enter FIN-WAIT2. + EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds()); +} + TEST_P(TcpSocketTest, DataCoalesced) { char buf[10]; -- cgit v1.2.3 From 6fa5cee82c0f515b001dee5f3840e1f875b2f477 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Fri, 6 Mar 2020 12:30:37 -0800 Subject: Prevent memory leaks in ilist When list elements are removed from a list but not discarded, it becomes important to invalidate the references they hold to their former neighbors to prevent memory leaks. PiperOrigin-RevId: 299412421 --- pkg/ilist/list.go | 8 ++++++-- pkg/sentry/fs/dirent_cache.go | 2 -- pkg/sentry/fs/inotify.go | 5 ++++- pkg/sentry/kernel/epoll/epoll_state.go | 13 ++++++++----- pkg/tcpip/network/fragmentation/fragmentation.go | 8 +++++--- 5 files changed, 23 insertions(+), 13 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go index f3a609b57..8f93e4d6d 100644 --- a/pkg/ilist/list.go +++ b/pkg/ilist/list.go @@ -169,8 +169,9 @@ func (l *List) InsertBefore(a, e Element) { // Remove removes e from l. func (l *List) Remove(e Element) { - prev := ElementMapper{}.linkerFor(e).Prev() - next := ElementMapper{}.linkerFor(e).Next() + linker := ElementMapper{}.linkerFor(e) + prev := linker.Prev() + next := linker.Next() if prev != nil { ElementMapper{}.linkerFor(prev).SetNext(next) @@ -183,6 +184,9 @@ func (l *List) Remove(e Element) { } else { l.tail = prev } + + linker.SetNext(nil) + linker.SetPrev(nil) } // Entry is a default implementation of Linker. Users can add anonymous fields diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go index 25514ace4..33de32c69 100644 --- a/pkg/sentry/fs/dirent_cache.go +++ b/pkg/sentry/fs/dirent_cache.go @@ -101,8 +101,6 @@ func (c *DirentCache) remove(d *Dirent) { panic(fmt.Sprintf("trying to remove %v, which is not in the dirent cache", d)) } c.list.Remove(d) - d.SetPrev(nil) - d.SetNext(nil) d.DecRef() c.currentSize-- if c.limit != nil { diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go index 928c90aa0..e3a715c1f 100644 --- a/pkg/sentry/fs/inotify.go +++ b/pkg/sentry/fs/inotify.go @@ -143,7 +143,10 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i } var writeLen int64 - for event := i.events.Front(); event != nil; event = event.Next() { + for it := i.events.Front(); it != nil; { + event := it + it = it.Next() + // Does the buffer have enough remaining space to hold the event we're // about to write out? if dst.NumBytes() < int64(event.sizeOf()) { diff --git a/pkg/sentry/kernel/epoll/epoll_state.go b/pkg/sentry/kernel/epoll/epoll_state.go index a0d35d350..8e9f200d0 100644 --- a/pkg/sentry/kernel/epoll/epoll_state.go +++ b/pkg/sentry/kernel/epoll/epoll_state.go @@ -38,11 +38,14 @@ func (e *EventPoll) afterLoad() { } } - for it := e.waitingList.Front(); it != nil; it = it.Next() { - if it.id.File.Readiness(it.mask) != 0 { - e.waitingList.Remove(it) - e.readyList.PushBack(it) - it.curList = &e.readyList + for it := e.waitingList.Front(); it != nil; { + entry := it + it = it.Next() + + if entry.id.File.Readiness(entry.mask) != 0 { + e.waitingList.Remove(entry) + e.readyList.PushBack(entry) + entry.curList = &e.readyList e.Notify(waiter.EventIn) } } diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 92f2aa13a..f42abc4bb 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -115,10 +115,12 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf // Evict reassemblers if we are consuming more memory than highLimit until // we reach lowLimit. if f.size > f.highLimit { - tail := f.rList.Back() - for f.size > f.lowLimit && tail != nil { + for f.size > f.lowLimit { + tail := f.rList.Back() + if tail == nil { + break + } f.release(tail) - tail = tail.Prev() } } f.mu.Unlock() -- cgit v1.2.3