diff options
Diffstat (limited to 'pkg/tcpip')
66 files changed, 1527 insertions, 642 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 23e4b09e7..26f7ba86b 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -12,7 +11,6 @@ go_library( "time_unsafe.go", "timer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -25,7 +23,7 @@ go_test( name = "tcpip_test", size = "small", srcs = ["tcpip_test.go"], - embed = [":tcpip"], + library = ":tcpip", ) go_test( diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD index 3df7d18d3..a984f1712 100644 --- a/pkg/tcpip/adapters/gonet/BUILD +++ b/pkg/tcpip/adapters/gonet/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "gonet", srcs = ["gonet.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -23,7 +21,7 @@ go_test( name = "gonet_test", size = "small", srcs = ["gonet_test.go"], - embed = [":gonet"], + library = ":gonet", deps = [ "//pkg/tcpip", "//pkg/tcpip/header", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index a2f44b496..711969b9b 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -556,6 +556,17 @@ type PacketConn struct { wq *waiter.Queue } +// NewPacketConn creates a new PacketConn. +func NewPacketConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *PacketConn { + c := &PacketConn{ + stack: s, + ep: ep, + wq: wq, + } + c.deadlineTimer.init() + return c +} + // DialUDP creates a new PacketConn. // // If laddr is nil, a local address is automatically chosen. @@ -580,12 +591,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw } } - c := PacketConn{ - stack: s, - ep: ep, - wq: &wq, - } - c.deadlineTimer.init() + c := NewPacketConn(s, &wq, ep) if raddr != nil { if err := c.ep.Connect(*raddr); err != nil { @@ -599,7 +605,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw } } - return &c, nil + return c, nil } func (c *PacketConn) newOpError(op string, err error) *net.OpError { @@ -622,7 +628,7 @@ func (c *PacketConn) RemoteAddr() net.Addr { if err != nil { return nil } - return fullToTCPAddr(a) + return fullToUDPAddr(a) } // Read implements net.Conn.Read diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index d6c31bfa2..563bc78ea 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "prependable.go", "view.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/buffer", visibility = ["//visibility:public"], ) @@ -17,5 +15,5 @@ go_test( name = "buffer_test", size = "small", srcs = ["view_test.go"], - embed = [":buffer"], + library = ":buffer", ) diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index b6fa6fc37..ed434807f 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "checker", testonly = 1, srcs = ["checker.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/checker", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 885d773b0..4d6ae0871 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -771,6 +771,56 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { } } +// NDPNSOptions creates a checker that checks that the packet contains the +// provided NDP options within an NDP Neighbor Solicitation message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid NDPNS message as far as the size is concerned. +func NDPNSOptions(opts []header.NDPOption) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + it, err := ns.Options().Iter(true) + if err != nil { + t.Errorf("opts.Iter(true): %s", err) + return + } + + i := 0 + for { + opt, done, _ := it.Next() + if done { + break + } + + if i >= len(opts) { + t.Errorf("got unexpected option: %s", opt) + continue + } + + switch wantOpt := opts[i].(type) { + case header.NDPSourceLinkLayerAddressOption: + gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption) + if !ok { + t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt) + } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want { + t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want) + } + default: + panic("not implemented") + } + + i++ + } + + if missing := opts[i:]; len(missing) > 0 { + t.Errorf("missing options: %s", missing) + } + } +} + // NDPRS creates a checker that checks that the packet contains a valid NDP // Router Solicitation message (as per the raw wire format). func NDPRS() NetworkChecker { diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD index e648efa71..ff2719291 100644 --- a/pkg/tcpip/hash/jenkins/BUILD +++ b/pkg/tcpip/hash/jenkins/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "jenkins", srcs = ["jenkins.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins", visibility = ["//visibility:public"], ) @@ -16,5 +14,5 @@ go_test( srcs = [ "jenkins_test.go", ], - embed = [":jenkins"], + library = ":jenkins", ) diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index cd747d100..9da0d71f8 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -1,5 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -24,7 +23,6 @@ go_library( "tcp.go", "udp.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/header", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -59,7 +57,7 @@ go_test( "eth_test.go", "ndp_test.go", ], - embed = [":header"], + library = ":header", deps = [ "//pkg/tcpip", "@com_github_google_go-cmp//cmp:go_default_library", diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 9749c7f4d..14a4b2b44 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -45,12 +45,139 @@ func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { return ChecksumCombine(uint16(v), uint16(v>>16)), odd } +func unrolledCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) { + v := initial + + if odd { + v += uint32(buf[0]) + buf = buf[1:] + } + + l := len(buf) + odd = l&1 != 0 + if odd { + l-- + v += uint32(buf[l]) << 8 + } + for (l - 64) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + i += 16 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + i += 16 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + i += 16 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + buf = buf[64:] + l = l - 64 + } + if (l - 32) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + i += 16 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + buf = buf[32:] + l = l - 32 + } + if (l - 16) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9]) + v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11]) + v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13]) + v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15]) + buf = buf[16:] + l = l - 16 + } + if (l - 8) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5]) + v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7]) + buf = buf[8:] + l = l - 8 + } + if (l - 4) >= 0 { + i := 0 + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3]) + buf = buf[4:] + l = l - 4 + } + + // At this point since l was even before we started unrolling + // there can be only two bytes left to add. + if l != 0 { + v += (uint32(buf[0]) << 8) + uint32(buf[1]) + } + + return ChecksumCombine(uint16(v), uint16(v>>16)), odd +} + +// ChecksumOld calculates the checksum (as defined in RFC 1071) of the bytes in +// the given byte array. This function uses a non-optimized implementation. Its +// only retained for reference and to use as a benchmark/test. Most code should +// use the header.Checksum function. +// +// The initial checksum must have been computed on an even number of bytes. +func ChecksumOld(buf []byte, initial uint16) uint16 { + s, _ := calculateChecksum(buf, false, uint32(initial)) + return s +} + // Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the -// given byte array. +// given byte array. This function uses an optimized unrolled version of the +// checksum algorithm. // // The initial checksum must have been computed on an even number of bytes. func Checksum(buf []byte, initial uint16) uint16 { - s, _ := calculateChecksum(buf, false, uint32(initial)) + s, _ := unrolledCalculateChecksum(buf, false, uint32(initial)) return s } @@ -86,7 +213,7 @@ func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, siz } v = v[:l] - sum, odd = calculateChecksum(v, odd, uint32(sum)) + sum, odd = unrolledCalculateChecksum(v, odd, uint32(sum)) size -= len(v) if size == 0 { diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index 86b466c1c..309403482 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -17,6 +17,8 @@ package header_test import ( + "fmt" + "math/rand" "testing" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -107,3 +109,63 @@ func TestChecksumVVWithOffset(t *testing.T) { }) } } + +func TestChecksum(t *testing.T) { + var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024} + type testCase struct { + buf []byte + initial uint16 + csumOrig uint16 + csumNew uint16 + } + testCases := make([]testCase, 100000) + // Ensure same buffer generation for test consistency. + rnd := rand.New(rand.NewSource(42)) + for i := range testCases { + testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)]) + testCases[i].initial = uint16(rnd.Intn(65536)) + rnd.Read(testCases[i].buf) + } + + for i := range testCases { + testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial) + testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial) + if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want { + t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want) + } + } +} + +func BenchmarkChecksum(b *testing.B) { + var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536} + + checkSumImpls := []struct { + fn func([]byte, uint16) uint16 + name string + }{ + {header.ChecksumOld, fmt.Sprintf("checksum_old")}, + {header.Checksum, fmt.Sprintf("checksum")}, + } + + for _, csumImpl := range checkSumImpls { + // Ensure same buffer generation for test consistency. + rnd := rand.New(rand.NewSource(42)) + for _, bufSz := range bufSizes { + b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) { + tc := struct { + buf []byte + initial uint16 + csum uint16 + }{ + buf: make([]byte, bufSz), + initial: uint16(rnd.Intn(65536)), + } + rnd.Read(tc.buf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + tc.csum = csumImpl.fn(tc.buf, tc.initial) + } + }) + } + } +} diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index b4037b6c8..c7ee2de57 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -52,7 +52,7 @@ const ( // ICMPv6NeighborAdvertSize is size of a neighbor advertisement // including the NDP Target Link Layer option for an Ethernet // address. - ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + ndpTargetEthernetLinkLayerAddressSize + ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet. ICMPv6EchoMinimumSize = 8 diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 06e0bace2..e6a6ad39b 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -17,6 +17,7 @@ package header import ( "encoding/binary" "errors" + "fmt" "math" "time" @@ -24,13 +25,17 @@ import ( ) const ( - // NDPTargetLinkLayerAddressOptionType is the type of the Target - // Link-Layer Address option, as per RFC 4861 section 4.6.1. + // NDPSourceLinkLayerAddressOptionType is the type of the Source Link Layer + // Address option, as per RFC 4861 section 4.6.1. + NDPSourceLinkLayerAddressOptionType = 1 + + // NDPTargetLinkLayerAddressOptionType is the type of the Target Link Layer + // Address option, as per RFC 4861 section 4.6.1. NDPTargetLinkLayerAddressOptionType = 2 - // ndpTargetEthernetLinkLayerAddressSize is the size of a Target - // Link Layer Option for an Ethernet address. - ndpTargetEthernetLinkLayerAddressSize = 8 + // NDPLinkLayerAddressSize is the size of a Source or Target Link Layer + // Address option for an Ethernet address. + NDPLinkLayerAddressSize = 8 // NDPPrefixInformationType is the type of the Prefix Information // option, as per RFC 4861 section 4.6.2. @@ -189,6 +194,9 @@ func (i *NDPOptionIterator) Next() (NDPOption, bool, error) { i.opts = i.opts[numBytes:] switch t { + case NDPSourceLinkLayerAddressOptionType: + return NDPSourceLinkLayerAddressOption(body), false, nil + case NDPTargetLinkLayerAddressOptionType: return NDPTargetLinkLayerAddressOption(body), false, nil @@ -293,6 +301,8 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int { // NDPOption is the set of functions to be implemented by all NDP option types. type NDPOption interface { + fmt.Stringer + // Type returns the type of the receiver. Type() uint8 @@ -368,6 +378,46 @@ func (b NDPOptionsSerializer) Length() int { return l } +// NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option +// as defined by RFC 4861 section 4.6.1. +// +// It is the first X bytes following the NDP option's Type and Length field +// where X is the value in Length multiplied by lengthByteUnits - 2 bytes. +type NDPSourceLinkLayerAddressOption tcpip.LinkAddress + +// Type implements NDPOption.Type. +func (o NDPSourceLinkLayerAddressOption) Type() uint8 { + return NDPSourceLinkLayerAddressOptionType +} + +// Length implements NDPOption.Length. +func (o NDPSourceLinkLayerAddressOption) Length() int { + return len(o) +} + +// serializeInto implements NDPOption.serializeInto. +func (o NDPSourceLinkLayerAddressOption) serializeInto(b []byte) int { + return copy(b, o) +} + +// String implements fmt.Stringer.String. +func (o NDPSourceLinkLayerAddressOption) String() string { + return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o)) +} + +// EthernetAddress will return an ethernet (MAC) address if the +// NDPSourceLinkLayerAddressOption's body has at minimum EthernetAddressSize +// bytes. If the body has more than EthernetAddressSize bytes, only the first +// EthernetAddressSize bytes are returned as that is all that is needed for an +// Ethernet address. +func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress { + if len(o) >= EthernetAddressSize { + return tcpip.LinkAddress(o[:EthernetAddressSize]) + } + + return tcpip.LinkAddress([]byte(nil)) +} + // NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option // as defined by RFC 4861 section 4.6.1. // @@ -390,6 +440,11 @@ func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int { return copy(b, o) } +// String implements fmt.Stringer.String. +func (o NDPTargetLinkLayerAddressOption) String() string { + return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o)) +} + // EthernetAddress will return an ethernet (MAC) address if the // NDPTargetLinkLayerAddressOption's body has at minimum EthernetAddressSize // bytes. If the body has more than EthernetAddressSize bytes, only the first @@ -436,6 +491,17 @@ func (o NDPPrefixInformation) serializeInto(b []byte) int { return used } +// String implements fmt.Stringer.String. +func (o NDPPrefixInformation) String() string { + return fmt.Sprintf("%T(O=%t, A=%t, PL=%s, VL=%s, Prefix=%s)", + o, + o.OnLinkFlag(), + o.AutonomousAddressConfigurationFlag(), + o.PreferredLifetime(), + o.ValidLifetime(), + o.Subnet()) +} + // PrefixLength returns the value in the number of leading bits in the Prefix // that are valid. // @@ -545,6 +611,11 @@ func (o NDPRecursiveDNSServer) serializeInto(b []byte) int { return used } +// String implements fmt.Stringer.String. +func (o NDPRecursiveDNSServer) String() string { + return fmt.Sprintf("%T(%s valid for %s)", o, o.Addresses(), o.Lifetime()) +} + // Lifetime returns the length of time that the DNS server addresses // in this option may be used for name resolution. // diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index 2c439d70c..1cb9f5dc8 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -153,6 +153,125 @@ func TestNDPRouterAdvert(t *testing.T) { } } +// TestNDPSourceLinkLayerAddressOptionEthernetAddress tests getting the +// Ethernet address from an NDPSourceLinkLayerAddressOption. +func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) { + tests := []struct { + name string + buf []byte + expected tcpip.LinkAddress + }{ + { + "ValidMAC", + []byte{1, 2, 3, 4, 5, 6}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + { + "SLLBodyTooShort", + []byte{1, 2, 3, 4, 5}, + tcpip.LinkAddress([]byte(nil)), + }, + { + "SLLBodyLargerThanNeeded", + []byte{1, 2, 3, 4, 5, 6, 7, 8}, + tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + sll := NDPSourceLinkLayerAddressOption(test.buf) + if got := sll.EthernetAddress(); got != test.expected { + t.Errorf("got sll.EthernetAddress = %s, want = %s", got, test.expected) + } + }) + } +} + +// TestNDPSourceLinkLayerAddressOptionSerialize tests serializing a +// NDPSourceLinkLayerAddressOption. +func TestNDPSourceLinkLayerAddressOptionSerialize(t *testing.T) { + tests := []struct { + name string + buf []byte + expectedBuf []byte + addr tcpip.LinkAddress + }{ + { + "Ethernet", + make([]byte, 8), + []byte{1, 1, 1, 2, 3, 4, 5, 6}, + "\x01\x02\x03\x04\x05\x06", + }, + { + "Padding", + []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0}, + "\x01\x02\x03\x04\x05\x06\x07\x08", + }, + { + "Empty", + nil, + nil, + "", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + opts := NDPOptions(test.buf) + serializer := NDPOptionsSerializer{ + NDPSourceLinkLayerAddressOption(test.addr), + } + if got, want := int(serializer.Length()), len(test.expectedBuf); got != want { + t.Fatalf("got Length = %d, want = %d", got, want) + } + opts.Serialize(serializer) + if !bytes.Equal(test.buf, test.expectedBuf) { + t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf) + } + + it, err := opts.Iter(true) + if err != nil { + t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) + } + + if len(test.expectedBuf) > 0 { + next, done, err := it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType { + t.Fatalf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType) + } + sll := next.(NDPSourceLinkLayerAddressOption) + if got, want := []byte(sll), test.expectedBuf[2:]; !bytes.Equal(got, want) { + t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + + if got, want := sll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { + t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want) + } + } + + // Iterator should not return anything else. + next, done, err := it.Next() + if err != nil { + t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if !done { + t.Error("got Next = (_, false, _), want = (_, true, _)") + } + if next != nil { + t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next) + } + }) + } +} + // TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the // Ethernet address from an NDPTargetLinkLayerAddressOption. func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { @@ -186,7 +305,6 @@ func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) { } }) } - } // TestNDPTargetLinkLayerAddressOptionSerialize tests serializing a @@ -212,8 +330,8 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { }, { "Empty", - []byte{}, - []byte{}, + nil, + nil, "", }, } @@ -246,7 +364,7 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { t.Fatal("got Next = (_, true, _), want = (_, false, _)") } if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { - t.Fatalf("got Type %= %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) + t.Fatalf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType) } tll := next.(NDPTargetLinkLayerAddressOption) if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) { @@ -254,7 +372,7 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) { } if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want { - t.Errorf("got tll.MACAddress = %s, want = %s", got, want) + t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want) } } @@ -510,7 +628,7 @@ func TestNDPRecursiveDNSServerOption(t *testing.T) { t.Fatal("got Next = (_, true, _), want = (_, false, _)") } if got := next.Type(); got != NDPRecursiveDNSServerOptionType { - t.Fatalf("got Type %= %d, want = %d", got, NDPRecursiveDNSServerOptionType) + t.Fatalf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType) } opt, ok := next.(NDPRecursiveDNSServer) @@ -553,6 +671,16 @@ func TestNDPOptionsIterCheck(t *testing.T) { ErrNDPOptZeroLength, }, { + "ValidSourceLinkLayerAddressOption", + []byte{1, 1, 1, 2, 3, 4, 5, 6}, + nil, + }, + { + "TooSmallSourceLinkLayerAddressOption", + []byte{1, 1, 1, 2, 3, 4, 5}, + ErrNDPOptBufExhausted, + }, + { "ValidTargetLinkLayerAddressOption", []byte{2, 1, 1, 2, 3, 4, 5, 6}, nil, @@ -603,10 +731,13 @@ func TestNDPOptionsIterCheck(t *testing.T) { ErrNDPOptMalformedBody, }, { - "ValidTargetLinkLayerAddressWithPrefixInformation", + "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation", []byte{ + // Source Link-Layer Address. + 1, 1, 1, 2, 3, 4, 5, 6, + // Target Link-Layer Address. - 2, 1, 1, 2, 3, 4, 5, 6, + 2, 1, 7, 8, 9, 10, 11, 12, // Prefix information. 3, 4, 43, 64, @@ -621,10 +752,13 @@ func TestNDPOptionsIterCheck(t *testing.T) { nil, }, { - "ValidTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", + "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized", []byte{ + // Source Link-Layer Address. + 1, 1, 1, 2, 3, 4, 5, 6, + // Target Link-Layer Address. - 2, 1, 1, 2, 3, 4, 5, 6, + 2, 1, 7, 8, 9, 10, 11, 12, // 255 is an unrecognized type. If 255 ends up // being the type for some recognized type, @@ -714,8 +848,11 @@ func TestNDPOptionsIterCheck(t *testing.T) { // here. func TestNDPOptionsIter(t *testing.T) { buf := []byte{ + // Source Link-Layer Address. + 1, 1, 1, 2, 3, 4, 5, 6, + // Target Link-Layer Address. - 2, 1, 1, 2, 3, 4, 5, 6, + 2, 1, 7, 8, 9, 10, 11, 12, // 255 is an unrecognized type. If 255 ends up being the type // for some recognized type, update 255 to some other @@ -740,7 +877,7 @@ func TestNDPOptionsIter(t *testing.T) { t.Fatalf("got Iter = (_, %s), want = (_, nil)", err) } - // Test the first (Taret Link-Layer) option. + // Test the first (Source Link-Layer) option. next, done, err := it.Next() if err != nil { t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) @@ -748,7 +885,22 @@ func TestNDPOptionsIter(t *testing.T) { if done { t.Fatal("got Next = (_, true, _), want = (_, false, _)") } - if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) { + if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) { + t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) + } + if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType { + t.Errorf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType) + } + + // Test the next (Target Link-Layer) option. + next, done, err = it.Next() + if err != nil { + t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err) + } + if done { + t.Fatal("got Next = (_, true, _), want = (_, false, _)") + } + if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) { t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) } if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType { @@ -764,7 +916,7 @@ func TestNDPOptionsIter(t *testing.T) { if done { t.Fatal("got Next = (_, true, _), want = (_, false, _)") } - if got, want := next.(NDPPrefixInformation), buf[26:][:30]; !bytes.Equal(got, want) { + if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) { t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want) } if got := next.Type(); got != NDPPrefixInformationType { diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD index 2893c80cd..d1b73cfdf 100644 --- a/pkg/tcpip/iptables/BUILD +++ b/pkg/tcpip/iptables/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -9,10 +9,10 @@ go_library( "targets.go", "types.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables", visibility = ["//visibility:public"], deps = [ "//pkg/log", "//pkg/tcpip", + "//pkg/tcpip/header", ], ) diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index 605a71679..4bfb3149e 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor authors. +// Copyright 2019 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" ) // Table names. @@ -184,8 +185,16 @@ func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename stri panic(fmt.Sprintf("Traversed past the entire list of iptables rules in table %q.", tablename)) } +// Precondition: pk.NetworkHeader is set. func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) Verdict { rule := table.Rules[ruleIdx] + + // First check whether the packet matches the IP header filter. + // TODO(gvisor.dev/issue/170): Support other fields of the filter. + if rule.Filter.Protocol != 0 && rule.Filter.Protocol != header.IPv4(pkt.NetworkHeader).TransportProtocol() { + return Continue + } + // Go through each rule matcher. If they all match, run // the rule target. for _, matcher := range rule.Matchers { diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 9f6906100..50893cc55 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor authors. +// Copyright 2019 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,7 +14,9 @@ package iptables -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) // A Hook specifies one of the hooks built into the network stack. // @@ -151,6 +153,9 @@ func (table *Table) SetMetadata(metadata interface{}) { // packets this rule applies to. If there are no matchers in the rule, it // applies to any packet. type Rule struct { + // Filter holds basic IP filtering fields common to every rule. + Filter IPHeaderFilter + // Matchers is the list of matchers for this rule. Matchers []Matcher @@ -158,6 +163,12 @@ type Rule struct { Target Target } +// IPHeaderFilter holds basic IP filtering data common to every rule. +type IPHeaderFilter struct { + // Protocol matches the transport protocol. + Protocol tcpip.TransportProtocolNumber +} + // A Matcher is the interface for matching packets. type Matcher interface { // Match returns whether the packet matches and whether the packet diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 7dbc05754..3974c464e 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "channel", srcs = ["channel.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 70188551f..71b9da797 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -18,6 +18,8 @@ package channel import ( + "context" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -38,25 +40,52 @@ type Endpoint struct { linkAddr tcpip.LinkAddress GSO bool - // C is where outbound packets are queued. - C chan PacketInfo + // c is where outbound packets are queued. + c chan PacketInfo } // New creates a new channel endpoint. func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint { return &Endpoint{ - C: make(chan PacketInfo, size), + c: make(chan PacketInfo, size), mtu: mtu, linkAddr: linkAddr, } } +// Close closes e. Further packet injections will panic. Reads continue to +// succeed until all packets are read. +func (e *Endpoint) Close() { + close(e.c) +} + +// Read does non-blocking read for one packet from the outbound packet queue. +func (e *Endpoint) Read() (PacketInfo, bool) { + select { + case pkt := <-e.c: + return pkt, true + default: + return PacketInfo{}, false + } +} + +// ReadContext does blocking read for one packet from the outbound packet queue. +// It can be cancelled by ctx, and in this case, it returns false. +func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) { + select { + case pkt := <-e.c: + return pkt, true + case <-ctx.Done(): + return PacketInfo{}, false + } +} + // Drain removes all outbound packets from the channel and counts them. func (e *Endpoint) Drain() int { c := 0 for { select { - case <-e.C: + case <-e.c: c++ default: return c @@ -125,7 +154,7 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, protocol tcpip.Ne } select { - case e.C <- p: + case e.c <- p: default: } @@ -150,7 +179,7 @@ packetLoop: } select { - case e.C <- p: + case e.c <- p: n++ default: break packetLoop @@ -169,7 +198,7 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { } select { - case e.C <- p: + case e.c <- p: default: } diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index 66cc53ed4..abe725548 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -13,7 +12,6 @@ go_library( "mmap_unsafe.go", "packet_dispatchers.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -30,7 +28,7 @@ go_test( name = "fdbased_test", size = "small", srcs = ["endpoint_test.go"], - embed = [":fdbased"], + library = ":fdbased", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD index f35fcdff4..6bf3805b7 100644 --- a/pkg/tcpip/link/loopback/BUILD +++ b/pkg/tcpip/link/loopback/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "loopback", srcs = ["loopback.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 1ac7948b6..82b441b79 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "muxed", srcs = ["injectable.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -19,7 +17,7 @@ go_test( name = "muxed_test", size = "small", srcs = ["injectable_test.go"], - embed = [":muxed"], + library = ":muxed", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index d8211e93d..14b527bc2 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -12,7 +12,6 @@ go_library( "errors.go", "rawfile_unsafe.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD index 09165dd4c..13243ebbb 100644 --- a/pkg/tcpip/link/sharedmem/BUILD +++ b/pkg/tcpip/link/sharedmem/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "sharedmem_unsafe.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -30,7 +28,7 @@ go_test( srcs = [ "sharedmem_test.go", ], - embed = [":sharedmem"], + library = ":sharedmem", deps = [ "//pkg/sync", "//pkg/tcpip", diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD index a0d4ad0be..87020ec08 100644 --- a/pkg/tcpip/link/sharedmem/pipe/BUILD +++ b/pkg/tcpip/link/sharedmem/pipe/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -11,7 +10,6 @@ go_library( "rx.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe", visibility = ["//visibility:public"], ) @@ -20,6 +18,6 @@ go_test( srcs = [ "pipe_test.go", ], - embed = [":pipe"], + library = ":pipe", deps = ["//pkg/sync"], ) diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD index 8c9234d54..3ba06af73 100644 --- a/pkg/tcpip/link/sharedmem/queue/BUILD +++ b/pkg/tcpip/link/sharedmem/queue/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "rx.go", "tx.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -22,7 +20,7 @@ go_test( srcs = [ "queue_test.go", ], - embed = [":queue"], + library = ":queue", deps = [ "//pkg/tcpip/link/sharedmem/pipe", ], diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index d6ae0368a..230a8d53a 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -8,7 +8,6 @@ go_library( "pcap.go", "sniffer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer", visibility = ["//visibility:public"], deps = [ "//pkg/log", diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index a71a493fc..e5096ea38 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -1,10 +1,9 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "tun", srcs = ["tun_unsafe.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun", visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 134837943..0956d2c65 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -8,7 +7,6 @@ go_library( srcs = [ "waitable.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable", visibility = ["//visibility:public"], deps = [ "//pkg/gate", @@ -23,7 +21,7 @@ go_test( srcs = [ "waitable_test.go", ], - embed = [":waitable"], + library = ":waitable", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 9d16ff8c9..6a4839fb8 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_test") package(licenses = ["notice"]) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index e7617229b..eddf7b725 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "arp", srcs = ["arp.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 8e6048a21..03cf03b6d 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -15,6 +15,7 @@ package arp_test import ( + "context" "strconv" "testing" "time" @@ -83,7 +84,7 @@ func newTestContext(t *testing.T) *testContext { } func (c *testContext) cleanup() { - close(c.linkEP.C) + c.linkEP.Close() } func TestDirectRequest(t *testing.T) { @@ -110,7 +111,7 @@ func TestDirectRequest(t *testing.T) { for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { t.Run(strconv.Itoa(i), func(t *testing.T) { inject(address) - pi := <-c.linkEP.C + pi, _ := c.linkEP.ReadContext(context.Background()) if pi.Proto != arp.ProtocolNumber { t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) } @@ -134,12 +135,11 @@ func TestDirectRequest(t *testing.T) { } inject(stackAddrBad) - select { - case pkt := <-c.linkEP.C: + // Sleep tests are gross, but this will only potentially flake + // if there's a bug. If there is no bug this will reliably + // succeed. + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + if pkt, ok := c.linkEP.ReadContext(ctx); ok { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) - case <-time.After(100 * time.Millisecond): - // Sleep tests are gross, but this will only potentially flake - // if there's a bug. If there is no bug this will reliably - // succeed. } } diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index ed16076fd..d1c728ccf 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -24,7 +23,6 @@ go_library( "reassembler.go", "reassembler_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation", visibility = ["//visibility:public"], deps = [ "//pkg/log", @@ -42,6 +40,6 @@ go_test( "fragmentation_test.go", "reassembler_test.go", ], - embed = [":fragmentation"], + library = ":fragmentation", deps = ["//pkg/tcpip/buffer"], ) diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD index e6db5c0b0..872165866 100644 --- a/pkg/tcpip/network/hash/BUILD +++ b/pkg/tcpip/network/hash/BUILD @@ -1,11 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "hash", srcs = ["hash.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/hash", visibility = ["//visibility:public"], deps = [ "//pkg/rand", diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 4e2aae9a3..0fef2b1f1 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "icmp.go", "ipv4.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 0a1453b31..85512f9b2 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -353,7 +353,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt tcpip.PacketBuffer) { } pkt.NetworkHeader = headerView[:h.HeaderLength()] - // iptables filtering. + // iptables filtering. All packets that reach here are intended for + // this machine and will not be forwarded. ipt := e.stack.IPTables() if ok := ipt.Check(iptables.Input, pkt); !ok { // iptables is telling us to drop the packet. diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index e4e273460..fb11874c6 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -1,5 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -9,7 +8,6 @@ go_library( "icmp.go", "ipv6.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", @@ -27,7 +25,7 @@ go_test( "ipv6_test.go", "ndp_test.go", ], - embed = [":ipv6"], + library = ":ipv6", deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 1c3410618..dc20c0fd7 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -137,21 +137,24 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P } ns := header.NDPNeighborSolicit(h.NDPPayload()) + it, err := ns.Options().Iter(true) + if err != nil { + // If we have a malformed NDP NS option, drop the packet. + received.Invalid.Increment() + return + } + targetAddr := ns.TargetAddress() s := r.Stack() rxNICID := r.NICID() - - isTentative, err := s.IsAddrTentative(rxNICID, targetAddr) - if err != nil { + if isTentative, err := s.IsAddrTentative(rxNICID, targetAddr); err != nil { // We will only get an error if rxNICID is unrecognized, // which should not happen. For now short-circuit this // packet. // // TODO(b/141002840): Handle this better? return - } - - if isTentative { + } else if isTentative { // If the target address is tentative and the source // of the packet is a unicast (specified) address, then // the source of the packet is attempting to perform @@ -185,6 +188,23 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P return } + // If the NS message has the source link layer option, update the link + // address cache with the link address for the sender of the message. + // + // TODO(b/148429853): Properly process the NS message and do Neighbor + // Unreachability Detection. + for { + opt, done, _ := it.Next() + if done { + break + } + + switch opt := opt.(type) { + case header.NDPSourceLinkLayerAddressOption: + e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, opt.EthernetAddress()) + } + } + optsSerializer := header.NDPOptionsSerializer{ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]), } @@ -211,15 +231,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, pkt tcpip.P r.LocalAddress = targetAddr packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - // TODO(tamird/ghanan): there exists an explicit NDP option that is - // used to update the neighbor table with link addresses for a - // neighbor from an NS (see the Source Link Layer option RFC - // 4861 section 4.6.1 and section 7.2.3). - // - // Furthermore, the entirety of NDP handling here seems to be - // contradicted by RFC 4861. - e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, r.RemoteLinkAddress) - // RFC 4861 Neighbor Discovery for IP version 6 (IPv6) // // 7.1.2. Validation of Neighbor Advertisements diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index a2fdc5dcd..7a6820643 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "context" "reflect" "strings" "testing" @@ -264,8 +265,8 @@ func newTestContext(t *testing.T) *testContext { } func (c *testContext) cleanup() { - close(c.linkEP0.C) - close(c.linkEP1.C) + c.linkEP0.Close() + c.linkEP1.Close() } type routeArgs struct { @@ -276,7 +277,7 @@ type routeArgs struct { func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pi := <-args.src.C + pi, _ := args.src.ReadContext(context.Background()) { views := []buffer.View{pi.Pkt.Header.View(), pi.Pkt.Data.ToView()} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index fe895b376..bd732f93f 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -70,6 +70,141 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack return s, ep } +// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving an +// NDP NS message with the Source Link Layer Address option results in a +// new entry in the link address cache for the sender of the message. +func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr0) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr0) + ns.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) + if err != nil { + t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) + } + if c != nil { + t.Errorf("got unexpected channel") + } + if linkAddr != linkAddr1 { + t.Errorf("got link address = %s, want = %s", linkAddr, linkAddr1) + } +} + +// TestNeighorSolicitationWithInvalidSourceLinkLayerOption tests that receiving +// an NDP NS message with an invalid Source Link Layer Address option does not +// result in a new entry in the link address cache for the sender of the +// message. +func TestNeighorSolicitationWithInvalidSourceLinkLayerOption(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + optsBuf []byte + }{ + { + name: "Too Small", + optsBuf: []byte{1, 1, 1, 2, 3, 4, 5}, + }, + { + name: "Invalid Length", + optsBuf: []byte{1, 2, 1, 2, 3, 4, 5, 6}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + e := channel.New(0, 1280, linkAddr0) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr0) + opts := ns.Options() + copy(opts, test.optsBuf) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(ProtocolNumber, tcpip.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + + linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) + } + if c == nil { + t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) + } + if linkAddr != "" { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (%s, _, ), want = ('', _, _)", nicID, lladdr1, lladdr0, ProtocolNumber, linkAddr) + } + }) + } +} + // TestHopLimitValidation is a test that makes sure that NDP packets are only // received if their IP header's hop limit is set to 255. func TestHopLimitValidation(t *testing.T) { diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index a6ef3bdcc..2bad05a2e 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -1,12 +1,10 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "ports", srcs = ["ports.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/ports", visibility = ["//visibility:public"], deps = [ "//pkg/sync", @@ -17,7 +15,7 @@ go_library( go_test( name = "ports_test", srcs = ["ports_test.go"], - embed = [":ports"], + library = ":ports", deps = [ "//pkg/tcpip", ], diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index d7496fde6..cf0a5fefe 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 2239c1e66..0ab089208 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -164,7 +164,7 @@ func main() { // Create TCP endpoint. var wq waiter.Queue ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) - if err != nil { + if e != nil { log.Fatal(e) } diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD index 875561566..43264b76d 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/BUILD +++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_binary") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index bca73cbb1..9e37cab18 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -168,7 +168,7 @@ func main() { // Create TCP endpoint, bind it, then start listening. var wq waiter.Queue ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) - if err != nil { + if e != nil { log.Fatal(e) } diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD index b31ddba2f..45f503845 100644 --- a/pkg/tcpip/seqnum/BUILD +++ b/pkg/tcpip/seqnum/BUILD @@ -1,10 +1,9 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "seqnum", srcs = ["seqnum.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum", visibility = ["//visibility:public"], ) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 783351a69..f5b750046 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -30,7 +29,6 @@ go_library( "stack_global_state.go", "transport_demuxer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/stack", visibility = ["//visibility:public"], deps = [ "//pkg/ilist", @@ -81,7 +79,7 @@ go_test( name = "stack_test", size = "small", srcs = ["linkaddrcache_test.go"], - embed = [":stack"], + library = ":stack", deps = [ "//pkg/sleep", "//pkg/sync", diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 7d4b41dfa..31294345d 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -15,7 +15,6 @@ package stack import ( - "fmt" "log" "math/rand" "time" @@ -168,8 +167,8 @@ type NDPDispatcher interface { // reason, such as the address being removed). If an error occured // during DAD, err will be set and resolved must be ignored. // - // This function is permitted to block indefinitely without interfering - // with the stack's operation. + // This function is not permitted to block indefinitely. This function + // is also not permitted to call into the stack. OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) // OnDefaultRouterDiscovered will be called when a new default router is @@ -429,8 +428,13 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref return tcpip.ErrAddressFamilyNotSupported } - // Should not attempt to perform DAD on an address that is currently in - // the DAD process. + if ref.getKind() != permanentTentative { + // The endpoint should be marked as tentative since we are starting DAD. + log.Fatalf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()) + } + + // Should not attempt to perform DAD on an address that is currently in the + // DAD process. if _, ok := ndp.dad[addr]; ok { // Should never happen because we should only ever call this function for // newly created addresses. If we attemped to "add" an address that already @@ -438,77 +442,79 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref // address, or its reference count would have been increased without doing // the work that would have been done for an address that was brand new. // See NIC.addAddressLocked. - panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID())) + log.Fatalf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()) } remaining := ndp.configs.DupAddrDetectTransmits - - { - done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref) - if err != nil { - return err - } - if done { - return nil - } + if remaining == 0 { + ref.setKind(permanent) + return nil } - remaining-- - var done bool var timer *time.Timer - timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() { - var d bool - var err *tcpip.Error - - // doDadIteration does a single iteration of the DAD loop. - // - // Returns true if the integrator needs to be informed of DAD - // completing. - doDadIteration := func() bool { - ndp.nic.mu.Lock() - defer ndp.nic.mu.Unlock() - - if done { - // If we reach this point, it means that the DAD - // timer fired after another goroutine already - // obtained the NIC lock and stopped DAD before - // this function obtained the NIC lock. Simply - // return here and do nothing further. - return false - } + // We initially start a timer to fire immediately because some of the DAD work + // cannot be done while holding the NIC's lock. This is effectively the same + // as starting a goroutine but we use a timer that fires immediately so we can + // reset it for the next DAD iteration. + timer = time.AfterFunc(0, func() { + ndp.nic.mu.RLock() + if done { + // If we reach this point, it means that the DAD timer fired after + // another goroutine already obtained the NIC lock and stopped DAD + // before this function obtained the NIC lock. Simply return here and do + // nothing further. + ndp.nic.mu.RUnlock() + return + } - ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}] - if !ok { - // This should never happen. - // We should have an endpoint for addr since we - // are still performing DAD on it. If the - // endpoint does not exist, but we are doing DAD - // on it, then we started DAD at some point, but - // forgot to stop it when the endpoint was - // deleted. - panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID())) - } + if ref.getKind() != permanentTentative { + // The endpoint should still be marked as tentative since we are still + // performing DAD on it. + log.Fatalf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID()) + } - d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref) - if err != nil || d { - delete(ndp.dad, addr) + dadDone := remaining == 0 + ndp.nic.mu.RUnlock() - if err != nil { - log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) - } + var err *tcpip.Error + if !dadDone { + err = ndp.sendDADPacket(addr) + } - // Let the integrator know DAD has completed. - return true - } + ndp.nic.mu.Lock() + if done { + // If we reach this point, it means that DAD was stopped after we released + // the NIC's read lock and before we obtained the write lock. + ndp.nic.mu.Unlock() + return + } + if dadDone { + // DAD has resolved. + ref.setKind(permanent) + } else if err == nil { + // DAD is not done and we had no errors when sending the last NDP NS, + // schedule the next DAD timer. remaining-- timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer) - return false + + ndp.nic.mu.Unlock() + return } - if doDadIteration() && ndp.nic.stack.ndpDisp != nil { - ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err) + // At this point we know that either DAD is done or we hit an error sending + // the last NDP NS. Either way, clean up addr's DAD state and let the + // integrator know DAD has completed. + delete(ndp.dad, addr) + ndp.nic.mu.Unlock() + + if err != nil { + log.Printf("ndpdad: error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err) + } + + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err) } }) @@ -520,45 +526,16 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref return nil } -// doDuplicateAddressDetection is called on every iteration of the timer, and -// when DAD starts. -// -// It handles resolving the address (if there are no more NS to send), or -// sending the next NS if there are more NS to send. -// -// This function must only be called by IPv6 addresses that are currently -// tentative. -// -// The NIC that ndp belongs to (n) MUST be locked. +// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns +// addr. // -// Returns true if DAD has resolved; false if DAD is still ongoing. -func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) { - if ref.getKind() != permanentTentative { - // The endpoint should still be marked as tentative - // since we are still performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID())) - } - - if remaining == 0 { - // DAD has resolved. - ref.setKind(permanent) - return true, nil - } - - // Send a new NS. +// addr must be a tentative IPv6 address on ndp's NIC. +func (ndp *ndpState) sendDADPacket(addr tcpip.Address) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}] - if !ok { - // This should never happen as if we have the - // address, we should have the solicited-node - // address. - panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr)) - } - snmcRef.incRef() - // Use the unspecified address as the source address when performing - // DAD. - r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false) + // Use the unspecified address as the source address when performing DAD. + ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) + r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize) @@ -569,15 +546,19 @@ func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining u pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}, tcpip.PacketBuffer{ - Header: hdr, - }); err != nil { + if err := r.WritePacket(nil, + NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + TOS: DefaultTOS, + }, tcpip.PacketBuffer{Header: hdr}, + ); err != nil { sent.Dropped.Increment() - return false, err + return err } sent.NeighborSolicit.Increment() - return false, nil + return nil } // stopDuplicateAddressDetection ends a running Duplicate Address Detection @@ -608,8 +589,8 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { delete(ndp.dad, addr) // Let the integrator know DAD did not resolve. - if ndp.nic.stack.ndpDisp != nil { - go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) + if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil) } } @@ -1212,7 +1193,7 @@ func (ndp *ndpState) startSolicitingRouters() { ndp.rtrSolicitTimer = time.AfterFunc(delay, func() { // Send an RS message with the unspecified source address. - ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, true) + ref := ndp.nic.getRefOrCreateTemp(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint, forceSpoofing) r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false) defer r.Release() diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 1a52e0e68..bc7cfbcb4 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "context" "encoding/binary" "fmt" "testing" @@ -35,13 +36,14 @@ import ( ) const ( - addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") - addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") - linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") - linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") - linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") - defaultTimeout = 100 * time.Millisecond + addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") + linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") + defaultTimeout = 100 * time.Millisecond + defaultAsyncEventTimeout = time.Second ) var ( @@ -301,6 +303,8 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s // Included in the subtests is a test to make sure that an invalid // RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s. func TestDADResolve(t *testing.T) { + const nicID = 1 + tests := []struct { name string dupAddrDetectTransmits uint8 @@ -331,44 +335,36 @@ func TestDADResolve(t *testing.T) { opts.NDPConfigs.RetransmitTimer = test.retransTimer opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits - e := channel.New(10, 1280, linkAddr1) + e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1) s := stack.New(opts) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err) - } - - stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit - - // Should have sent an NDP NS immediately. - if got := stat.Value(); got != 1 { - t.Fatalf("got NeighborSolicit = %d, want = 1", got) - + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err) } // Address should not be considered bound to the NIC yet // (DAD ongoing). - addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Wait for the remaining time - some delta (500ms), to // make sure the address is still not resolved. const delta = 500 * time.Millisecond time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if want := (tcpip.AddressWithPrefix{}); addr != want { - t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want) } // Wait for DAD to resolve. @@ -385,8 +381,8 @@ func TestDADResolve(t *testing.T) { if e.err != nil { t.Fatal("got DAD error: ", e.err) } - if e.nicID != 1 { - t.Fatalf("got DAD event w/ nicID = %d, want = 1", e.nicID) + if e.nicID != nicID { + t.Fatalf("got DAD event w/ nicID = %d, want = %d", e.nicID, nicID) } if e.addr != addr1 { t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1) @@ -395,37 +391,44 @@ func TestDADResolve(t *testing.T) { t.Fatal("got DAD event w/ resolved = false, want = true") } } - addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber) + addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber) if err != nil { - t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err) } if addr.Address != addr1 { - t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1) + t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1) } // Should not have sent any more NS messages. - if got := stat.Value(); got != uint64(test.dupAddrDetectTransmits) { + if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) } // Validate the sent Neighbor Solicitation messages. for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p := <-e.C + p, _ := e.ReadContext(context.Background()) // Make sure its an IPv6 packet. if p.Proto != header.IPv6ProtocolNumber { t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) } - // Check NDP packet. + // Check NDP NS packet. + // + // As per RFC 4861 section 4.3, a possible option is the Source Link + // Layer option, but this option MUST NOT be included when the source + // address of the packet is the unspecified address. checker.IPv6(t, p.Pkt.Header.View().ToVectorisedView().First(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.SolicitedNodeAddr(addr1)), checker.TTL(header.NDPHopLimit), checker.NDPNS( - checker.NDPNSTargetAddress(addr1))) + checker.NDPNSTargetAddress(addr1), + checker.NDPNSOptions(nil), + )) } }) } - } // TestDADFail tests to make sure that the DAD process fails if another node is @@ -498,7 +501,7 @@ func TestDADFail(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } ndpConfigs := stack.DefaultNDPConfigurations() opts := stack.Options{ @@ -577,7 +580,7 @@ func TestDADFail(t *testing.T) { // removed. func TestDADStop(t *testing.T) { ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } ndpConfigs := stack.NDPConfigurations{ RetransmitTimer: time.Second, @@ -1093,7 +1096,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultTimeout) + expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncEventTimeout) // Rx an RA from lladdr2 with huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) @@ -1110,7 +1113,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultTimeout) + expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncEventTimeout) } // TestRouterDiscoveryMaxRouters tests that only @@ -1349,7 +1352,7 @@ func TestPrefixDiscovery(t *testing.T) { if diff := checkPrefixEvent(e, subnet2, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(lifetime)*time.Second + defaultTimeout): + case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncEventTimeout): t.Fatal("timed out waiting for prefix discovery event") } @@ -1688,7 +1691,7 @@ func TestAutoGenAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultTimeout): + case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if contains(s.NICInfo()[1].ProtocolAddresses, addr1) { @@ -1994,7 +1997,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -2034,7 +2037,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncEventTimeout) if !contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -2048,7 +2051,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultTimeout) + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncEventTimeout) if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -2080,7 +2083,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultTimeout): + case <-time.After(defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { @@ -2095,7 +2098,7 @@ func TestAutoGenAddrTimerDeprecation(t *testing.T) { t.Fatalf("got unexpected auto-generated event") } - case <-time.After(newMinVLDuration + defaultTimeout): + case <-time.After(newMinVLDuration + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -2220,7 +2223,7 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(minVLSeconds*time.Second + defaultTimeout): + case <-time.After(minVLSeconds*time.Second + defaultAsyncEventTimeout): t.Fatal("timeout waiting for addr auto gen event") } }) @@ -2708,7 +2711,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultTimeout): + case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } if contains(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3291,29 +3294,29 @@ func TestRouterSolicitation(t *testing.T) { e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1) waitForPkt := func(timeout time.Duration) { t.Helper() - select { - case p := <-e.C: - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, - p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(), - ) - - case <-time.After(timeout): + ctx, _ := context.WithTimeout(context.Background(), timeout) + p, ok := e.ReadContext(ctx) + if !ok { t.Fatal("timed out waiting for packet") + return } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, + p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(), + ) } waitForNothing := func(timeout time.Duration) { t.Helper() - select { - case <-e.C: + ctx, _ := context.WithTimeout(context.Background(), timeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet") - case <-time.After(timeout): } } s := stack.New(stack.Options{ @@ -3332,12 +3335,12 @@ func TestRouterSolicitation(t *testing.T) { // times. remaining := test.maxRtrSolicit if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultTimeout) + waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncEventTimeout) remaining-- } for ; remaining > 0; remaining-- { waitForNothing(test.effectiveRtrSolicitInt - defaultTimeout) - waitForPkt(2 * defaultTimeout) + waitForPkt(defaultAsyncEventTimeout) } // Make sure no more RS. @@ -3368,20 +3371,21 @@ func TestStopStartSolicitingRouters(t *testing.T) { e := channel.New(maxRtrSolicitations, 1280, linkAddr1) waitForPkt := func(timeout time.Duration) { t.Helper() - select { - case p := <-e.C: - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } - checker.IPv6(t, p.Pkt.Header.View(), - checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS()) - - case <-time.After(timeout): + ctx, _ := context.WithTimeout(context.Background(), timeout) + p, ok := e.ReadContext(ctx) + if !ok { t.Fatal("timed out waiting for packet") + return } + + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } + checker.IPv6(t, p.Pkt.Header.View(), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS()) } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, @@ -3397,41 +3401,36 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Enable forwarding which should stop router solicitations. s.SetForwarding(true) - select { - case <-e.C: + ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { // A single RS may have been sent before forwarding was enabled. - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) + if _, ok = e.ReadContext(ctx); ok { t.Fatal("Should not have sent more than one RS message") - case <-time.After(interval + defaultTimeout): } - case <-time.After(delay + defaultTimeout): } // Enabling forwarding again should do nothing. s.SetForwarding(true) - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after becoming a router") - case <-time.After(delay + defaultTimeout): } // Disable forwarding which should start router solicitations. s.SetForwarding(false) - waitForPkt(delay + defaultTimeout) - waitForPkt(interval + defaultTimeout) - waitForPkt(interval + defaultTimeout) - select { - case <-e.C: + waitForPkt(delay + defaultAsyncEventTimeout) + waitForPkt(interval + defaultAsyncEventTimeout) + waitForPkt(interval + defaultAsyncEventTimeout) + ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") - case <-time.After(interval + defaultTimeout): } // Disabling forwarding again should do nothing. s.SetForwarding(false) - select { - case <-e.C: + ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout) + if _, ok := e.ReadContext(ctx); ok { t.Fatal("unexpectedly got a packet after becoming a router") - case <-time.After(delay + defaultTimeout): } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index de88c0bfa..7dad9a8cb 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -35,24 +35,21 @@ type NIC struct { linkEP LinkEndpoint context NICContext - mu sync.RWMutex - spoofing bool - promiscuous bool - primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - addressRanges []tcpip.Subnet - mcastJoins map[NetworkEndpointID]int32 - // packetEPs is protected by mu, but the contained PacketEndpoint - // values are not. - packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint - stats NICStats - // ndp is the NDP related state for NIC. - // - // Note, read and write operations on ndp require that the NIC is - // appropriately locked. - ndp ndpState + mu struct { + sync.RWMutex + spoofing bool + promiscuous bool + primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint + endpoints map[NetworkEndpointID]*referencedNetworkEndpoint + addressRanges []tcpip.Subnet + mcastJoins map[NetworkEndpointID]int32 + // packetEPs is protected by mu, but the contained PacketEndpoint + // values are not. + packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint + ndp ndpState + } } // NICStats includes transmitted and received stats. @@ -97,15 +94,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // of IPv6 is supported on this endpoint's LinkEndpoint. nic := &NIC{ - stack: stack, - id: id, - name: name, - linkEP: ep, - context: ctx, - primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint), - endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), - mcastJoins: make(map[NetworkEndpointID]int32), - packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint), + stack: stack, + id: id, + name: name, + linkEP: ep, + context: ctx, stats: NICStats{ Tx: DirectionStats{ Packets: &tcpip.StatCounter{}, @@ -116,22 +109,26 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC Bytes: &tcpip.StatCounter{}, }, }, - ndp: ndpState{ - configs: stack.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), - }, } - nic.ndp.nic = nic + nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint) + nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint) + nic.mu.mcastJoins = make(map[NetworkEndpointID]int32) + nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) + nic.mu.ndp = ndpState{ + nic: nic, + configs: stack.ndpConfigs, + dad: make(map[tcpip.Address]dadState), + defaultRouters: make(map[tcpip.Address]defaultRouterState), + onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), + autoGenAddresses: make(map[tcpip.Address]autoGenAddressState), + } // Register supported packet endpoint protocols. for _, netProto := range header.Ethertypes { - nic.packetEPs[netProto] = []PacketEndpoint{} + nic.mu.packetEPs[netProto] = []PacketEndpoint{} } for _, netProto := range stack.networkProtocols { - nic.packetEPs[netProto.Number()] = []PacketEndpoint{} + nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{} } return nic @@ -215,7 +212,7 @@ func (n *NIC) enable() *tcpip.Error { // and default routers). Therefore, soliciting RAs from other routers on // a link is unnecessary for routers. if !n.stack.forwarding { - n.ndp.startSolicitingRouters() + n.mu.ndp.startSolicitingRouters() } return nil @@ -230,8 +227,8 @@ func (n *NIC) becomeIPv6Router() { n.mu.Lock() defer n.mu.Unlock() - n.ndp.cleanupHostOnlyState() - n.ndp.stopSolicitingRouters() + n.mu.ndp.cleanupHostOnlyState() + n.mu.ndp.stopSolicitingRouters() } // becomeIPv6Host transitions n into an IPv6 host. @@ -242,7 +239,7 @@ func (n *NIC) becomeIPv6Host() { n.mu.Lock() defer n.mu.Unlock() - n.ndp.startSolicitingRouters() + n.mu.ndp.startSolicitingRouters() } // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it @@ -254,13 +251,13 @@ func (n *NIC) attachLinkEndpoint() { // setPromiscuousMode enables or disables promiscuous mode. func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Lock() - n.promiscuous = enable + n.mu.promiscuous = enable n.mu.Unlock() } func (n *NIC) isPromiscuousMode() bool { n.mu.RLock() - rv := n.promiscuous + rv := n.mu.promiscuous n.mu.RUnlock() return rv } @@ -272,7 +269,7 @@ func (n *NIC) isLoopback() bool { // setSpoofing enables or disables address spoofing. func (n *NIC) setSpoofing(enable bool) { n.mu.Lock() - n.spoofing = enable + n.mu.spoofing = enable n.mu.Unlock() } @@ -291,8 +288,8 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr t defer n.mu.RUnlock() var deprecatedEndpoint *referencedNetworkEndpoint - for _, r := range n.primary[protocol] { - if !r.isValidForOutgoing() { + for _, r := range n.mu.primary[protocol] { + if !r.isValidForOutgoingRLocked() { continue } @@ -342,7 +339,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn n.mu.RLock() defer n.mu.RUnlock() - primaryAddrs := n.primary[header.IPv6ProtocolNumber] + primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber] if len(primaryAddrs) == 0 { return nil @@ -425,7 +422,7 @@ func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEn // hasPermanentAddrLocked returns true if n has a permanent (including currently // tentative) address, addr. func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { - ref, ok := n.endpoints[NetworkEndpointID{addr}] + ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return false @@ -436,24 +433,54 @@ func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool { return kind == permanent || kind == permanentTentative } +type getRefBehaviour int + +const ( + // spoofing indicates that the NIC's spoofing flag should be observed when + // getting a NIC's referenced network endpoint. + spoofing getRefBehaviour = iota + + // promiscuous indicates that the NIC's promiscuous flag should be observed + // when getting a NIC's referenced network endpoint. + promiscuous + + // forceSpoofing indicates that the NIC should be assumed to be spoofing, + // regardless of what the NIC's spoofing flag is when getting a NIC's + // referenced network endpoint. + forceSpoofing +) + func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) + return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { - return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing) + return n.getRefOrCreateTemp(protocol, address, peb, spoofing) } // getRefEpOrCreateTemp returns the referenced network endpoint for the given -// protocol and address. If none exists a temporary one may be created if -// we are in promiscuous mode or spoofing. -func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint { +// protocol and address. +// +// If none exists a temporary one may be created if we are in promiscuous mode +// or spoofing. Promiscuous mode will only be checked if promiscuous is true. +// Similarly, spoofing will only be checked if spoofing is true. +func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { id := NetworkEndpointID{address} n.mu.RLock() - if ref, ok := n.endpoints[id]; ok { + var spoofingOrPromiscuous bool + switch tempRef { + case spoofing: + spoofingOrPromiscuous = n.mu.spoofing + case promiscuous: + spoofingOrPromiscuous = n.mu.promiscuous + case forceSpoofing: + spoofingOrPromiscuous = true + } + + if ref, ok := n.mu.endpoints[id]; ok { // An endpoint with this id exists, check if it can be used and return it. switch ref.getKind() { case permanentExpired: @@ -474,7 +501,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // the caller or if the address is found in the NIC's subnets. createTempEP := spoofingOrPromiscuous if !createTempEP { - for _, sn := range n.addressRanges { + for _, sn := range n.mu.addressRanges { // Skip the subnet address. if address == sn.ID() { continue @@ -502,7 +529,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - if ref, ok := n.endpoints[id]; ok { + if ref, ok := n.mu.endpoints[id]; ok { // No need to check the type as we are ok with expired endpoints at this // point. if ref.tryIncRef() { @@ -543,7 +570,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar // Sanity check. id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address} - if ref, ok := n.endpoints[id]; ok { + if ref, ok := n.mu.endpoints[id]; ok { // Endpoint already exists. if kind != permanent { return nil, tcpip.ErrDuplicateAddress @@ -562,7 +589,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar ref.deprecated = deprecated ref.configType = configType - refs := n.primary[ref.protocol] + refs := n.mu.primary[ref.protocol] for i, r := range refs { if r == ref { switch peb { @@ -572,9 +599,9 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar if i == 0 { return ref, nil } - n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) case NeverPrimaryEndpoint: - n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) return ref, nil } } @@ -637,13 +664,13 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } - n.endpoints[id] = ref + n.mu.endpoints[id] = ref n.insertPrimaryEndpointLocked(ref, peb) // If we are adding a tentative IPv6 address, start DAD. if isIPv6Unicast && kind == permanentTentative { - if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { + if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil { return nil, err } } @@ -668,8 +695,8 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { n.mu.RLock() defer n.mu.RUnlock() - addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) - for nid, ref := range n.endpoints { + addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints)) + for nid, ref := range n.mu.endpoints { // Don't include tentative, expired or temporary endpoints to // avoid confusion and prevent the caller from using those. switch ref.getKind() { @@ -695,7 +722,7 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { defer n.mu.RUnlock() var addrs []tcpip.ProtocolAddress - for proto, list := range n.primary { + for proto, list := range n.mu.primary { for _, ref := range list { // Don't include tentative, expired or tempory endpoints // to avoid confusion and prevent the caller from using @@ -726,7 +753,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit n.mu.RLock() defer n.mu.RUnlock() - list, ok := n.primary[proto] + list, ok := n.mu.primary[proto] if !ok { return tcpip.AddressWithPrefix{} } @@ -769,7 +796,7 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit // address. func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { n.mu.Lock() - n.addressRanges = append(n.addressRanges, subnet) + n.mu.addressRanges = append(n.mu.addressRanges, subnet) n.mu.Unlock() } @@ -778,13 +805,13 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Lock() // Use the same underlying array. - tmp := n.addressRanges[:0] - for _, sub := range n.addressRanges { + tmp := n.mu.addressRanges[:0] + for _, sub := range n.mu.addressRanges { if sub != subnet { tmp = append(tmp, sub) } } - n.addressRanges = tmp + n.mu.addressRanges = tmp n.mu.Unlock() } @@ -793,8 +820,8 @@ func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() - sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints)) - for nid := range n.endpoints { + sns := make([]tcpip.Subnet, 0, len(n.mu.addressRanges)+len(n.mu.endpoints)) + for nid := range n.mu.endpoints { sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) if err != nil { // This should never happen as the mask has been carefully crafted to @@ -803,7 +830,7 @@ func (n *NIC) AddressRanges() []tcpip.Subnet { } sns = append(sns, sn) } - return append(sns, n.addressRanges...) + return append(sns, n.mu.addressRanges...) } // insertPrimaryEndpointLocked adds r to n's primary endpoint list as required @@ -813,9 +840,9 @@ func (n *NIC) AddressRanges() []tcpip.Subnet { func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) { switch peb { case CanBePrimaryEndpoint: - n.primary[r.protocol] = append(n.primary[r.protocol], r) + n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r) case FirstPrimaryEndpoint: - n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...) + n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...) } } @@ -827,7 +854,7 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { // and was waiting (on the lock) to be removed and 2) the same address was // re-added in the meantime by removing this endpoint from the list and // adding a new one. - if n.endpoints[id] != r { + if n.mu.endpoints[id] != r { return } @@ -835,11 +862,11 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { panic("Reference count dropped to zero before being removed") } - delete(n.endpoints, id) - refs := n.primary[r.protocol] + delete(n.mu.endpoints, id) + refs := n.mu.primary[r.protocol] for i, ref := range refs { if ref == r { - n.primary[r.protocol] = append(refs[:i], refs[i+1:]...) + n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...) break } } @@ -854,7 +881,7 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { } func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r, ok := n.endpoints[NetworkEndpointID{addr}] + r, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return tcpip.ErrBadLocalAddress } @@ -870,13 +897,13 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { // If we are removing a tentative IPv6 unicast address, stop // DAD. if kind == permanentTentative { - n.ndp.stopDuplicateAddressDetection(addr) + n.mu.ndp.stopDuplicateAddressDetection(addr) } // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. if r.configType == slaac { - n.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) + n.mu.ndp.cleanupAutoGenAddrResourcesAndNotify(addr) } } @@ -926,7 +953,7 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A // outlined in RFC 3810 section 5. id := NetworkEndpointID{addr} - joins := n.mcastJoins[id] + joins := n.mu.mcastJoins[id] if joins == 0 { netProto, ok := n.stack.networkProtocols[protocol] if !ok { @@ -942,7 +969,7 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A return err } } - n.mcastJoins[id] = joins + 1 + n.mu.mcastJoins[id] = joins + 1 return nil } @@ -960,7 +987,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { // before leaveGroupLocked is called. func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} - joins := n.mcastJoins[id] + joins := n.mu.mcastJoins[id] switch joins { case 0: // There are no joins with this address on this NIC. @@ -971,7 +998,7 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { return err } } - n.mcastJoins[id] = joins - 1 + n.mu.mcastJoins[id] = joins - 1 return nil } @@ -1006,12 +1033,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // Are any packet sockets listening for this network protocol? n.mu.RLock() - packetEPs := n.packetEPs[protocol] + packetEPs := n.mu.packetEPs[protocol] // Check whether there are packet sockets listening for every protocol. // If we received a packet with protocol EthernetProtocolAll, then the // previous for loop will have handled it. if protocol != header.EthernetProtocolAll { - packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...) + packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) } n.mu.RUnlock() for _, ep := range packetEPs { @@ -1060,8 +1087,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // Found a NIC. n := r.ref.nic n.mu.RLock() - ref, ok := n.endpoints[NetworkEndpointID{dst}] - ok = ok && ref.isValidForOutgoing() && ref.tryIncRef() + ref, ok := n.mu.endpoints[NetworkEndpointID{dst}] + ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef() n.mu.RUnlock() if ok { r.RemoteAddress = src @@ -1181,7 +1208,10 @@ func (n *NIC) Stack() *Stack { // false. It will only return true if the address is associated with the NIC // AND it is tentative. func (n *NIC) isAddrTentative(addr tcpip.Address) bool { - ref, ok := n.endpoints[NetworkEndpointID{addr}] + n.mu.RLock() + defer n.mu.RUnlock() + + ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return false } @@ -1197,7 +1227,7 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - ref, ok := n.endpoints[NetworkEndpointID{addr}] + ref, ok := n.mu.endpoints[NetworkEndpointID{addr}] if !ok { return tcpip.ErrBadAddress } @@ -1217,7 +1247,7 @@ func (n *NIC) setNDPConfigs(c NDPConfigurations) { c.validate() n.mu.Lock() - n.ndp.configs = c + n.mu.ndp.configs = c n.mu.Unlock() } @@ -1226,7 +1256,7 @@ func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { n.mu.Lock() defer n.mu.Unlock() - n.ndp.handleRA(ip, ra) + n.mu.ndp.handleRA(ip, ra) } type networkEndpointKind int32 @@ -1268,11 +1298,11 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa n.mu.Lock() defer n.mu.Unlock() - eps, ok := n.packetEPs[netProto] + eps, ok := n.mu.packetEPs[netProto] if !ok { return tcpip.ErrNotSupported } - n.packetEPs[netProto] = append(eps, ep) + n.mu.packetEPs[netProto] = append(eps, ep) return nil } @@ -1281,14 +1311,14 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep n.mu.Lock() defer n.mu.Unlock() - eps, ok := n.packetEPs[netProto] + eps, ok := n.mu.packetEPs[netProto] if !ok { return } for i, epOther := range eps { if epOther == ep { - n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) + n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) return } } @@ -1346,14 +1376,19 @@ func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { // packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed), or the NIC to be in spoofing mode. func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { - return r.getKind() != permanentExpired || r.nic.spoofing + r.nic.mu.RLock() + defer r.nic.mu.RUnlock() + + return r.isValidForOutgoingRLocked() } -// isValidForIncoming returns true if the endpoint can accept an incoming -// packet. It requires the endpoint to not be marked expired (i.e., its address -// has been removed), or the NIC to be in promiscuous mode. -func (r *referencedNetworkEndpoint) isValidForIncoming() bool { - return r.getKind() != permanentExpired || r.nic.promiscuous +// isValidForOutgoingRLocked returns true if the endpoint can be used to send +// out a packet. It requires the endpoint to not be marked expired (i.e., its +// address has been removed), or the NIC to be in spoofing mode. +// +// r's NIC must be read locked. +func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool { + return r.getKind() != permanentExpired || r.nic.mu.spoofing } // decRef decrements the ref count and cleans up the endpoint once it reaches diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index dad288642..834fe9487 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1880,9 +1880,7 @@ func TestNICForwarding(t *testing.T) { Data: buf.ToVectorisedView(), }) - select { - case <-ep2.C: - default: + if _, ok := ep2.Read(); !ok { t.Fatal("Packet not forwarded") } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index f50604a8a..869c69a6d 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -623,10 +623,8 @@ func TestTransportForwarding(t *testing.T) { t.Fatalf("Write failed: %v", err) } - var p channel.PacketInfo - select { - case p = <-ep2.C: - default: + p, ok := ep2.Read() + if !ok { t.Fatal("Response packet not forwarded") } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 59c9b3fb0..0fa141d58 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -626,6 +626,12 @@ type TCPLingerTimeoutOption time.Duration // before being marked closed. type TCPTimeWaitTimeoutOption time.Duration +// TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a +// accept to return a completed connection only when there is data to be +// read. This usually means the listening socket will drop the final ACK +// for a handshake till the specified timeout until a segment with data arrives. +type TCPDeferAcceptOption time.Duration + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 3aa23d529..ac18ec5b1 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,7 +23,6 @@ go_library( "icmp_packet_list.go", "protocol.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/icmp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD index 4858d150c..d22de6b26 100644 --- a/pkg/tcpip/transport/packet/BUILD +++ b/pkg/tcpip/transport/packet/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -22,7 +22,6 @@ go_library( "endpoint_state.go", "packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 2f2131ff7..c9baf4600 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -1,5 +1,5 @@ +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -23,7 +23,6 @@ go_library( "protocol.go", "raw_packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 0e3ab05ad..272e8f570 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -55,10 +54,10 @@ go_library( "tcp_segment_list.go", "timer.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ + "//pkg/log", "//pkg/rand", "//pkg/sleep", "//pkg/sync", @@ -92,6 +91,7 @@ go_test( tags = ["flaky"], deps = [ ":tcp", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d469758eb..6101f2945 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -222,13 +222,13 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { netProto = s.route.NetProto } - n := newEndpoint(l.stack, netProto, nil) + n := newEndpoint(l.stack, netProto, queue) n.v6only = l.v6only n.ID = s.id n.boundNICID = s.route.NICID() @@ -273,16 +273,17 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // createEndpoint creates a new endpoint in connected state and then performs // the TCP 3-way handshake. -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep, err := l.createConnectingEndpoint(s, isn, irs, opts) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) if err != nil { return nil, err } // listenEP is nil when listenContext is used by tcp.Forwarder. + deferAccept := time.Duration(0) if l.listenEP != nil { l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { @@ -290,13 +291,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, tcpip.ErrConnectionAborted } l.addPendingEndpoint(ep) + deferAccept = l.listenEP.deferAccept l.listenEP.mu.Unlock() } // Perform the 3-way handshake. - h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow())) - - h.resetToSynRcvd(isn, irs, opts) + h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.Close() if l.listenEP != nil { @@ -377,16 +377,14 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header defer e.decSynRcvdCount() defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts) + n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() return } ctx.removePendingEndpoint(n) - // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) + n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() e.deliverAccepted(n) @@ -546,7 +544,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) + n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions, &waiter.Queue{}) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() @@ -576,8 +574,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // space available in the backlog. // Start the protocol goroutine. - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) + n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 4e3c5419c..9ff7ac261 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -86,6 +86,19 @@ type handshake struct { // rcvWndScale is the receive window scale, as defined in RFC 1323. rcvWndScale int + + // startTime is the time at which the first SYN/SYN-ACK was sent. + startTime time.Time + + // deferAccept if non-zero will drop the final ACK for a passive + // handshake till an ACK segment with data is received or the timeout is + // hit. + deferAccept time.Duration + + // acked is true if the the final ACK for a 3-way handshake has + // been received. This is required to stop retransmitting the + // original SYN-ACK when deferAccept is enabled. + acked bool } func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { @@ -112,6 +125,12 @@ func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { return h } +func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake { + h := newHandshake(ep, rcvWnd) + h.resetToSynRcvd(isn, irs, opts, deferAccept) + return h +} + // FindWndScale determines the window scale to use for the given maximum window // size. func FindWndScale(wnd seqnum.Size) int { @@ -181,7 +200,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 { // resetToSynRcvd resets the state of the handshake object to the SYN-RCVD // state. -func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) { +func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) { h.active = false h.state = handshakeSynRcvd h.flags = header.TCPFlagSyn | header.TCPFlagAck @@ -189,6 +208,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.ackNum = irs + 1 h.mss = opts.MSS h.sndWndScale = opts.WS + h.deferAccept = deferAccept h.ep.mu.Lock() h.ep.setEndpointState(StateSynRecv) h.ep.mu.Unlock() @@ -352,6 +372,14 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { // We have previously received (and acknowledged) the peer's SYN. If the // peer acknowledges our SYN, the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { + // If deferAccept is not zero and this is a bare ACK and the + // timeout is not hit then drop the ACK. + if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept { + h.acked = true + h.ep.stack.Stats().DroppedPackets.Increment() + return nil + } + // If the timestamp option is negotiated and the segment does // not carry a timestamp option then the segment must be dropped // as per https://tools.ietf.org/html/rfc7323#section-3.2. @@ -365,10 +393,16 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted + h.ep.mu.Lock() h.ep.transitionToStateEstablishedLocked(h) + // If the segment has data then requeue it for the receiver + // to process it again once main loop is started. + if s.data.Size() > 0 { + s.incRef() + h.ep.enqueueSegment(s) + } h.ep.mu.Unlock() - return nil } @@ -471,6 +505,7 @@ func (h *handshake) execute() *tcpip.Error { } } + h.startTime = time.Now() // Initialize the resend timer. resendWaker := sleep.Waker{} timeOut := time.Duration(time.Second) @@ -524,11 +559,21 @@ func (h *handshake) execute() *tcpip.Error { switch index, _ := s.Fetch(true); index { case wakerForResend: timeOut *= 2 - if timeOut > 60*time.Second { + if timeOut > MaxRTO { return tcpip.ErrTimeout } rt.Reset(timeOut) - h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + // Resend the SYN/SYN-ACK only if the following conditions hold. + // - It's an active handshake (deferAccept does not apply) + // - It's a passive handshake and we have not yet got the final-ACK. + // - It's a passive handshake and we got an ACK but deferAccept is + // enabled and we are now past the deferAccept duration. + // The last is required to provide a way for the peer to complete + // the connection with another ACK or data (as ACKs are never + // retransmitted on their own). + if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + } case wakerForNotification: n := h.ep.fetchNotifications() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 13718ff55..b5a8e15ee 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -498,6 +498,13 @@ type endpoint struct { // without any data being acked. userTimeout time.Duration + // deferAccept if non-zero specifies a user specified time during + // which the final ACK of a handshake will be dropped provided the + // ACK is a bare ACK and carries no data. If the timeout is crossed then + // the bare ACK is accepted and the connection is delivered to the + // listener. + deferAccept time.Duration + // pendingAccepted is a synchronization primitive used to track number // of connections that are queued up to be delivered to the accepted // channel. We use this to ensure that all goroutines blocked on writing @@ -1574,6 +1581,15 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.TCPDeferAcceptOption: + e.mu.Lock() + if time.Duration(v) > MaxRTO { + v = tcpip.TCPDeferAcceptOption(MaxRTO) + } + e.deferAccept = time.Duration(v) + e.mu.Unlock() + return nil + default: return nil } @@ -1798,6 +1814,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case *tcpip.TCPDeferAcceptOption: + e.mu.Lock() + *o = tcpip.TCPDeferAcceptOption(e.deferAccept) + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -2025,8 +2047,14 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // work mutex is available. if e.workMu.TryLock() { e.mu.Lock() - e.resetConnectionLocked(tcpip.ErrConnectionAborted) - e.notifyProtocolGoroutine(notifyTickleWorker) + // We need to double check here to make + // sure worker has not transitioned the + // endpoint out of a connected state + // before trying to send a reset. + if e.EndpointState().connected() { + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.notifyProtocolGoroutine(notifyTickleWorker) + } e.mu.Unlock() e.workMu.Unlock() } else { @@ -2149,9 +2177,8 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { // startAcceptedLoop sets up required state and starts a goroutine with the // main loop for accepted connections. -func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { +func (e *endpoint) startAcceptedLoop() { e.mu.Lock() - e.waiterQueue = waiterQueue e.workerRunning = true e.mu.Unlock() wakerInitDone := make(chan struct{}) @@ -2177,7 +2204,6 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { default: return nil, nil, tcpip.ErrWouldBlock } - return n, n.waiterQueue, nil } diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 7eb613be5..c9ee5bf06 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -157,13 +157,13 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, TSVal: r.synOptions.TSVal, TSEcr: r.synOptions.TSEcr, SACKPermitted: r.synOptions.SACKPermitted, - }) + }, queue) if err != nil { return nil, err } // Start the protocol goroutine. - ep.startAcceptedLoop(queue) + ep.startAcceptedLoop() return ep, nil } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index df2fb1071..2c1505067 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -6787,3 +6788,183 @@ func TestIncreaseWindowOnBufferResize(t *testing.T) { ), ) } + +func TestTCPDeferAccept(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + const tcpDeferAccept = 1 * time.Second + if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock) + } + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) + + // Give a bit of time for the socket to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept() + if err != nil { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err) + } + + aep.Close() + // Closing aep without reading the data should trigger a RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestTCPDeferAcceptTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + const tcpDeferAccept = 1 * time.Second + if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %v", tcpDeferAccept, err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: %s", err, tcpip.ErrWouldBlock) + } + + // Sleep for a little of the tcpDeferAccept timeout. + time.Sleep(tcpDeferAccept + 100*time.Millisecond) + + // On timeout expiry we should get a SYN-ACK retransmission. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) + + // Give sometime for the endpoint to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept() + if err != nil { + t.Fatalf("c.EP.Accept() returned unexpected error got: %v, want: nil", err) + } + + aep.Close() + // Closing aep without reading the data should trigger a RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestResetDuringClose(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + iss := seqnum.Value(789) + c.CreateConnected(iss, 30000, -1 /* epRecvBuf */) + // Send some data to make sure there is some unread + // data to trigger a reset on c.Close. + irs := c.IRS + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(1), + AckNum: irs.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(irs.Add(1))), + checker.AckNum(uint32(iss.Add(5))))) + + // Close in a separate goroutine so that we can trigger + // a race with the RST we send below. This should not + // panic due to the route being released depeding on + // whether Close() sends an active RST or the RST sent + // below is processed by the worker first. + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: iss.Add(5), + AckNum: c.IRS.Add(5), + RcvWnd: 30000, + Flags: header.TCPFlagRst, + }) + }() + + wg.Add(1) + go func() { + defer wg.Done() + c.EP.Close() + }() + + wg.Wait() +} diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD index b33ec2087..ce6a2c31d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/BUILD +++ b/pkg/tcpip/transport/tcp/testing/context/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -6,7 +6,6 @@ go_library( name = "context", testonly = 1, srcs = ["context.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context", visibility = [ "//visibility:public", ], diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 822907998..730ac4292 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -18,6 +18,7 @@ package context import ( "bytes" + "context" "testing" "time" @@ -215,11 +216,9 @@ func (c *Context) Stack() *stack.Stack { func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { c.t.Helper() - select { - case <-c.linkEP.C: + ctx, _ := context.WithTimeout(context.Background(), wait) + if _, ok := c.linkEP.ReadContext(ctx); ok { c.t.Fatal(errMsg) - - case <-time.After(wait): } } @@ -234,27 +233,27 @@ func (c *Context) CheckNoPacket(errMsg string) { // 2 seconds. func (c *Context) GetPacket() []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + c.t.Fatalf("Packet wasn't written out") + return nil + } - if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { - c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) - } + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - case <-time.After(2 * time.Second): - c.t.Fatalf("Packet wasn't written out") + if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { + c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) } - return nil + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b } // GetPacketNonBlocking reads a packet from the link layer endpoint @@ -263,20 +262,21 @@ func (c *Context) GetPacket() []byte { // nil immediately. func (c *Context) GetPacketNonBlocking() []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv4.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) - } - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) - return b - default: + p, ok := c.linkEP.Read() + if !ok { return nil } + + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b } // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. @@ -484,23 +484,23 @@ func (c *Context) CreateV6Endpoint(v6only bool) { // and asserts that it is an IPv6 Packet with the expected src/dest addresses. func (c *Context) GetV6Packet() []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != ipv6.ProtocolNumber { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) - } - b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) - copy(b, p.Pkt.Header.View()) - copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) - - checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) - return b - case <-time.After(2 * time.Second): + ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { c.t.Fatalf("Packet wasn't written out") + return nil + } + + if p.Proto != ipv6.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) } + b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) + copy(b, p.Pkt.Header.View()) + copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) - return nil + checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) + return b } // SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD index 43fcc27f0..3ad6994a7 100644 --- a/pkg/tcpip/transport/tcpconntrack/BUILD +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -1,12 +1,10 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "tcpconntrack", srcs = ["tcp_conntrack.go"], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack", visibility = ["//visibility:public"], deps = [ "//pkg/tcpip/header", diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 57ff123e3..adc908e24 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -1,6 +1,5 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") -load("//tools/go_stateify:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -25,7 +24,6 @@ go_library( "protocol.go", "udp_packet_list.go", ], - importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/udp", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index c6927cfe3..f0ff3fe71 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,6 +16,7 @@ package udp_test import ( "bytes" + "context" "fmt" "math/rand" "testing" @@ -357,30 +358,29 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { c.t.Helper() - select { - case p := <-c.linkEP.C: - if p.Proto != flow.netProto() { - c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) - } - - hdr := p.Pkt.Header.View() - b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) - - h := flow.header4Tuple(outgoing) - checkers := append( - checkers, - checker.SrcAddr(h.srcAddr.Addr), - checker.DstAddr(h.dstAddr.Addr), - checker.UDP(checker.DstPort(h.dstAddr.Port)), - ) - flow.checkerFn()(c.t, b, checkers...) - return b - - case <-time.After(2 * time.Second): + ctx, _ := context.WithTimeout(context.Background(), 2*time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { c.t.Fatalf("Packet wasn't written out") + return nil } - return nil + if p.Proto != flow.netProto() { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) + } + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + + h := flow.header4Tuple(outgoing) + checkers = append( + checkers, + checker.SrcAddr(h.srcAddr.Addr), + checker.DstAddr(h.dstAddr.Addr), + checker.UDP(checker.DstPort(h.dstAddr.Port)), + ) + flow.checkerFn()(c.t, b, checkers...) + return b } // injectPacket creates a packet of the given flow and with the given payload, @@ -1541,48 +1541,50 @@ func TestV4UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - select { - case p := <-c.linkEP.C: + ctx, _ := context.WithTimeout(context.Background(), time.Second) + if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) - case <-time.After(1 * time.Second): - return } + return } - select { - case p := <-c.linkEP.C: - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) - if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } + // ICMP required. + ctx, _ := context.WithTimeout(context.Background(), time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + t.Fatalf("packet wasn't written out") + return + } - hdr := header.IPv4(pkt) - checker.IPv4(t, hdr, checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4DstUnreachable), - checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + var pkt []byte + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) + if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } - icmpPkt := header.ICMPv4(hdr.Payload()) - payloadIPHeader := header.IPv4(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize - } + hdr := header.IPv4(pkt) + checker.IPv4(t, hdr, checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + icmpPkt := header.ICMPv4(hdr.Payload()) + payloadIPHeader := header.IPv4(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize + } - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %d, want: %d", got, want) - } - case <-time.After(1 * time.Second): - t.Fatalf("packet wasn't written out") + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %d, want: %d", got, want) } }) } @@ -1615,47 +1617,49 @@ func TestV6UnknownDestination(t *testing.T) { } c.injectPacket(tc.flow, payload) if !tc.icmpRequired { - select { - case p := <-c.linkEP.C: + ctx, _ := context.WithTimeout(context.Background(), time.Second) + if p, ok := c.linkEP.ReadContext(ctx); ok { t.Fatalf("unexpected packet received: %+v", p) - case <-time.After(1 * time.Second): - return } + return } - select { - case p := <-c.linkEP.C: - var pkt []byte - pkt = append(pkt, p.Pkt.Header.View()...) - pkt = append(pkt, p.Pkt.Data.ToView()...) - if got, want := len(pkt), header.IPv6MinimumMTU; got > want { - t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) - } + // ICMP required. + ctx, _ := context.WithTimeout(context.Background(), time.Second) + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + t.Fatalf("packet wasn't written out") + return + } + + var pkt []byte + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) + if got, want := len(pkt), header.IPv6MinimumMTU; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } - hdr := header.IPv6(pkt) - checker.IPv6(t, hdr, checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6DstUnreachable), - checker.ICMPv6Code(header.ICMPv6PortUnreachable))) + hdr := header.IPv6(pkt) + checker.IPv6(t, hdr, checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6PortUnreachable))) - icmpPkt := header.ICMPv6(hdr.Payload()) - payloadIPHeader := header.IPv6(icmpPkt.Payload()) - wantLen := len(payload) - if tc.largePayload { - wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize - } - // In case of large payloads the IP packet may be truncated. Update - // the length field before retrieving the udp datagram payload. - payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) + icmpPkt := header.ICMPv6(hdr.Payload()) + payloadIPHeader := header.IPv6(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + } + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) - origDgram := header.UDP(payloadIPHeader.Payload()) - if got, want := len(origDgram.Payload()), wantLen; got != want { - t.Fatalf("unexpected payload length got: %d, want: %d", got, want) - } - if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { - t.Fatalf("unexpected payload got: %v, want: %v", got, want) - } - case <-time.After(1 * time.Second): - t.Fatalf("packet wasn't written out") + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %v, want: %v", got, want) } }) } |