diff options
Diffstat (limited to 'pkg/tcpip')
114 files changed, 4920 insertions, 3365 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index ea46c30da..f00cfd0f5 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -39,12 +39,29 @@ go_library( deps_test( name = "netstack_deps_test", allowed = [ + # gVisor deps. + "//pkg/atomicbitops", + "//pkg/buffer", + "//pkg/context", + "//pkg/gohacks", + "//pkg/goid", + "//pkg/ilist", + "//pkg/linewriter", + "//pkg/log", + "//pkg/rand", + "//pkg/sleep", + "//pkg/state", + "//pkg/state/wire", + "//pkg/sync", + "//pkg/waiter", + + # Other deps. "@com_github_google_btree//:go_default_library", "@org_golang_x_sys//unix:go_default_library", "@org_golang_x_time//rate:go_default_library", ], allowed_prefixes = [ - "//", + "//pkg/tcpip", "@org_golang_x_sys//internal/unsafeheader", ], targets = [ diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 18e6cc3cd..e0dfe5813 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -50,7 +50,7 @@ func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { ipv4 := header.IPv4(b) if !ipv4.IsValid(len(b)) { - t.Error("Not a valid IPv4 packet") + t.Fatalf("Not a valid IPv4 packet: %x", ipv4) } if !ipv4.IsChecksumValid() { @@ -72,7 +72,7 @@ func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { ipv6 := header.IPv6(b) if !ipv6.IsValid(len(b)) { - t.Error("Not a valid IPv6 packet") + t.Fatalf("Not a valid IPv6 packet: %x", ipv6) } for _, f := range checkers { @@ -701,7 +701,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp if !ok { return } - opts := []byte(tcp.Options()) + opts := tcp.Options() limit := len(opts) foundTS := false tsVal := uint32(0) @@ -748,12 +748,6 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp } } -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does -// not contain any SACK blocks in the TCP options. -func TCPNoSACKBlockChecker() TransportChecker { - return TCPSACKBlockChecker(nil) -} - // TCPSACKBlockChecker creates a checker that verifies that the segment does // contain the specified SACK blocks in the TCP options. func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { @@ -765,7 +759,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { } var gotSACKBlocks []header.SACKBlock - opts := []byte(tcp.Options()) + opts := tcp.Options() limit := len(opts) for i := 0; i < limit; { switch opts[i] { diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go index fb819d7a8..9f8f51647 100644 --- a/pkg/tcpip/faketime/faketime.go +++ b/pkg/tcpip/faketime/faketime.go @@ -29,14 +29,14 @@ type NullClock struct{} var _ tcpip.Clock = (*NullClock)(nil) -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (*NullClock) NowNanoseconds() int64 { - return 0 +// Now implements tcpip.Clock.Now. +func (*NullClock) Now() time.Time { + return time.Time{} } // NowMonotonic implements tcpip.Clock.NowMonotonic. -func (*NullClock) NowMonotonic() int64 { - return 0 +func (*NullClock) NowMonotonic() tcpip.MonotonicTime { + return tcpip.MonotonicTime{} } // AfterFunc implements tcpip.Clock.AfterFunc. @@ -118,16 +118,17 @@ func NewManualClock() *ManualClock { var _ tcpip.Clock = (*ManualClock)(nil) -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (mc *ManualClock) NowNanoseconds() int64 { +// Now implements tcpip.Clock.Now. +func (mc *ManualClock) Now() time.Time { mc.mu.RLock() defer mc.mu.RUnlock() - return mc.mu.now.UnixNano() + return mc.mu.now } // NowMonotonic implements tcpip.Clock.NowMonotonic. -func (mc *ManualClock) NowMonotonic() int64 { - return mc.NowNanoseconds() +func (mc *ManualClock) NowMonotonic() tcpip.MonotonicTime { + var mt tcpip.MonotonicTime + return mt.Add(mc.Now().Sub(time.Unix(0, 0))) } // AfterFunc implements tcpip.Clock.AfterFunc. @@ -218,6 +219,12 @@ func (mc *ManualClock) stopTimerLocked(mt *manualTimer) { } } +// RunImmediatelyScheduledJobs runs all jobs scheduled to run at the current +// time. +func (mc *ManualClock) RunImmediatelyScheduledJobs() { + mc.Advance(0) +} + // Advance executes all work that have been scheduled to execute within d from // the current time. Blocks until all work has completed execution. func (mc *ManualClock) Advance(d time.Duration) { diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go index c2704df2c..fd2bb470a 100644 --- a/pkg/tcpip/faketime/faketime_test.go +++ b/pkg/tcpip/faketime/faketime_test.go @@ -26,7 +26,7 @@ func TestManualClockAdvance(t *testing.T) { clock := faketime.NewManualClock() start := clock.NowMonotonic() clock.Advance(timeout) - if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want { + if got, want := clock.NowMonotonic().Sub(start), timeout; got != want { t.Errorf("got = %d, want = %d", got, want) } } @@ -87,7 +87,7 @@ func TestManualClockAfterFunc(t *testing.T) { if got, want := counter2, test.wantCounter2; got != want { t.Errorf("got counter2 = %d, want = %d", got, want) } - if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want { + if got, want := clock.NowMonotonic().Sub(start), test.advance; got != want { t.Errorf("got elapsed = %d, want = %d", got, want) } }) diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index 6aa9acfa8..e2c85e220 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -18,6 +18,7 @@ package header import ( "encoding/binary" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -234,3 +235,64 @@ func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip. return Checksum([]byte{0, uint8(protocol)}, xsum) } + +// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated +// checksum. +// +// The value MUST begin at a 2-byte boundary in the original buffer. +func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 { + // As per RFC 1071 page 4, + // (4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + return ChecksumCombine(xsum, ChecksumCombine(new, ^old)) +} + +// checksumUpdate2ByteAlignedAddress updates an address in a calculated +// checksum. +// +// The addresses must have the same length and must contain an even number +// of bytes. The address MUST begin at a 2-byte boundary in the original buffer. +func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 { + const uint16Bytes = 2 + + if len(old) != len(new) { + panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", len(old), len(new))) + } + + if len(old)%uint16Bytes != 0 { + panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", len(old))) + } + + // As per RFC 1071 page 4, + // (4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + for len(old) != 0 { + // Convert the 2 byte sequences to uint16 values then apply the increment + // update. + xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(old[0])<<8)+uint16(old[1]), (uint16(new[0])<<8)+uint16(new[1])) + old = old[uint16Bytes:] + new = new[uint16Bytes:] + } + + return xsum +} diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index d267dabd0..3445511f4 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -23,6 +23,7 @@ import ( "sync" "testing" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -256,3 +257,205 @@ func TestICMPv6Checksum(t *testing.T) { }) }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) } + +func randomAddress(size int) tcpip.Address { + s := make([]byte, size) + for i := 0; i < size; i++ { + s[i] = byte(rand.Uint32()) + } + return tcpip.Address(s) +} + +func TestChecksummableNetworkUpdateAddress(t *testing.T) { + tests := []struct { + name string + update func(header.IPv4, tcpip.Address) + }{ + { + name: "SetSourceAddressWithChecksumUpdate", + update: header.IPv4.SetSourceAddressWithChecksumUpdate, + }, + { + name: "SetDestinationAddressWithChecksumUpdate", + update: header.IPv4.SetDestinationAddressWithChecksumUpdate, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for i := 0; i < 1000; i++ { + var origBytes [header.IPv4MinimumSize]byte + header.IPv4(origBytes[:]).Encode(&header.IPv4Fields{ + TOS: 1, + TotalLength: header.IPv4MinimumSize, + ID: 2, + Flags: 3, + FragmentOffset: 4, + TTL: 5, + Protocol: 6, + Checksum: 0, + SrcAddr: randomAddress(header.IPv4AddressSize), + DstAddr: randomAddress(header.IPv4AddressSize), + }) + + addr := randomAddress(header.IPv4AddressSize) + + bytesCopy := origBytes + h := header.IPv4(bytesCopy[:]) + origXSum := h.CalculateChecksum() + h.SetChecksum(^origXSum) + + test.update(h, addr) + got := ^h.Checksum() + h.SetChecksum(0) + want := h.CalculateChecksum() + if got != want { + t.Errorf("got h.Checksum() = 0x%x, want = 0x%x; originalBytes = 0x%x, new addr = %s", got, want, origBytes, addr) + } + } + }) + } +} + +func TestChecksummableTransportUpdatePort(t *testing.T) { + // The fields in the pseudo header is not tested here so we just use 0. + const pseudoHeaderXSum = 0 + + tests := []struct { + name string + transportHdr func(_, _ uint16) (header.ChecksummableTransport, func(uint16) uint16) + proto tcpip.TransportProtocolNumber + }{ + { + name: "TCP", + transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { + h := header.TCP(make([]byte, header.TCPMinimumSize)) + h.Encode(&header.TCPFields{ + SrcPort: src, + DstPort: dst, + SeqNum: 1, + AckNum: 2, + DataOffset: header.TCPMinimumSize, + Flags: 3, + WindowSize: 4, + Checksum: 0, + UrgentPointer: 5, + }) + h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) + return h, h.CalculateChecksum + }, + proto: header.TCPProtocolNumber, + }, + { + name: "UDP", + transportHdr: func(src, dst uint16) (header.ChecksummableTransport, func(uint16) uint16) { + h := header.UDP(make([]byte, header.UDPMinimumSize)) + h.Encode(&header.UDPFields{ + SrcPort: src, + DstPort: dst, + Length: 0, + Checksum: 0, + }) + h.SetChecksum(^h.CalculateChecksum(pseudoHeaderXSum)) + return h, h.CalculateChecksum + }, + proto: header.UDPProtocolNumber, + }, + } + + for i := 0; i < 1000; i++ { + origSrcPort := uint16(rand.Uint32()) + origDstPort := uint16(rand.Uint32()) + newPort := uint16(rand.Uint32()) + + t.Run(fmt.Sprintf("OrigSrcPort=%d,OrigDstPort=%d,NewPort=%d", origSrcPort, origDstPort, newPort), func(*testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range []struct { + name string + update func(header.ChecksummableTransport) + }{ + { + name: "Source port", + update: func(h header.ChecksummableTransport) { h.SetSourcePortWithChecksumUpdate(newPort) }, + }, + { + name: "Destination port", + update: func(h header.ChecksummableTransport) { h.SetDestinationPortWithChecksumUpdate(newPort) }, + }, + } { + t.Run(subTest.name, func(t *testing.T) { + h, calcXSum := test.transportHdr(origSrcPort, origDstPort) + subTest.update(h) + // TCP and UDP hold the 1s complement of the fully calculated + // checksum. + got := ^h.Checksum() + h.SetChecksum(0) + + if want := calcXSum(pseudoHeaderXSum); got != want { + h, _ := test.transportHdr(origSrcPort, origDstPort) + t.Errorf("got Checksum() = 0x%x, want = 0x%x; originalBytes = %#v, new port = %d", got, want, h, newPort) + } + }) + } + }) + } + }) + } +} + +func TestChecksummableTransportUpdatePseudoHeaderAddress(t *testing.T) { + const addressSize = 6 + + tests := []struct { + name string + transportHdr func() header.ChecksummableTransport + proto tcpip.TransportProtocolNumber + }{ + { + name: "TCP", + transportHdr: func() header.ChecksummableTransport { return header.TCP(make([]byte, header.TCPMinimumSize)) }, + proto: header.TCPProtocolNumber, + }, + { + name: "UDP", + transportHdr: func() header.ChecksummableTransport { return header.UDP(make([]byte, header.UDPMinimumSize)) }, + proto: header.UDPProtocolNumber, + }, + } + + for i := 0; i < 1000; i++ { + permanent := randomAddress(addressSize) + old := randomAddress(addressSize) + new := randomAddress(addressSize) + + t.Run(fmt.Sprintf("Permanent=%q,Old=%q,New=%q", permanent, old, new), func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, fullChecksum := range []bool{true, false} { + t.Run(fmt.Sprintf("FullChecksum=%t", fullChecksum), func(t *testing.T) { + initialXSum := header.PseudoHeaderChecksum(test.proto, permanent, old, 0) + if fullChecksum { + // TCP and UDP hold the 1s complement of the fully calculated + // checksum. + initialXSum = ^initialXSum + } + + h := test.transportHdr() + h.SetChecksum(initialXSum) + h.UpdateChecksumPseudoHeaderAddress(old, new, fullChecksum) + + got := h.Checksum() + if fullChecksum { + got = ^got + } + if want := header.PseudoHeaderChecksum(test.proto, permanent, new, 0); got != want { + t.Errorf("got Checksum() = 0x%x, want = 0x%x; h = %#v", got, want, h) + } + }) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go index 861cbbb70..3a41adfc4 100644 --- a/pkg/tcpip/header/interfaces.go +++ b/pkg/tcpip/header/interfaces.go @@ -53,6 +53,31 @@ type Transport interface { Payload() []byte } +// ChecksummableTransport is a Transport that supports checksumming. +type ChecksummableTransport interface { + Transport + + // SetSourcePortWithChecksumUpdate sets the source port and updates + // the checksum. + // + // The receiver's checksum must be a fully calculated checksum. + SetSourcePortWithChecksumUpdate(port uint16) + + // SetDestinationPortWithChecksumUpdate sets the destination port and updates + // the checksum. + // + // The receiver's checksum must be a fully calculated checksum. + SetDestinationPortWithChecksumUpdate(port uint16) + + // UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an + // updated address in the pseudo header. + // + // If fullChecksum is true, the receiver's checksum field is assumed to hold a + // fully calculated checksum. Otherwise, it is assumed to hold a partially + // calculated checksum which only reflects the pseudo header. + UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) +} + // Network offers generic methods to query and/or update the fields of the // header of a network protocol buffer. type Network interface { @@ -90,3 +115,16 @@ type Network interface { // SetTOS sets the values of the "type of service" and "flow label" fields. SetTOS(t uint8, l uint32) } + +// ChecksummableNetwork is a Network that supports checksumming. +type ChecksummableNetwork interface { + Network + + // SetSourceAddressAndChecksum sets the source address and updates the + // checksum to reflect the new address. + SetSourceAddressWithChecksumUpdate(tcpip.Address) + + // SetDestinationAddressAndChecksum sets the destination address and + // updates the checksum to reflect the new address. + SetDestinationAddressWithChecksumUpdate(tcpip.Address) +} diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 2be21ec75..dcc549c7b 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -17,6 +17,7 @@ package header import ( "encoding/binary" "fmt" + "time" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -181,7 +182,7 @@ const ( // ipv4LinkLocalUnicastSubnet is the IPv4 link local unicast subnet as defined // by RFC 3927 section 1. var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet { - subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", tcpip.AddressMask("\xff\xff\x00\x00")) + subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", "\xff\xff\x00\x00") if err != nil { panic(err) } @@ -191,7 +192,7 @@ var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet { // ipv4LinkLocalMulticastSubnet is the IPv4 link local multicast subnet as // defined by RFC 5771 section 4. var ipv4LinkLocalMulticastSubnet = func() tcpip.Subnet { - subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", tcpip.AddressMask("\xff\xff\xff\x00")) + subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", "\xff\xff\xff\x00") if err != nil { panic(err) } @@ -304,6 +305,18 @@ func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } +// SetSourceAddressWithChecksumUpdate implements ChecksummableNetwork. +func (b IPv4) SetSourceAddressWithChecksumUpdate(new tcpip.Address) { + b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.SourceAddress(), new)) + b.SetSourceAddress(new) +} + +// SetDestinationAddressWithChecksumUpdate implements ChecksummableNetwork. +func (b IPv4) SetDestinationAddressWithChecksumUpdate(new tcpip.Address) { + b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), b.DestinationAddress(), new)) + b.SetDestinationAddress(new) +} + // padIPv4OptionsLength returns the total length for IPv4 options of length l // after applying padding according to RFC 791: // The internet header padding is used to ensure that the internet @@ -572,7 +585,7 @@ func (o *IPv4OptionGeneric) Type() IPv4OptionType { func (o *IPv4OptionGeneric) Size() uint8 { return uint8(len(*o)) } // Contents implements IPv4Option. -func (o *IPv4OptionGeneric) Contents() []byte { return []byte(*o) } +func (o *IPv4OptionGeneric) Contents() []byte { return *o } // IPv4OptionIterator is an iterator pointing to a specific IP option // at any point of time. It also holds information as to a new options buffer @@ -610,7 +623,7 @@ func (i *IPv4OptionIterator) InitReplacement(option IPv4Option) IPv4Options { // RemainingBuffer returns the remaining (unused) part of the new option buffer, // into which a new option may be written. func (i *IPv4OptionIterator) RemainingBuffer() IPv4Options { - return IPv4Options(i.newOptions[i.writePoint:]) + return i.newOptions[i.writePoint:] } // ConsumeBuffer marks a portion of the new buffer as used. @@ -813,9 +826,12 @@ const ( // ipv4TimestampTime provides the current time as specified in RFC 791. func ipv4TimestampTime(clock tcpip.Clock) uint32 { - const millisecondsPerDay = 24 * 3600 * 1000 - const nanoPerMilli = 1000000 - return uint32((clock.NowNanoseconds() / nanoPerMilli) % millisecondsPerDay) + // Per RFC 791 page 21: + // The Timestamp is a right-justified, 32-bit timestamp in + // milliseconds since midnight UT. + now := clock.Now().UTC() + midnight := now.Truncate(24 * time.Hour) + return uint32(now.Sub(midnight).Milliseconds()) } // IP Timestamp option fields. @@ -843,7 +859,7 @@ func (ts *IPv4OptionTimestamp) Type() IPv4OptionType { return IPv4OptionTimestam func (ts *IPv4OptionTimestamp) Size() uint8 { return uint8(len(*ts)) } // Contents implements IPv4Option. -func (ts *IPv4OptionTimestamp) Contents() []byte { return []byte(*ts) } +func (ts *IPv4OptionTimestamp) Contents() []byte { return *ts } // Pointer returns the pointer field in the IP Timestamp option. func (ts *IPv4OptionTimestamp) Pointer() uint8 { @@ -947,7 +963,7 @@ func (rr *IPv4OptionRecordRoute) Type() IPv4OptionType { return IPv4OptionRecord func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) } // Contents implements IPv4Option. -func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) } +func (rr *IPv4OptionRecordRoute) Contents() []byte { return *rr } // Router Alert option specific related constants. // @@ -992,7 +1008,7 @@ func (*IPv4OptionRouterAlert) Type() IPv4OptionType { return IPv4OptionRouterAle func (ra *IPv4OptionRouterAlert) Size() uint8 { return uint8(len(*ra)) } // Contents implements IPv4Option. -func (ra *IPv4OptionRouterAlert) Contents() []byte { return []byte(*ra) } +func (ra *IPv4OptionRouterAlert) Contents() []byte { return *ra } // Value returns the value of the IPv4OptionRouterAlert. func (ra *IPv4OptionRouterAlert) Value() uint16 { diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 3d1bccd15..b1f39e6e6 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -77,12 +77,12 @@ const ( // ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag // field in the flags byte within an NDPPrefixInformation. - ndpPrefixInformationOnLinkFlagMask = (1 << 7) + ndpPrefixInformationOnLinkFlagMask = 1 << 7 // ndpPrefixInformationAutoAddrConfFlagMask is the mask of the // Autonomous Address-Configuration flag field in the flags byte within // an NDPPrefixInformation. - ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6) + ndpPrefixInformationAutoAddrConfFlagMask = 1 << 6 // ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1 // field in the flags byte within an NDPPrefixInformation. @@ -148,15 +148,10 @@ const ( // NDP option. That is, the length field for NDP options is in units of // 8 octets, as per RFC 4861 section 4.6. lengthByteUnits = 8 -) -var ( // NDPInfiniteLifetime is a value that represents infinity for the // 4-byte lifetime fields found in various NDP options. Its value is // (2^32 - 1)s = 4294967295s. - // - // This is a variable instead of a constant so that tests can change - // this value to a smaller value. It should only be modified by tests. NDPInfiniteLifetime = time.Second * math.MaxUint32 ) @@ -451,7 +446,7 @@ func (o NDPNonceOption) String() string { // Nonce returns the nonce value this option holds. func (o NDPNonceOption) Nonce() []byte { - return []byte(o) + return o } // NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go index bf7610863..7e2f0c797 100644 --- a/pkg/tcpip/header/ndp_router_advert.go +++ b/pkg/tcpip/header/ndp_router_advert.go @@ -19,12 +19,72 @@ import ( "time" ) +// NDPRoutePreference is the preference values for default routers or +// more-specific routes. +// +// As per RFC 4191 section 2.1, +// +// Default router preferences and preferences for more-specific routes +// are encoded the same way. +// +// Preference values are encoded as a two-bit signed integer, as +// follows: +// +// 01 High +// 00 Medium (default) +// 11 Low +// 10 Reserved - MUST NOT be sent +// +// Note that implementations can treat the value as a two-bit signed +// integer. +// +// Having just three values reinforces that they are not metrics and +// more values do not appear to be necessary for reasonable scenarios. +type NDPRoutePreference uint8 + +const ( + // HighRoutePreference indicates a high preference, as per + // RFC 4191 section 2.1. + HighRoutePreference NDPRoutePreference = 0b01 + + // MediumRoutePreference indicates a medium preference, as per + // RFC 4191 section 2.1. + // + // This is the default preference value. + MediumRoutePreference = 0b00 + + // LowRoutePreference indicates a low preference, as per + // RFC 4191 section 2.1. + LowRoutePreference = 0b11 + + // ReservedRoutePreference is a reserved preference value, as per + // RFC 4191 section 2.1. + // + // It MUST NOT be sent. + ReservedRoutePreference = 0b10 +) + // NDPRouterAdvert is an NDP Router Advertisement message. It will only contain // the body of an ICMPv6 packet. // -// See RFC 4861 section 4.2 for more details. +// See RFC 4861 section 4.2 and RFC 4191 section 2.2 for more details. type NDPRouterAdvert []byte +// As per RFC 4191 section 2.2, +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type | Code | Checksum | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cur Hop Limit |M|O|H|Prf|Resvd| Router Lifetime | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Reachable Time | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Retrans Timer | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Options ... +// +-+-+-+-+-+-+-+-+-+-+-+- const ( // NDPRAMinimumSize is the minimum size of a valid NDP Router // Advertisement message (body of an ICMPv6 packet). @@ -47,6 +107,14 @@ const ( // within the bit-field/flags byte of an NDPRouterAdvert. ndpRAOtherConfFlagMask = (1 << 6) + // ndpDefaultRouterPreferenceShift is the shift of the Prf (Default Router + // Preference) field within the flags byte of an NDPRouterAdvert. + ndpDefaultRouterPreferenceShift = 3 + + // ndpDefaultRouterPreferenceMask is the mask of the Prf (Default Router + // Preference) field within the flags byte of an NDPRouterAdvert. + ndpDefaultRouterPreferenceMask = (0b11 << ndpDefaultRouterPreferenceShift) + // ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime // field within an NDPRouterAdvert. ndpRARouterLifetimeOffset = 2 @@ -80,6 +148,11 @@ func (b NDPRouterAdvert) OtherConfFlag() bool { return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0 } +// DefaultRouterPreference returns the Default Router Preference field. +func (b NDPRouterAdvert) DefaultRouterPreference() NDPRoutePreference { + return NDPRoutePreference((b[ndpRAFlagsOffset] & ndpDefaultRouterPreferenceMask) >> ndpDefaultRouterPreferenceShift) +} + // RouterLifetime returns the lifetime associated with the default router. A // value of 0 means the source of the Router Advertisement is not a default // router and SHOULD NOT appear on the default router list. Note, a value of 0 diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index 1b5093e58..8fd1f7d13 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -126,36 +126,83 @@ func TestNDPNeighborAdvert(t *testing.T) { } func TestNDPRouterAdvert(t *testing.T) { - b := []byte{ - 64, 128, 1, 2, - 3, 4, 5, 6, - 7, 8, 9, 10, + tests := []struct { + hopLimit uint8 + managedFlag, otherConfFlag bool + prf NDPRoutePreference + routerLifetimeS uint16 + reachableTimeMS, retransTimerMS uint32 + }{ + { + hopLimit: 1, + managedFlag: false, + otherConfFlag: true, + prf: HighRoutePreference, + routerLifetimeS: 2, + reachableTimeMS: 3, + retransTimerMS: 4, + }, + { + hopLimit: 64, + managedFlag: true, + otherConfFlag: false, + prf: LowRoutePreference, + routerLifetimeS: 258, + reachableTimeMS: 78492, + retransTimerMS: 13213, + }, } - ra := NDPRouterAdvert(b) + for i, test := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + flags := uint8(0) + if test.managedFlag { + flags |= 1 << 7 + } + if test.otherConfFlag { + flags |= 1 << 6 + } + flags |= uint8(test.prf) << 3 - if got := ra.CurrHopLimit(); got != 64 { - t.Errorf("got ra.CurrHopLimit = %d, want = 64", got) - } + b := []byte{ + test.hopLimit, flags, 1, 2, + 3, 4, 5, 6, + 7, 8, 9, 10, + } + binary.BigEndian.PutUint16(b[2:], test.routerLifetimeS) + binary.BigEndian.PutUint32(b[4:], test.reachableTimeMS) + binary.BigEndian.PutUint32(b[8:], test.retransTimerMS) - if got := ra.ManagedAddrConfFlag(); !got { - t.Errorf("got ManagedAddrConfFlag = false, want = true") - } + ra := NDPRouterAdvert(b) - if got := ra.OtherConfFlag(); got { - t.Errorf("got OtherConfFlag = true, want = false") - } + if got := ra.CurrHopLimit(); got != test.hopLimit { + t.Errorf("got ra.CurrHopLimit() = %d, want = %d", got, test.hopLimit) + } - if got, want := ra.RouterLifetime(), time.Second*258; got != want { - t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want) - } + if got := ra.ManagedAddrConfFlag(); got != test.managedFlag { + t.Errorf("got ManagedAddrConfFlag() = %t, want = %t", got, test.managedFlag) + } - if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want { - t.Errorf("got ra.ReachableTime = %d, want = %d", got, want) - } + if got := ra.OtherConfFlag(); got != test.otherConfFlag { + t.Errorf("got OtherConfFlag() = %t, want = %t", got, test.otherConfFlag) + } + + if got := ra.DefaultRouterPreference(); got != test.prf { + t.Errorf("got DefaultRouterPreference() = %d, want = %d", got, test.prf) + } - if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want { - t.Errorf("got ra.RetransTimer = %d, want = %d", got, want) + if got, want := ra.RouterLifetime(), time.Second*time.Duration(test.routerLifetimeS); got != want { + t.Errorf("got ra.RouterLifetime() = %d, want = %d", got, want) + } + + if got, want := ra.ReachableTime(), time.Millisecond*time.Duration(test.reachableTimeMS); got != want { + t.Errorf("got ra.ReachableTime() = %d, want = %d", got, want) + } + + if got, want := ra.RetransTimer(), time.Millisecond*time.Duration(test.retransTimerMS); got != want { + t.Errorf("got ra.RetransTimer() = %d, want = %d", got, want) + } + }) } } diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 0df517000..a75e51a28 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -48,6 +48,16 @@ const ( // TCPFlags is the dedicated type for TCP flags. type TCPFlags uint8 +// Intersects returns true iff there are flags common to both f and o. +func (f TCPFlags) Intersects(o TCPFlags) bool { + return f&o != 0 +} + +// Contains returns true iff all the flags in o are contained within f. +func (f TCPFlags) Contains(o TCPFlags) bool { + return f&o == o +} + // String implements Stringer.String. func (f TCPFlags) String() string { flagsStr := []byte("FSRPAU") @@ -380,6 +390,35 @@ func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32 b.SetChecksum(^checksum) } +// SetSourcePortWithChecksumUpdate implements ChecksummableTransport. +func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) { + old := b.SourcePort() + b.SetSourcePort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport. +func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) { + old := b.DestinationPort() + b.SetDestinationPort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport. +func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) { + xsum := b.Checksum() + if fullChecksum { + xsum = ^xsum + } + + xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new) + if fullChecksum { + xsum = ^xsum + } + + b.SetChecksum(xsum) +} + // ParseSynOptions parses the options received in a SYN segment and returns the // relevant ones. opts should point to the option part of the TCP header. func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index ae9d167ff..f69d53314 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -130,3 +130,32 @@ func (b UDP) Encode(u *UDPFields) { binary.BigEndian.PutUint16(b[udpLength:], u.Length) binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum) } + +// SetSourcePortWithChecksumUpdate implements ChecksummableTransport. +func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) { + old := b.SourcePort() + b.SetSourcePort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport. +func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) { + old := b.DestinationPort() + b.SetDestinationPort(new) + b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new)) +} + +// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport. +func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) { + xsum := b.Checksum() + if fullChecksum { + xsum = ^xsum + } + + xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new) + if fullChecksum { + xsum = ^xsum + } + + b.SetChecksum(xsum) +} diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index ef9126deb..f26c857eb 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -288,5 +288,5 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { } diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index d971194e6..1d0163823 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -14,7 +14,6 @@ go_library( ], visibility = ["//visibility:public"], deps = [ - "//pkg/iovec", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index bddb1d0a2..1b56d2b72 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,11 +41,9 @@ package fdbased import ( "fmt" - "math" "sync/atomic" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -139,6 +137,20 @@ type endpoint struct { // gsoKind is the supported kind of GSO. gsoKind stack.SupportedGSO + + // maxSyscallHeaderBytes has the same meaning as + // Options.MaxSyscallHeaderBytes. + maxSyscallHeaderBytes uintptr + + // writevMaxIovs is the maximum number of iovecs that may be passed to + // rawfile.NonBlockingWriteIovec, as possibly limited by + // maxSyscallHeaderBytes. (No analogous limit is defined for + // rawfile.NonBlockingSendMMsg, since in that case the maximum number of + // iovecs also depends on the number of mmsghdrs. Instead, if sendBatch + // encounters a packet whose iovec count is limited by + // maxSyscallHeaderBytes, it falls back to writing the packet using writev + // via WritePacket.) + writevMaxIovs int } // Options specify the details about the fd-based endpoint to be created. @@ -187,6 +199,11 @@ type Options struct { // RXChecksumOffload if true, indicates that this endpoints capability // set should include CapabilityRXChecksumOffload. RXChecksumOffload bool + + // If MaxSyscallHeaderBytes is non-zero, it is the maximum number of bytes + // of struct iovec, msghdr, and mmsghdr that may be passed by each host + // system call. + MaxSyscallHeaderBytes int } // fanoutID is used for AF_PACKET based endpoints to enable PACKET_FANOUT @@ -196,8 +213,12 @@ type Options struct { // option for an FD with a fanoutID already in use by another FD for a different // NIC will return an EINVAL. // +// Since fanoutID must be unique within the network namespace, we start with +// the PID to avoid collisions. The only way to be sure of avoiding collisions +// is to run in a new network namespace. +// // Must be accessed using atomic operations. -var fanoutID int32 = 0 +var fanoutID int32 = int32(unix.Getpid()) // New creates a new fd-based endpoint. // @@ -232,14 +253,25 @@ func New(opts *Options) (stack.LinkEndpoint, error) { return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified") } + if opts.MaxSyscallHeaderBytes < 0 { + return nil, fmt.Errorf("opts.MaxSyscallHeaderBytes is negative") + } + e := &endpoint{ - fds: opts.FDs, - mtu: opts.MTU, - caps: caps, - closed: opts.ClosedFunc, - addr: opts.Address, - hdrSize: hdrSize, - packetDispatchMode: opts.PacketDispatchMode, + fds: opts.FDs, + mtu: opts.MTU, + caps: caps, + closed: opts.ClosedFunc, + addr: opts.Address, + hdrSize: hdrSize, + packetDispatchMode: opts.PacketDispatchMode, + maxSyscallHeaderBytes: uintptr(opts.MaxSyscallHeaderBytes), + writevMaxIovs: rawfile.MaxIovs, + } + if e.maxSyscallHeaderBytes != 0 { + if max := int(e.maxSyscallHeaderBytes / rawfile.SizeofIovec); max < e.writevMaxIovs { + e.writevMaxIovs = max + } } // Increment fanoutID to ensure that we don't re-use the same fanoutID for @@ -292,11 +324,6 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin } switch sa.(type) { case *unix.SockaddrLinklayer: - // See: PACKET_FANOUT_MAX in net/packet/internal.h - const packetFanoutMax = 1 << 16 - if fID > packetFanoutMax { - return nil, fmt.Errorf("host fanoutID limit exceeded, fanoutID must be <= %d", math.MaxUint16) - } // Enable PACKET_FANOUT mode if the underlying socket is of type // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will // prevent gvisor from receiving fragmented packets and the host does the @@ -317,7 +344,7 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin // // See: https://github.com/torvalds/linux/blob/7acac4b3196caee5e21fb5ea53f8bc124e6a16fc/net/packet/af_packet.c#L3881 const fanoutType = unix.PACKET_FANOUT_HASH - fanoutArg := int(fID) | fanoutType<<16 + fanoutArg := (int(fID) & 0xffff) | fanoutType<<16 if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err) } @@ -472,9 +499,8 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) } - var builder iovec.Builder - fd := e.fds[pkt.Hash%uint32(len(e.fds))] + var vnetHdrBuf []byte if e.gsoKind == stack.HWGSOSupported { vnetHdr := virtioNetHdr{} if pkt.GSOOptions.Type != stack.GSONone { @@ -496,71 +522,123 @@ func (e *endpoint) WritePacket(r stack.RouteInfo, protocol tcpip.NetworkProtocol vnetHdr.gsoSize = pkt.GSOOptions.MSS } } + vnetHdrBuf = vnetHdr.marshal() + } - vnetHdrBuf := vnetHdr.marshal() - builder.Add(vnetHdrBuf) + views := pkt.Views() + numIovecs := len(views) + if len(vnetHdrBuf) != 0 { + numIovecs++ + } + if numIovecs > e.writevMaxIovs { + numIovecs = e.writevMaxIovs } - for _, v := range pkt.Views() { - builder.Add(v) + // Allocate small iovec arrays on the stack. + var iovecsArr [8]unix.Iovec + iovecs := iovecsArr[:0] + if numIovecs > len(iovecsArr) { + iovecs = make([]unix.Iovec, 0, numIovecs) } - return rawfile.NonBlockingWriteIovec(fd, builder.Build()) + iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs) + for _, v := range views { + iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs) + } + return rawfile.NonBlockingWriteIovec(fd, iovecs) } -func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, tcpip.Error) { +func (e *endpoint) sendBatch(batchFD int, pkts []*stack.PacketBuffer) (int, tcpip.Error) { // Send a batch of packets through batchFD. - mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) - for _, pkt := range batch { - if e.hdrSize > 0 { - e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) - } + mmsgHdrsStorage := make([]rawfile.MMsgHdr, 0, len(pkts)) + packets := 0 + for packets < len(pkts) { + mmsgHdrs := mmsgHdrsStorage + batch := pkts[packets:] + syscallHeaderBytes := uintptr(0) + for _, pkt := range batch { + if e.hdrSize > 0 { + e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) + } - var vnetHdrBuf []byte - if e.gsoKind == stack.HWGSOSupported { - vnetHdr := virtioNetHdr{} - if pkt.GSOOptions.Type != stack.GSONone { - vnetHdr.hdrLen = uint16(pkt.HeaderSize()) - if pkt.GSOOptions.NeedsCsum { - vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM - vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen - vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset - } - if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS { - switch pkt.GSOOptions.Type { - case stack.GSOTCPv4: - vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 - case stack.GSOTCPv6: - vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 - default: - panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type)) + var vnetHdrBuf []byte + if e.gsoKind == stack.HWGSOSupported { + vnetHdr := virtioNetHdr{} + if pkt.GSOOptions.Type != stack.GSONone { + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) + if pkt.GSOOptions.NeedsCsum { + vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM + vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen + vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset + } + if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS { + switch pkt.GSOOptions.Type { + case stack.GSOTCPv4: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4 + case stack.GSOTCPv6: + vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6 + default: + panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type)) + } + vnetHdr.gsoSize = pkt.GSOOptions.MSS } - vnetHdr.gsoSize = pkt.GSOOptions.MSS } + vnetHdrBuf = vnetHdr.marshal() } - vnetHdrBuf = vnetHdr.marshal() - } - var builder iovec.Builder - builder.Add(vnetHdrBuf) - for _, v := range pkt.Views() { - builder.Add(v) - } - iovecs := builder.Build() + views := pkt.Views() + numIovecs := len(views) + if len(vnetHdrBuf) != 0 { + numIovecs++ + } + if numIovecs > rawfile.MaxIovs { + numIovecs = rawfile.MaxIovs + } + if e.maxSyscallHeaderBytes != 0 { + syscallHeaderBytes += rawfile.SizeofMMsgHdr + uintptr(numIovecs)*rawfile.SizeofIovec + if syscallHeaderBytes > e.maxSyscallHeaderBytes { + // We can't fit this packet into this call to sendmmsg(). + // We could potentially do so if we reduced numIovecs + // further, but this might incur considerable extra + // copying. Leave it to the next batch instead. + break + } + } - var mmsgHdr rawfile.MMsgHdr - mmsgHdr.Msg.Iov = &iovecs[0] - mmsgHdr.Msg.SetIovlen((len(iovecs))) - mmsgHdrs = append(mmsgHdrs, mmsgHdr) - } + // We can't easily allocate iovec arrays on the stack here since + // they will escape this loop iteration via mmsgHdrs. + iovecs := make([]unix.Iovec, 0, numIovecs) + iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs) + for _, v := range views { + iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs) + } - packets := 0 - for len(mmsgHdrs) > 0 { - sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs) - if err != nil { - return packets, err + var mmsgHdr rawfile.MMsgHdr + mmsgHdr.Msg.Iov = &iovecs[0] + mmsgHdr.Msg.SetIovlen(len(iovecs)) + mmsgHdrs = append(mmsgHdrs, mmsgHdr) + } + + if len(mmsgHdrs) == 0 { + // We can't fit batch[0] into a mmsghdr while staying under + // e.maxSyscallHeaderBytes. Use WritePacket, which will avoid the + // mmsghdr (by using writev) and re-buffer iovecs more aggressively + // if necessary (by using e.writevMaxIovs instead of + // rawfile.MaxIovs). + pkt := batch[0] + if err := e.WritePacket(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { + return packets, err + } + packets++ + } else { + for len(mmsgHdrs) > 0 { + sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs) + if err != nil { + return packets, err + } + packets += sent + mmsgHdrs = mmsgHdrs[sent:] + } } - packets += sent - mmsgHdrs = mmsgHdrs[sent:] } return packets, nil @@ -678,8 +756,9 @@ func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabiliti unix.SetNonblock(fd, true) return &InjectableEndpoint{endpoint: endpoint{ - fds: []int{fd}, - mtu: mtu, - caps: capabilities, + fds: []int{fd}, + mtu: mtu, + caps: capabilities, + writevMaxIovs: rawfile.MaxIovs, }} } diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index ba92aedbc..43fe57830 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -19,12 +19,66 @@ package rawfile import ( + "reflect" "unsafe" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" ) +// SizeofIovec is the size of a unix.Iovec in bytes. +const SizeofIovec = unsafe.Sizeof(unix.Iovec{}) + +// MaxIovs is UIO_MAXIOV, the maximum number of iovecs that may be passed to a +// host system call in a single array. +const MaxIovs = 1024 + +// IovecFromBytes returns a unix.Iovec representing bs. +// +// Preconditions: len(bs) > 0. +func IovecFromBytes(bs []byte) unix.Iovec { + iov := unix.Iovec{ + Base: &bs[0], + } + iov.SetLen(len(bs)) + return iov +} + +func bytesFromIovec(iov unix.Iovec) (bs []byte) { + sh := (*reflect.SliceHeader)(unsafe.Pointer(&bs)) + sh.Data = uintptr(unsafe.Pointer(iov.Base)) + sh.Len = int(iov.Len) + sh.Cap = int(iov.Len) + return +} + +// AppendIovecFromBytes returns append(iovs, IovecFromBytes(bs)). If len(bs) == +// 0, AppendIovecFromBytes returns iovs without modification. If len(iovs) >= +// max, AppendIovecFromBytes replaces the final iovec in iovs with one that +// also includes the contents of bs. Note that this implies that +// AppendIovecFromBytes is only usable when the returned iovec slice is used as +// the source of a write. +func AppendIovecFromBytes(iovs []unix.Iovec, bs []byte, max int) []unix.Iovec { + if len(bs) == 0 { + return iovs + } + if len(iovs) < max { + return append(iovs, IovecFromBytes(bs)) + } + iovs[len(iovs)-1] = IovecFromBytes(append(bytesFromIovec(iovs[len(iovs)-1]), bs...)) + return iovs +} + +// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux. +type MMsgHdr struct { + Msg unix.Msghdr + Len uint32 + _ [4]byte +} + +// SizeofMMsgHdr is the size of a MMsgHdr in bytes. +const SizeofMMsgHdr = unsafe.Sizeof(MMsgHdr{}) + // GetMTU determines the MTU of a network interface device. func GetMTU(name string) (uint32, error) { fd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_DGRAM, 0) @@ -137,13 +191,6 @@ func BlockingReadv(fd int, iovecs []unix.Iovec) (int, tcpip.Error) { } } -// MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux. -type MMsgHdr struct { - Msg unix.Msghdr - Len uint32 - _ [4]byte -} - // BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking // and stores the received messages in a slice of MMsgHdr structures. If no data // is available, it will block in a poll() syscall until the file descriptor diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 7656cca6a..4758a99ad 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -26,6 +26,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/errors/linuxerr", "//pkg/log", "//pkg/refs", "//pkg/refsvfs2", diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index af9a3d759..59ac5cb8c 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -88,12 +89,12 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error { defer d.mu.Unlock() if d.endpoint != nil { - return syserror.EINVAL + return linuxerr.EINVAL } // Input validation. if flags.TAP && flags.TUN || !flags.TAP && !flags.TUN { - return syserror.EINVAL + return linuxerr.EINVAL } prefix := "tun" @@ -108,7 +109,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags Flags) error { endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps) if err != nil { - return syserror.EINVAL + return linuxerr.EINVAL } d.endpoint = endpoint @@ -159,7 +160,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE // Race detected: A NIC has been created in between. continue default: - return nil, syserror.EINVAL + return nil, linuxerr.EINVAL } } } diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index a72eb1aad..6fa1aee18 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -28,13 +28,13 @@ go_test( ":arp", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 0efa3a926..6515c31e5 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -278,7 +278,7 @@ func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) { return "", "" } -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ protocol: p, nic: nic, diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 94209b026..5fcbfeaa2 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -15,15 +15,14 @@ package arp_test import ( - "context" "fmt" "testing" - "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -31,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) const ( @@ -39,15 +37,6 @@ const ( stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") - - defaultChannelSize = 1 - defaultMTU = 65536 - - // eventChanSize defines the size of event channels used by the neighbor - // cache's event dispatcher. The size chosen here needs to be sufficient to - // queue all the events received during tests before consumption. - // If eventChanSize is too small, the tests may deadlock. - eventChanSize = 32 ) var ( @@ -123,24 +112,6 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo d.C <- e } -func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { - select { - case got := <-d.C: - if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { - return fmt.Errorf("got invalid event (-want +got):\n%s", diff) - } - case <-ctx.Done(): - return fmt.Errorf("%s for %s", ctx.Err(), want) - } - return nil -} - -func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return d.waitForEvent(ctx, want) -} - func (d *arpDispatcher) nextEvent() (eventInfo, bool) { select { case event := <-d.C: @@ -153,55 +124,45 @@ func (d *arpDispatcher) nextEvent() (eventInfo, bool) { type testContext struct { s *stack.Stack linkEP *channel.Endpoint - nudDisp *arpDispatcher + nudDisp arpDispatcher } -func newTestContext(t *testing.T) *testContext { - c := stack.DefaultNUDConfigurations() - // Transition from Reachable to Stale almost immediately to test if receiving - // probes refreshes positive reachability. - c.BaseReachableTime = time.Microsecond - - d := arpDispatcher{ - // Create an event channel large enough so the neighbor cache doesn't block - // while dispatching events. Blocking could interfere with the timing of - // NUD transitions. - C: make(chan eventInfo, eventChanSize), +func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext { + t.Helper() + + tc := testContext{ + nudDisp: arpDispatcher{ + C: make(chan eventInfo, eventDepth), + }, } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - NUDConfigs: c, - NUDDisp: &d, + tc.s = stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + NUDDisp: &tc.nudDisp, + Clock: &faketime.NullClock{}, }) - ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - wep := stack.LinkEndpoint(ep) + tc.linkEP = channel.New(packetDepth, header.IPv4MinimumMTU, stackLinkAddr) + tc.linkEP.LinkEPCapabilities |= stack.CapabilityResolutionRequired + wep := stack.LinkEndpoint(tc.linkEP) if testing.Verbose() { - wep = sniffer.New(ep) + wep = sniffer.New(wep) } - if err := s.CreateNIC(nicID, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + if err := tc.s.CreateNIC(nicID, wep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) + if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %s", err) } - s.SetRouteTable([]tcpip.Route{{ + tc.s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, NIC: nicID, }}) - return &testContext{ - s: s, - linkEP: ep, - nudDisp: &d, - } + return tc } func (c *testContext) cleanup() { @@ -209,7 +170,7 @@ func (c *testContext) cleanup() { } func TestMalformedPacket(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 0, 0) defer c.cleanup() v := make(buffer.View, header.ARPSize) @@ -228,7 +189,7 @@ func TestMalformedPacket(t *testing.T) { } func TestDisabledEndpoint(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 0, 0) defer c.cleanup() ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber) @@ -253,7 +214,7 @@ func TestDisabledEndpoint(t *testing.T) { } func TestDirectReply(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 0, 0) defer c.cleanup() const senderMAC = "\x01\x02\x03\x04\x05\x06" @@ -284,7 +245,7 @@ func TestDirectReply(t *testing.T) { } func TestDirectRequest(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 1, 1) defer c.cleanup() tests := []struct { @@ -391,17 +352,21 @@ func TestDirectRequest(t *testing.T) { } // Verify the sender was saved in the neighbor cache. - wantEvent := eventInfo{ - eventType: entryAdded, - nicID: nicID, - entry: stack.NeighborEntry{ - Addr: test.senderAddr, - LinkAddr: tcpip.LinkAddress(test.senderLinkAddr), - State: stack.Stale, - }, - } - if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil { - t.Fatal(err) + if got, ok := c.nudDisp.nextEvent(); ok { + want := eventInfo{ + eventType: entryAdded, + nicID: nicID, + entry: stack.NeighborEntry{ + Addr: test.senderAddr, + LinkAddr: test.senderLinkAddr, + State: stack.Stale, + }, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" { + t.Errorf("got invalid event (-want +got):\n%s", diff) + } + } else { + t.Fatal("event didn't arrive") } neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber) @@ -589,7 +554,7 @@ func TestLinkAddressRequest(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, }) - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + linkEP := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr) if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } @@ -663,15 +628,16 @@ func TestLinkAddressRequest(t *testing.T) { } func TestDADARPRequestPacket(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocolWithOptions(arp.Options{ DADConfigs: stack.DADConfigurations{ DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, }, }), ipv4.NewProtocol}, + Clock: clock, }) - e := channel.New(1, defaultMTU, stackLinkAddr) + e := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } @@ -682,7 +648,8 @@ func TestDADARPRequestPacket(t *testing.T) { t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, header.IPv4ProtocolNumber, remoteAddr, res, stack.DADStarting) } - pkt, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + pkt, ok := e.Read() if !ok { t.Fatal("expected to send an ARP request") } diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation.go b/pkg/tcpip/network/internal/fragmentation/fragmentation.go index 5168f5361..1ba4d0d36 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation.go @@ -251,7 +251,7 @@ func (f *Fragmentation) releaseReassemblersLocked() { // The list is empty. break } - elapsed := time.Duration(now-r.creationTime) * time.Nanosecond + elapsed := now.Sub(r.createdAt) if f.timeout > elapsed { // If the oldest reassembler has not expired, schedule the release // job so that this function is called back when it has expired. diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go index 7daf64b4a..dadfc28cc 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go @@ -275,15 +275,23 @@ func TestMemoryLimits(t *testing.T) { highLimit := 3 * lowLimit // Allow at most 3 such packets. f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")) + if _, _, _, err := f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")) + if _, _, _, err := f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")) + if _, _, _, err := f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")); err != nil { + t.Fatal(err) + } if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -300,9 +308,13 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { memSize := pkt(1, "0").MemSize() f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } if got, want := f.memSize, memSize; got != want { t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 56b76a284..5b7e4b361 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -35,21 +35,21 @@ type hole struct { type reassembler struct { reassemblerEntry - id FragmentID - memSize int - proto uint8 - mu sync.Mutex - holes []hole - filled int - done bool - creationTime int64 - pkt *stack.PacketBuffer + id FragmentID + memSize int + proto uint8 + mu sync.Mutex + holes []hole + filled int + done bool + createdAt tcpip.MonotonicTime + pkt *stack.PacketBuffer } func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { r := &reassembler{ - id: id, - creationTime: clock.NowMonotonic(), + id: id, + createdAt: clock.NowMonotonic(), } r.holes = append(r.holes, hole{ first: 0, diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go index a22b712c6..24687cf06 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go @@ -133,7 +133,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled) } // Wait for any initially fired timers to complete. - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check(nil); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -147,7 +147,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -156,7 +156,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check(nil); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -170,7 +170,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr2}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -208,7 +208,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -247,7 +247,7 @@ func TestDADStop(t *testing.T) { if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -272,7 +272,7 @@ func TestDADStop(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -347,7 +347,7 @@ func TestNonce(t *testing.T) { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() for i, want := range test.expectedResults { if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want { t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want) diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go index d22974b12..671dfbf32 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go @@ -611,7 +611,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr // If a timer for any address is already running, it is reset to the new // random value only if the requested Maximum Response Delay is less than // the remaining value of the running timer. - now := time.Unix(0 /* seconds */, g.opts.Clock.NowNanoseconds()) + now := g.opts.Clock.Now() if info.state == delayingMember { if info.delayedReportJobFiresAt.IsZero() { panic(fmt.Sprintf("delayed report unscheduled while in the delaying member state; group = %s", groupAddress)) diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index 0c2b62127..40ab21cb6 100644 --- a/pkg/tcpip/network/internal/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -42,6 +42,10 @@ type MultiCounterIPForwardingStats struct { // were too big for the outgoing MTU. PacketTooBig tcpip.MultiCounterStat + // HostUnreachable is the number of IP packets received which could not be + // successfully forwarded due to an unresolvable next hop. + HostUnreachable tcpip.MultiCounterStat + // ExtensionHeaderProblem is the number of IP packets which were dropped // because of a problem encountered when processing an IPv6 extension // header. @@ -61,6 +65,7 @@ func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) { m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem) m.PacketTooBig.Init(a.PacketTooBig, b.PacketTooBig) m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL) + m.HostUnreachable.Init(a.HostUnreachable, b.HostUnreachable) } // LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats) diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD index cec3e62c4..a180e5c75 100644 --- a/pkg/tcpip/network/internal/testutil/BUILD +++ b/pkg/tcpip/network/internal/testutil/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/tcpip/network/internal/fragmentation:__pkg__", "//pkg/tcpip/network/ipv4:__pkg__", "//pkg/tcpip/network/ipv6:__pkg__", + "//pkg/tcpip/tests/integration:__pkg__", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index bd63e0289..771b9173a 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -88,6 +88,7 @@ type testObject struct { dataCalls int controlCalls int + rawCalls int } // checkValues verifies that the transport protocol, data contents, src & dst @@ -148,6 +149,10 @@ func (t *testObject) DeliverTransportError(local, remote tcpip.Address, net tcpi t.controlCalls++ } +func (t *testObject) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) { + t.rawCalls++ +} + // Attach is only implemented to satisfy the LinkEndpoint interface. func (*testObject) Attach(stack.NetworkDispatcher) {} @@ -717,7 +722,10 @@ func TestReceive(t *testing.T) { } test.handlePacket(t, ep, &nic) if nic.testObject.dataCalls != 1 { - t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + t.Errorf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls) + } + if nic.testObject.rawCalls != 1 { + t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls) } if got := stat.Value(); got != 1 { t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got) @@ -968,7 +976,10 @@ func TestIPv4FragmentationReceive(t *testing.T) { ep.HandlePacket(pkt) if nic.testObject.dataCalls != 0 { - t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) + t.Fatalf("Bad number of data calls: got %d, want 0", nic.testObject.dataCalls) + } + if nic.testObject.rawCalls != 0 { + t.Errorf("Bad number of raw calls: got %d, want 0", nic.testObject.rawCalls) } // Send second segment. @@ -977,7 +988,10 @@ func TestIPv4FragmentationReceive(t *testing.T) { }) ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + t.Fatalf("Bad number of data calls: got %d, want 1", nic.testObject.dataCalls) + } + if nic.testObject.rawCalls != 1 { + t.Errorf("Bad number of raw calls: got %d, want 1", nic.testObject.rawCalls) } } @@ -1310,7 +1324,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) return hdr.View().ToVectorisedView() }, @@ -1351,7 +1365,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) ip.SetHeaderLength(header.IPv4MinimumSize - 1) return hdr.View().ToVectorisedView() @@ -1370,7 +1384,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, @@ -1388,7 +1402,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) return buffer.View(ip).ToVectorisedView() }, @@ -1430,7 +1444,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, Options: ipv4Options, }) return hdr.View().ToVectorisedView() @@ -1469,7 +1483,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { Protocol: transportProto, TTL: ipv4.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, Options: ipv4Options, }) vv := buffer.View(ip).ToVectorisedView() @@ -1515,7 +1529,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { TransportProtocol: transportProto, HopLimit: ipv6.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv6Addr, }) return hdr.View().ToVectorisedView() }, @@ -1560,7 +1574,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier), HopLimit: ipv6.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv6Addr, }) return hdr.View().ToVectorisedView() }, @@ -1595,7 +1609,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { TransportProtocol: transportProto, HopLimit: ipv6.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv6Addr, }) return buffer.View(ip).ToVectorisedView() }, @@ -1630,7 +1644,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { TransportProtocol: transportProto, HopLimit: ipv6.DefaultTTL, SrcAddr: src, - DstAddr: header.IPv4Any, + DstAddr: remoteIPv4Addr, }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index d1a82b584..2aa38eb98 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -173,9 +173,8 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { received := e.stats.icmp.packetsReceived - // TODO(gvisor.dev/issue/170): ICMP packets don't have their - // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set. See + // icmp/protocol.go:protocol.Parse for a full explanation. v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize) if !ok { received.invalid.Increment() @@ -222,7 +221,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() e.stats.ip.MalformedPacketsReceived.Increment() } return @@ -481,6 +479,22 @@ func (*icmpReasonFragmentationNeeded) isForwarding() bool { return true } +// icmpReasonHostUnreachable is an error in which the host specified in the +// internet destination field of the datagram is unreachable. +type icmpReasonHostUnreachable struct{} + +func (*icmpReasonHostUnreachable) isICMPReason() {} +func (*icmpReasonHostUnreachable) isForwarding() bool { + // If we hit a Host Unreachable error, then we know we are operating as a + // router. As per RFC 792 page 5, Destination Unreachable Message, + // + // In addition, in some networks, the gateway may be able to determine + // if the internet destination host is unreachable. Gateways in these + // networks may send destination unreachable messages to the source host + // when the destination host is unreachable. + return true +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as @@ -537,7 +551,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip defer route.Release() p.mu.Lock() - netEP, ok := p.mu.eps[pkt.NICID] + // We retrieve an endpoint using the newly constructed route's NICID rather + // than the packet's NICID. The packet's NICID corresponds to the NIC on + // which it arrived, which isn't necessarily the same as the NIC on which it + // will be transmitted. On the other hand, the route's NIC *is* guaranteed + // to be the NIC on which the packet will be transmitted. + netEP, ok := p.mu.eps[route.NICID()] p.mu.Unlock() if !ok { return &tcpip.ErrNotConnected{} @@ -653,6 +672,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4NetUnreachable) counter = sent.dstUnreachable + case *icmpReasonHostUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4HostUnreachable) + counter = sent.dstUnreachable case *icmpReasonFragmentationNeeded: icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4FragmentationNeeded) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 23178277a..44c85bdb8 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -82,7 +82,7 @@ type endpoint struct { protocol *protocol stats sharedStats - // enabled is set to 1 when the enpoint is enabled and 0 when it is + // enabled is set to 1 when the endpoint is enabled and 0 when it is // disabled. // // Must be accessed using atomic operations. @@ -104,6 +104,16 @@ type endpoint struct { // HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // If we are operating as a router, return an ICMP error to the original + // packet's sender. + if pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP + // errors to local endpoints. + e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt) + e.stats.ip.Forwarding.Errors.Increment() + e.stats.ip.Forwarding.HostUnreachable.Increment() + return + } // handleControl expects the entire offending packet to be in the packet // buffer's data field. pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -322,8 +332,8 @@ func (e *endpoint) DefaultTTL() uint8 { return e.protocol.DefaultTTL() } -// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus -// the network layer max header length. +// MTU implements stack.NetworkEndpoint. It returns the link-layer MTU minus the +// network layer max header length. func (e *endpoint) MTU() uint32 { networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize) if err != nil { @@ -338,7 +348,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize } -// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +// NetworkProtocolNumber implements stack.NetworkEndpoint. func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } @@ -353,7 +363,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet if hdrLen > header.IPv4MaximumHeaderSize { return &tcpip.ErrMessageTooLong{} } - ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) + ipH := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) length := pkt.Size() if length > math.MaxUint16 { return &tcpip.ErrMessageTooLong{} @@ -362,7 +372,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1) - ip.Encode(&header.IPv4Fields{ + ipH.Encode(&header.IPv4Fields{ TotalLength: uint16(length), ID: uint16(id), TTL: params.TTL, @@ -372,7 +382,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet DstAddr: dstAddr, Options: options, }) - ip.SetChecksum(^ip.CalculateChecksum()) + ipH.SetChecksum(^ipH.CalculateChecksum()) pkt.NetworkProtocolNumber = ProtocolNumber return nil } @@ -381,7 +391,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the // original packet. -func (e *endpoint) handleFragments(r *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { +func (e *endpoint) handleFragments(_ *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { // Round the MTU down to align to 8 bytes. fragmentPayloadSize := networkMTU &^ 7 networkHeader := header.IPv4(pkt.NetworkHeader().View()) @@ -419,9 +429,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // based on destination address and do not send the packet to link // layer. // - // TODO(gvisor.dev/issue/170): We should do this for every - // packet, rather than only NATted packets, but removing this check - // short circuits broadcasts before they are sent out to other hosts. + // We should do this for every packet, rather than only NATted packets, but + // removing this check short circuits broadcasts before they are sent out to + // other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { @@ -490,7 +500,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn return nil } -// WritePackets implements stack.NetworkEndpoint.WritePackets. +// WritePackets implements stack.NetworkEndpoint. func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop()&stack.PacketLoop != 0 { panic("multiple packets in local loop") @@ -593,34 +603,30 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu if !ok { return &tcpip.ErrMalformedHeader{} } - ip := header.IPv4(h) + ipH := header.IPv4(h) // Always set the total length. pktSize := pkt.Data().Size() - ip.SetTotalLength(uint16(pktSize)) + ipH.SetTotalLength(uint16(pktSize)) // Set the source address when zero. - if ip.SourceAddress() == header.IPv4Any { - ip.SetSourceAddress(r.LocalAddress()) + if ipH.SourceAddress() == header.IPv4Any { + ipH.SetSourceAddress(r.LocalAddress()) } - // Set the destination. If the packet already included a destination, it will - // be part of the route anyways. - ip.SetDestinationAddress(r.RemoteAddress()) - // Set the packet ID when zero. - if ip.ID() == 0 { + if ipH.ID() == 0 { // RFC 6864 section 4.3 mandates uniqueness of ID values for // non-atomic datagrams, so assign an ID to all such datagrams // according to the definition given in RFC 6864 section 4. - if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { - ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress(), r.RemoteAddress(), 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) + if ipH.Flags()&header.IPv4FlagDontFragment == 0 || ipH.Flags()&header.IPv4FlagMoreFragments != 0 || ipH.FragmentOffset() > 0 { + ipH.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress(), r.RemoteAddress(), 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } } // Always set the checksum. - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) + ipH.SetChecksum(0) + ipH.SetChecksum(^ipH.CalculateChecksum()) // Populate the packet buffer's network header and don't allow an invalid // packet to be sent. @@ -710,7 +716,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return nil } - ep.handleValidatedPacket(h, pkt) + // The packet originally arrived on e so provide its NIC as the input NIC. + ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) return nil } @@ -826,7 +833,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } // handleLocalPacket is like HandlePacket except it does not perform the @@ -845,10 +852,17 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum return } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } -func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) { + // Raw socket packets are delivered based solely on the transport protocol + // number. We only require that the packet be valid IPv4, and that they not + // be fragmented. + if !h.More() && h.FragmentOffset() == 0 { + e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) + } + pkt.NICID = e.nic.ID() stats := e.stats stats.ip.ValidPacketsReceived.Increment() @@ -898,7 +912,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) case *ip.ErrNoRoute: stats.ip.Forwarding.Unrouteable.Increment() case *ip.ErrParameterProblem: - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() stats.ip.MalformedPacketsReceived.Increment() case *ip.ErrMessageTooLong: stats.ip.Forwarding.PacketTooBig.Increment() @@ -911,8 +924,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return @@ -935,7 +947,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() e.stats.ip.MalformedPacketsReceived.Increment() } return @@ -985,8 +996,11 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) // The reassembler doesn't take care of fixing up the header, so we need // to do it here. - h.SetTotalLength(uint16(pkt.Data().Size() + len((h)))) + h.SetTotalLength(uint16(pkt.Data().Size() + len(h))) h.SetFlagsFragmentOffset(0, 0) + + // Now that the packet is reassembled, it can be sent to raw sockets. + e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) } stats.ip.PacketsDelivered.Increment() @@ -1008,7 +1022,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() stats.ip.MalformedPacketsReceived.Increment() } return @@ -1217,13 +1230,13 @@ func (p *protocol) DefaultPrefixLen() int { return header.IPv4AddressSize * 8 } -// ParseAddresses implements NetworkProtocol.ParseAddresses. +// ParseAddresses implements stack.NetworkProtocol. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv4(v) return h.SourceAddress(), h.DestinationAddress() } -// SetOption implements NetworkProtocol.SetOption. +// SetOption implements stack.NetworkProtocol. func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: @@ -1234,7 +1247,7 @@ func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.E } } -// Option implements NetworkProtocol.Option. +// Option implements stack.NetworkProtocol. func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: @@ -1255,10 +1268,10 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } -// Close implements stack.TransportProtocol.Close. +// Close implements stack.TransportProtocol. func (*protocol) Close() {} -// Wait implements stack.TransportProtocol.Wait. +// Wait implements stack.TransportProtocol. func (*protocol) Wait() {} // parseAndValidate parses the packet (including its transport layer header) and @@ -1296,7 +1309,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) return h, true } -// Parse implements stack.NetworkProtocol.Parse. +// Parse implements stack.NetworkProtocol. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { if ok := parse.IPv4(pkt); !ok { return 0, false, false @@ -1326,7 +1339,7 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error networkMTU = MaxTotalSize } - return networkMTU - uint32(networkHeaderSize), nil + return networkMTU - networkHeaderSize, nil } func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32) bool { @@ -1735,9 +1748,8 @@ type optionTracker struct { // // If there were no errors during parsing, the new set of options is returned as // a new buffer. -func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (header.IPv4Options, optionTracker, *header.IPv4OptParameterProblem) { +func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, opts header.IPv4Options, usage optionsUsage) (header.IPv4Options, optionTracker, *header.IPv4OptParameterProblem) { stats := e.stats.ip - opts := header.IPv4Options(orig) optIter := opts.MakeIterator() // Except NOP, each option must only appear at most once (RFC 791 section 3.1, diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index da9cc0ae8..4a4448cf9 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -16,7 +16,6 @@ package ipv4_test import ( "bytes" - "context" "encoding/hex" "fmt" "io/ioutil" @@ -36,10 +35,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" + iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" - tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -112,10 +111,6 @@ func TestExcludeBroadcast(t *testing.T) { }) } -type forwardedPacket struct { - fragments []fragmentInfo -} - func TestForwarding(t *testing.T) { const ( incomingNICID = 1 @@ -134,11 +129,11 @@ func TestForwarding(t *testing.T) { PrefixLen: 8, } outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - remoteIPv4Addr1 := tcptestutil.MustParse4("10.0.0.2") - remoteIPv4Addr2 := tcptestutil.MustParse4("11.0.0.2") - unreachableIPv4Addr := tcptestutil.MustParse4("12.0.0.2") - multicastIPv4Addr := tcptestutil.MustParse4("225.0.0.0") - linkLocalIPv4Addr := tcptestutil.MustParse4("169.254.0.0") + remoteIPv4Addr1 := testutil.MustParse4("10.0.0.2") + remoteIPv4Addr2 := testutil.MustParse4("11.0.0.2") + unreachableIPv4Addr := testutil.MustParse4("12.0.0.2") + multicastIPv4Addr := testutil.MustParse4("225.0.0.0") + linkLocalIPv4Addr := testutil.MustParse4("169.254.0.0") tests := []struct { name string @@ -402,13 +397,13 @@ func TestForwarding(t *testing.T) { totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength hdr := buffer.NewPrependable(totalLength) hdr.Prepend(test.payloadLength) - icmp := header.ICMPv4(hdr.Prepend(icmpHeaderLength)) - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv4Echo) - icmp.SetCode(header.ICMPv4UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(^header.Checksum(icmp, 0)) + icmpH := header.ICMPv4(hdr.Prepend(icmpHeaderLength)) + icmpH.SetIdent(randomIdent) + icmpH.SetSequence(randomSequence) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^header.Checksum(icmpH, 0)) ip := header.IPv4(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv4Fields{ TotalLength: uint16(totalLength), @@ -453,7 +448,7 @@ func TestForwarding(t *testing.T) { return len(hdr.View()) } - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(incomingIPv4Addr.Address), checker.DstAddr(test.sourceAddr), checker.TTL(ipv4.DefaultTTL), @@ -461,7 +456,7 @@ func TestForwarding(t *testing.T) { checker.ICMPv4Checksum(), checker.ICMPv4Type(test.icmpType), checker.ICMPv4Code(test.icmpCode), - checker.ICMPv4Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])), + checker.ICMPv4Payload(hdr.View()[:expectedICMPPayloadLength()]), ), ) } else if ok { @@ -470,7 +465,7 @@ func TestForwarding(t *testing.T) { if test.expectPacketForwarded { if len(test.expectedFragmentsForwarded) != 0 { - fragmentedPackets := []*stack.PacketBuffer{} + var fragmentedPackets []*stack.PacketBuffer for i := 0; i < len(test.expectedFragmentsForwarded); i++ { reply, ok = outgoingEndpoint.Read() if !ok { @@ -487,7 +482,7 @@ func TestForwarding(t *testing.T) { // maximum IP header size and the maximum size allocated for link layer // headers. In this case, no size is allocated for link layer headers. expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize - if err := compareFragments(fragmentedPackets, requestPkt, uint32(test.mtu), test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { + if err := compareFragments(fragmentedPackets, requestPkt, test.mtu, test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { t.Error(err) } } else { @@ -496,7 +491,7 @@ func TestForwarding(t *testing.T) { t.Fatal("expected ICMP Echo packet through outgoing NIC") } - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(test.sourceAddr), checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), @@ -1212,15 +1207,15 @@ func TestIPv4Sanity(t *testing.T) { } totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(totalLen)) - icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) // Specify ident/seq to make sure we get the same in the response. - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv4Echo) - icmp.SetCode(header.ICMPv4UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(^header.Checksum(icmp, 0)) + icmpH.SetIdent(randomIdent) + icmpH.SetSequence(randomSequence) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^header.Checksum(icmpH, 0)) ip := header.IPv4(hdr.Prepend(ipHeaderLength)) if test.maxTotalLength < totalLen { totalLen = test.maxTotalLength @@ -1315,7 +1310,7 @@ func TestIPv4Sanity(t *testing.T) { checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), checker.ICMPv4Pointer(test.paramProblemPointer), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload(hdr.View()), ), ) return @@ -1334,7 +1329,7 @@ func TestIPv4Sanity(t *testing.T) { checker.ICMPv4( checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload(hdr.View()), ), ) return @@ -1546,9 +1541,9 @@ func TestFragmentationWritePacket(t *testing.T) { for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) source := pkt.Clone() err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -1602,7 +1597,7 @@ func TestFragmentationWritePackets(t *testing.T) { insertAfter: 1, }, } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) + tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) for _, test := range writePacketsTests { t.Run(test.description, func(t *testing.T) { @@ -1612,13 +1607,13 @@ func TestFragmentationWritePackets(t *testing.T) { for i := 0; i < test.insertBefore; i++ { pkts.PushBack(tinyPacket.Clone()) } - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) pkts.PushBack(pkt.Clone()) for i := 0; i < test.insertAfter; i++ { pkts.PushBack(tinyPacket.Clone()) } - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter @@ -1726,8 +1721,8 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -2301,7 +2296,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) { checker.ICMPv4Type(header.ICMPv4TimeExceeded), checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), checker.ICMPv4Checksum(), - checker.ICMPv4Payload([]byte(firstFragmentSent)), + checker.ICMPv4Payload(firstFragmentSent), ), ) }) @@ -2705,7 +2700,7 @@ func TestReceiveFragments(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, RawFactory: raw.EndpointFactory{}, }) - e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) + e := channel.New(0, 1280, "\xf0\x00") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -2948,7 +2943,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) + ep := iptestutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList @@ -3035,7 +3030,7 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) return false, false } -func TestPacketQueing(t *testing.T) { +func TestPacketQueuing(t *testing.T) { const nicID = 1 var ( @@ -3074,7 +3069,7 @@ func TestPacketQueing(t *testing.T) { Length: header.UDPMinimumSize, }) sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ @@ -3090,7 +3085,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -3133,7 +3128,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -3157,9 +3152,11 @@ func TestPacketQueing(t *testing.T) { t.Run(test.name, func(t *testing.T) { e := channel.New(1, defaultMTU, host1NICLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -3182,7 +3179,8 @@ func TestPacketQueing(t *testing.T) { // Wait for a ARP request since link address resolution should be // performed. { - p, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -3223,6 +3221,7 @@ func TestPacketQueing(t *testing.T) { } // Expect the response now that the link address has resolved. + clock.RunImmediatelyScheduledJobs() test.checkResp(t, e) // Since link resolution was already performed, it shouldn't be performed @@ -3244,8 +3243,8 @@ func TestCloseLocking(t *testing.T) { ) var ( - src = tcptestutil.MustParse4("16.0.0.1") - dst = tcptestutil.MustParse4("16.0.0.2") + src = testutil.MustParse4("16.0.0.1") + dst = testutil.MustParse4("16.0.0.2") ) s := stack.New(stack.Options{ @@ -3253,7 +3252,7 @@ func TestCloseLocking(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - // Perform NAT so that the endoint tries to search for a sibling endpoint + // Perform NAT so that the endpoint tries to search for a sibling endpoint // which ends up taking the protocol and endpoint lock (in that order). table := stack.Table{ Rules: []stack.Rule{ diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 307e1972d..94caaae6c 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -285,8 +285,8 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, routerAlert *header.IPv6RouterAlertOption) { sent := e.stats.icmp.packetsSent received := e.stats.icmp.packetsReceived - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set. See icmp/protocol.go:protocol.Parse for a full explanation. + // ICMP packets don't have their TransportHeader fields set. See + // icmp/protocol.go:protocol.Parse for a full explanation. v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize) if !ok { received.invalid.Increment() @@ -1029,6 +1029,26 @@ func (*icmpReasonNetUnreachable) respondsToMulticast() bool { return false } +// icmpReasonHostUnreachable is an error in which the host specified in the +// internet destination field of the datagram is unreachable. +type icmpReasonHostUnreachable struct{} + +func (*icmpReasonHostUnreachable) isICMPReason() {} +func (*icmpReasonHostUnreachable) isForwarding() bool { + // If we hit a Host Unreachable error, then we know we are operating as a + // router. As per RFC 4443 page 8, Destination Unreachable Message, + // + // If the reason for the failure to deliver cannot be mapped to any of + // other codes, the Code field is set to 3. Example of such cases are + // an inability to resolve the IPv6 destination address into a + // corresponding link address, or a link-specific problem of some sort. + return true +} + +func (*icmpReasonHostUnreachable) respondsToMulticast() bool { + return false +} + // icmpReasonFragmentationNeeded is an error where a packet is to big to be sent // out through the outgoing MTU, as per RFC 4443 page 9, Packet Too Big Message. type icmpReasonPacketTooBig struct{} @@ -1143,7 +1163,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip defer route.Release() p.mu.Lock() - netEP, ok := p.mu.eps[pkt.NICID] + // We retrieve an endpoint using the newly constructed route's NICID rather + // than the packet's NICID. The packet's NICID corresponds to the NIC on + // which it arrived, which isn't necessarily the same as the NIC on which it + // will be transmitted. On the other hand, the route's NIC *is* guaranteed + // to be the NIC on which the packet will be transmitted. + netEP, ok := p.mu.eps[route.NICID()] p.mu.Unlock() if !ok { return &tcpip.ErrNotConnected{} @@ -1222,6 +1247,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6NetworkUnreachable) counter = sent.dstUnreachable + case *icmpReasonHostUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6AddressUnreachable) + counter = sent.dstUnreachable case *icmpReasonPacketTooBig: icmpHdr.SetType(header.ICMPv6PacketTooBig) icmpHdr.SetCode(header.ICMPv6UnusedCode) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 040cd4bc8..7c2a3e56b 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -16,17 +16,16 @@ package ipv6 import ( "bytes" - "context" "net" "reflect" "strings" "testing" - "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -46,16 +45,12 @@ const ( defaultChannelSize = 1 defaultMTU = 65536 - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 30 * time.Second - arbitraryHopLimit = 42 ) var ( lladdr0 = header.LinkLocalAddr(linkAddr0) lladdr1 = header.LinkLocalAddr(linkAddr1) - lladdr2 = header.LinkLocalAddr(linkAddr2) ) type stubLinkEndpoint struct { @@ -95,6 +90,10 @@ func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *st return stack.TransportPacketHandled } +func (*stubDispatcher) DeliverRawPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) { + // No-op. +} + var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { @@ -371,6 +370,8 @@ type testContext struct { linkEP0 *channel.Endpoint linkEP1 *channel.Endpoint + + clock *faketime.ManualClock } type endpointWithResolutionCapability struct { @@ -382,15 +383,19 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab } func newTestContext(t *testing.T) *testContext { + clock := faketime.NewManualClock() c := &testContext{ s0: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + Clock: clock, }), s1: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + Clock: clock, }), + clock: clock, } c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0) @@ -457,10 +462,14 @@ type routeArgs struct { remoteLinkAddr tcpip.LinkAddress } -func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { +func routeICMPv6Packet(t *testing.T, clock *faketime.ManualClock, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pi, _ := args.src.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + pi, ok := args.src.Read() + if !ok { + t.Fatal("packet didn't arrive") + } { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -533,7 +542,7 @@ func TestLinkResolution(t *testing.T) { {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, } { - routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { + routeICMPv6Packet(t, c.clock, args, func(t *testing.T, icmpv6 header.ICMPv6) { if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) } @@ -544,7 +553,7 @@ func TestLinkResolution(t *testing.T) { {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest}, {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply}, } { - routeICMPv6Packet(t, args, nil) + routeICMPv6Packet(t, c.clock, args, nil) } } @@ -1309,7 +1318,7 @@ func TestPacketQueing(t *testing.T) { Length: header.UDPMinimumSize, }) sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -1325,7 +1334,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -1371,7 +1380,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -1396,9 +1405,11 @@ func TestPacketQueing(t *testing.T) { e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -1421,7 +1432,8 @@ func TestPacketQueing(t *testing.T) { // Wait for a neighbor solicitation since link address resolution should // be performed. { - p, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -1475,6 +1487,7 @@ func TestPacketQueing(t *testing.T) { } // Expect the response now that the link address has resolved. + clock.RunImmediatelyScheduledJobs() test.checkResp(t, e) // Since link resolution was already performed, it shouldn't be performed diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 95e11ac51..b1aec5312 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -184,7 +184,6 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol - stack *stack.Stack stats sharedStats // enabled is set to 1 when the endpoint is enabled and 0 when it is @@ -231,11 +230,11 @@ type endpoint struct { // If the NIC was created with a name, it is passed to NICNameFromID. // // NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are -// generated for the same prefix on differnt NICs. +// generated for the same prefix on different NICs. type NICNameFromID func(tcpip.NICID, string) string // OpaqueInterfaceIdentifierOptions holds the options related to the generation -// of opaque interface indentifiers (IIDs) as defined by RFC 7217. +// of opaque interface identifiers (IIDs) as defined by RFC 7217. type OpaqueInterfaceIdentifierOptions struct { // NICNameFromID is a function that returns a stable name for a specified NIC, // even if the NIC ID changes over time. @@ -282,6 +281,16 @@ func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber { // HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // If we are operating as a router, we should return an ICMP error to the + // original packet's sender. + if pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP + // errors to local endpoints. + e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt) + e.stats.ip.Forwarding.Errors.Increment() + e.stats.ip.Forwarding.HostUnreachable.Increment() + return + } // handleControl expects the entire offending packet to be in the packet // buffer's data field. pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -335,7 +344,10 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) { func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { e.mu.Lock() defer e.mu.Unlock() - e.mu.ndp.invalidateDefaultRouter(rtr) + + // We represent default routers with a default (off-link) route through the + // router. + e.mu.ndp.invalidateOffLinkRoute(offLinkRoute{dest: header.IPv6EmptySubnet, router: rtr}) } // SetNDPConfigurations implements NDPEndpoint. @@ -536,7 +548,7 @@ func (e *endpoint) Enable() tcpip.Error { // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent // state. // - // Addresses may have aleady completed DAD but in the time since the endpoint + // Addresses may have already completed DAD but in the time since the endpoint // was last enabled, other devices may have acquired the same addresses. var err tcpip.Error e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { @@ -641,8 +653,8 @@ func (e *endpoint) DefaultTTL() uint8 { return e.protocol.DefaultTTL() } -// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus -// the network layer max header length. +// MTU implements stack.NetworkEndpoint. It returns the link-layer MTU minus the +// network layer max header length. func (e *endpoint) MTU() uint32 { networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize) if err != nil { @@ -666,8 +678,7 @@ func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params if length > math.MaxUint16 { return &tcpip.ErrMessageTooLong{} } - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) - ip.Encode(&header.IPv6Fields{ + header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)).Encode(&header.IPv6Fields{ PayloadLength: uint16(length), TransportProtocol: params.Protocol, HopLimit: params.TTL, @@ -747,9 +758,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // based on destination address and do not send the packet to link // layer. // - // TODO(gvisor.dev/issue/170): We should do this for every - // packet, rather than only NATted packets, but removing this check - // short circuits broadcasts before they are sent out to other hosts. + // We should do this for every packet, rather than only NATted packets, but + // removing this check short circuits broadcasts before they are sent out to + // other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { @@ -818,7 +829,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol return nil } -// WritePackets implements stack.NetworkEndpoint.WritePackets. +// WritePackets implements stack.NetworkEndpoint. func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop()&stack.PacketLoop != 0 { panic("not implemented") @@ -909,21 +920,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu if !ok { return &tcpip.ErrMalformedHeader{} } - ip := header.IPv6(h) + ipH := header.IPv6(h) // Always set the payload length. pktSize := pkt.Data().Size() - ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) + ipH.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) // Set the source address when zero. - if ip.SourceAddress() == header.IPv6Any { - ip.SetSourceAddress(r.LocalAddress()) + if ipH.SourceAddress() == header.IPv6Any { + ipH.SetSourceAddress(r.LocalAddress()) } - // Set the destination. If the packet already included a destination, it will - // be part of the route anyways. - ip.SetDestinationAddress(r.RemoteAddress()) - // Populate the packet buffer's network header and don't allow an invalid // packet to be sent. // @@ -983,7 +990,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return nil } - ep.handleValidatedPacket(h, pkt) + // The packet originally arrived on e so provide its NIC as the input NIC. + ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) return nil } @@ -1096,7 +1104,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } // handleLocalPacket is like HandlePacket except it does not perform the @@ -1115,10 +1123,14 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum return } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } -func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) { + // Raw socket packets are delivered based solely on the transport protocol + // number. We only require that the packet be valid IPv6. + e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) + pkt.NICID = e.nic.ID() stats := e.stats.ip stats.ValidPacketsReceived.Increment() @@ -1167,8 +1179,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return @@ -1610,7 +1621,7 @@ func (e *endpoint) Close() { e.protocol.forgetEndpoint(e.nic.ID()) } -// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +// NetworkProtocolNumber implements stack.NetworkEndpoint. func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } @@ -1619,8 +1630,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) } @@ -1661,8 +1672,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre // RemovePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { @@ -2004,7 +2015,7 @@ func (p *protocol) DefaultPrefixLen() int { return header.IPv6AddressSize * 8 } -// ParseAddresses implements NetworkProtocol.ParseAddresses. +// ParseAddresses implements stack.NetworkProtocol. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv6(v) return h.SourceAddress(), h.DestinationAddress() @@ -2081,7 +2092,7 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { delete(p.mu.eps, nicID) } -// SetOption implements NetworkProtocol.SetOption. +// SetOption implements stack.NetworkProtocol. func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: @@ -2092,7 +2103,7 @@ func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.E } } -// Option implements NetworkProtocol.Option. +// Option implements stack.NetworkProtocol. func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: @@ -2113,10 +2124,10 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } -// Close implements stack.TransportProtocol.Close. +// Close implements stack.TransportProtocol. func (*protocol) Close() {} -// Wait implements stack.TransportProtocol.Wait. +// Wait implements stack.TransportProtocol. func (*protocol) Wait() {} // parseAndValidate parses the packet (including its transport layer header) and @@ -2150,7 +2161,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool) return h, true } -// Parse implements stack.NetworkProtocol.Parse. +// Parse implements stack.NetworkProtocol. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) if !ok { @@ -2180,7 +2191,7 @@ func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, tcpip.Error return 0, &tcpip.ErrMalformedHeader{} } - networkMTU := linkMTU - uint32(networkHeadersLen) + networkMTU := linkMTU - networkHeadersLen if networkMTU > maxPayloadSize { networkMTU = maxPayloadSize } @@ -2199,7 +2210,7 @@ type Options struct { // Note, setting this to true does not mean that a link-local address is // assigned right away, or at all. If Duplicate Address Detection is enabled, // an address is only assigned if it successfully resolves. If it fails, no - // further attempts are made to auto-generate a link-local adddress. + // further attempts are made to auto-generate a link-local address. // // The generated link-local address follows RFC 4291 Appendix A guidelines. AutoGenLinkLocal bool @@ -2215,7 +2226,7 @@ type Options struct { // TempIIDSeed is used to seed the initial temporary interface identifier // history value used to generate IIDs for temporary SLAAC addresses. // - // Temporary SLAAC adresses are short-lived addresses which are unpredictable + // Temporary SLAAC addresses are short-lived addresses which are unpredictable // and random from the perspective of other nodes on the network. It is // recommended that the seed be a random byte buffer of at least // header.IIDSize bytes to make sure that temporary SLAAC addresses are diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index afc6c3547..d2a23fd4f 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -129,7 +129,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) // UDP checksum - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) payloadLength := hdr.UsedLength() @@ -402,7 +402,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }{ { name: "None", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr }, + extHdr: func(nextHdr uint8) ([]byte, uint8) { return nil, nextHdr }, shouldAccept: true, expectICMP: false, }, @@ -612,8 +612,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { { name: "No next header", extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{}, - noNextHdrID + return nil, noNextHdrID }, shouldAccept: false, expectICMP: false, @@ -1160,7 +1159,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2), []buffer.View{ // Fragment extension header. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}, ipv6Payload1Addr1ToAddr2, }, @@ -1180,7 +1179,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2), []buffer.View{ // Fragment extension header. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}, ipv6Payload3Addr1ToAddr2, }, @@ -1202,7 +1201,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1218,7 +1217,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1240,7 +1239,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1256,7 +1255,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1278,7 +1277,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1296,7 +1295,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 8, More = false, ID = 1 // NextHeader value is different than the one in the first fragment, so // this NextHeader should be ignored. - buffer.View([]byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1318,7 +1317,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload3Addr1ToAddr2[:64], }, @@ -1334,7 +1333,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload3Addr1ToAddr2[64:], }, @@ -1356,7 +1355,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload3Addr1ToAddr2[:63], }, @@ -1372,7 +1371,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload3Addr1ToAddr2[63:], }, @@ -1394,7 +1393,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1410,7 +1409,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1432,7 +1431,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], }, @@ -1448,10 +1447,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + []byte{uint8(header.UDPProtocolNumber), 0, udpMaximumSizeMinus15 >> 8, udpMaximumSizeMinus15 & 0xff, - 0, 0, 0, 1}), + 0, 0, 0, 1}, ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], }, @@ -1473,7 +1472,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], }, @@ -1489,10 +1488,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + []byte{uint8(header.UDPProtocolNumber), 0, udpMaximumSizeMinus15 >> 8, (udpMaximumSizeMinus15 & 0xff) + 1, - 0, 0, 0, 1}), + 0, 0, 0, 1}, ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], }, @@ -1514,12 +1513,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header. // // Segments left = 0. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), + []byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}, // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1535,12 +1534,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header. // // Segments left = 0. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), + []byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}, // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1562,12 +1561,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header. // // Segments left = 1. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), + []byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}, // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1583,12 +1582,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header. // // Segments left = 1. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), + []byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}, // Fragment extension header. // // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1610,12 +1609,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, // Routing extension header. // // Segments left = 0. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}), + []byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1631,7 +1630,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1653,12 +1652,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, // Routing extension header. // // Segments left = 1. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}), + []byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1674,7 +1673,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1699,12 +1698,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, // Routing extension header (part 1) // // Segments left = 0. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}), + []byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}, }, ), }, @@ -1721,10 +1720,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 1, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}, // Routing extension header (part 2) - buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), + []byte{6, 7, 8, 9, 10, 11, 12, 13}, ipv6Payload1Addr1ToAddr2, }, @@ -1749,12 +1748,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, // Routing extension header (part 1) // // Segments left = 1. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}), + []byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}, }, ), }, @@ -1771,10 +1770,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 1, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}, // Routing extension header (part 2) - buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), + []byte{6, 7, 8, 9, 10, 11, 12, 13}, ipv6Payload1Addr1ToAddr2, }, @@ -1798,7 +1797,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1816,7 +1815,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}, ipv6Payload2Addr1ToAddr2, }, @@ -1832,7 +1831,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1854,7 +1853,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1870,7 +1869,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}, ipv6Payload2Addr1ToAddr2[:32], }, @@ -1886,7 +1885,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1902,7 +1901,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 4, More = false, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}, ipv6Payload2Addr1ToAddr2[32:], }, @@ -1924,7 +1923,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1940,7 +1939,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr3ToAddr2[:32], }, @@ -1956,7 +1955,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1972,7 +1971,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 4, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}, ipv6Payload1Addr3ToAddr2[32:], }, @@ -2208,7 +2207,7 @@ func TestInvalidIPv6Fragments(t *testing.T) { checker.ICMPv6Type(test.expectICMPType), checker.ICMPv6Code(test.expectICMPCode), checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific), - checker.ICMPv6Payload([]byte(expectICMPPayload)), + checker.ICMPv6Payload(expectICMPPayload), ), ) }) @@ -2459,7 +2458,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) { checker.ICMPv6( checker.ICMPv6Type(header.ICMPv6TimeExceeded), checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout), - checker.ICMPv6Payload([]byte(firstFragmentSent)), + checker.ICMPv6Payload(firstFragmentSent), ), ) }) @@ -2795,11 +2794,7 @@ var fragmentationTests = []struct { } func TestFragmentationWritePacket(t *testing.T) { - const ( - ttl = 42 - tos = stack.DefaultTOS - transportProto = tcp.ProtocolNumber - ) + const ttl = 42 for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { @@ -3335,7 +3330,7 @@ func TestForwarding(t *testing.T) { } transportProtocol := header.ICMPv6ProtocolNumber - extHdrBytes := []byte{} + var extHdrBytes []byte extHdrChecker := checker.IPv6ExtHdr() if test.extHdr != nil { nextHdrID := hopByHopExtHdrID @@ -3349,15 +3344,15 @@ func TestForwarding(t *testing.T) { totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen hdr := buffer.NewPrependable(totalLength) hdr.Prepend(test.payloadLength) - icmp := header.ICMPv6(hdr.Prepend(icmpHeaderLength)) - - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv6EchoRequest) - icmp.SetCode(header.ICMPv6UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, + icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength)) + + icmpH.SetIdent(randomIdent) + icmpH.SetSequence(randomSequence) + icmpH.SetType(header.ICMPv6EchoRequest) + icmpH.SetCode(header.ICMPv6UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpH, Src: test.sourceAddr, Dst: test.destAddr, })) @@ -3395,14 +3390,14 @@ func TestForwarding(t *testing.T) { return len(hdr.View()) } - checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(incomingIPv6Addr.Address), checker.DstAddr(test.sourceAddr), checker.TTL(DefaultTTL), checker.ICMPv6( checker.ICMPv6Type(test.icmpType), checker.ICMPv6Code(test.icmpCode), - checker.ICMPv6Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])), + checker.ICMPv6Payload(hdr.View()[:expectedICMPPayloadLength()]), ), ) @@ -3419,7 +3414,7 @@ func TestForwarding(t *testing.T) { t.Fatal("expected ICMP Echo Request packet through outgoing NIC") } - checker.IPv6WithExtHdr(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv6WithExtHdr(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(test.sourceAddr), checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 71d1c3e28..bc9cf6999 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -16,6 +16,7 @@ package ipv6_test import ( "bytes" + "math/rand" "testing" "time" @@ -138,7 +139,8 @@ func TestSendQueuedMLDReports(t *testing.T) { var secureRNG bytes.Reader secureRNG.Reset(secureRNGBytes[:]) s := stack.New(stack.Options{ - SecureRNG: &secureRNG, + SecureRNG: &secureRNG, + RandSource: rand.NewSource(time.Now().UnixNano()), NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ DupAddrDetectTransmits: test.dadTransmits, diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index f0ff111c5..9cd283eba 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -16,7 +16,6 @@ package ipv6 import ( "fmt" - "math/rand" "time" "gvisor.dev/gvisor/pkg/sync" @@ -79,13 +78,13 @@ const ( // we cannot have a negative delay. minimumMaxRtrSolicitationDelay = 0 - // MaxDiscoveredDefaultRouters is the maximum number of discovered - // default routers. The stack should stop discovering new routers after - // discovering MaxDiscoveredDefaultRouters routers. + // MaxDiscoveredOffLinkRoutes is the maximum number of discovered off-link + // routes. The stack should stop discovering new off-link routes after + // this limit is reached. // // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and // SHOULD be more. - MaxDiscoveredDefaultRouters = 10 + MaxDiscoveredOffLinkRoutes = 10 // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered // on-link prefixes. The stack should stop discovering new on-link @@ -128,25 +127,17 @@ const ( // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt // SLAAC address regenerations in response to an IPv6 endpoint-local conflict. maxSLAACAddrLocalRegenAttempts = 10 -) -var ( // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid // Lifetime to update the valid lifetime of a generated address by // SLAAC. // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // // Min = 2hrs. MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour // MaxDesyncFactor is the upper bound for the preferred lifetime's desync // factor for temporary SLAAC addresses. // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // // Must be greater than 0. // // Max = 10m (from RFC 4941 section 5). @@ -155,9 +146,6 @@ var ( // MinMaxTempAddrPreferredLifetime is the minimum value allowed for the // maximum preferred lifetime for temporary SLAAC addresses. // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // // This value guarantees that a temporary address is preferred for at // least 1hr if the SLAAC prefix is valid for at least that time. MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour @@ -165,9 +153,6 @@ var ( // MinMaxTempAddrValidLifetime is the minimum value allowed for the // maximum valid lifetime for temporary SLAAC addresses. // - // This is exported as a variable (instead of a constant) so tests - // can update it to a smaller value. - // // This value guarantees that a temporary address is valid for at least // 2hrs if the SLAAC prefix is valid for at least that time. MinMaxTempAddrValidLifetime = 2 * time.Hour @@ -215,28 +200,23 @@ type NDPDispatcher interface { // is also not permitted to call into the stack. OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) - // OnDefaultRouterDiscovered is called when a new default router is - // discovered. Implementations must return true if the newly discovered - // router should be remembered. + // OnOffLinkRouteUpdated is called when an off-link route is updated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool + OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference) - // OnDefaultRouterInvalidated is called when a discovered default router that - // was remembered is invalidated. + // OnOffLinkRouteInvalidated is called when an off-link route is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) + OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered. - // Implementations must return true if the newly discovered on-link prefix - // should be remembered. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool + OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that // was remembered is invalidated. @@ -246,13 +226,11 @@ type NDPDispatcher interface { OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) // OnAutoGenAddress is called when a new prefix with its autonomous address- - // configuration flag set is received and SLAAC was performed. Implementations - // may prevent the stack from assigning the address to the NIC by returning - // false. + // configuration flag set is received and SLAAC was performed. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. - OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) // OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC) // is deprecated, but is still considered valid. Note, if an address is @@ -470,6 +448,11 @@ type timer struct { timer tcpip.Timer } +type offLinkRoute struct { + dest tcpip.Subnet + router tcpip.Address +} + // ndpState is the per-Interface NDP state. type ndpState struct { // Do not allow overwriting this state. @@ -484,8 +467,8 @@ type ndpState struct { // The DAD timers to send the next NS message, or resolve the address. dad ip.DAD - // The default routers discovered through Router Advertisements. - defaultRouters map[tcpip.Address]defaultRouterState + // The off-link routes discovered through Router Advertisements. + offLinkRoutes map[offLinkRoute]offLinkRouteState // rtrSolicitTimer is the timer used to send the next router solicitation // message. @@ -513,10 +496,12 @@ type ndpState struct { temporaryAddressDesyncFactor time.Duration } -// defaultRouterState holds data associated with a default router discovered by +// offLinkRouteState holds data associated with an off-link route discovered by // a Router Advertisement (RA). -type defaultRouterState struct { - // Job to invalidate the default router. +type offLinkRouteState struct { + prf header.NDPRoutePreference + + // Job to invalidate the route. // // Must not be nil. invalidationJob *tcpip.Job @@ -549,7 +534,7 @@ type tempSLAACAddrState struct { // Must not be nil. regenJob *tcpip.Job - createdAt time.Time + createdAt tcpip.MonotonicTime // The address's endpoint. // @@ -572,11 +557,11 @@ type slaacPrefixState struct { // Must not be nil. invalidationJob *tcpip.Job - // Nonzero only when the address is not valid forever. - validUntil time.Time + // nil iff the address is valid forever. + validUntil *tcpip.MonotonicTime - // Nonzero only when the address is not preferred forever. - preferredUntil time.Time + // nil iff the address is preferred forever. + preferredUntil *tcpip.MonotonicTime // State associated with the stable address generated for the prefix. stableAddr struct { @@ -734,30 +719,22 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // Is the IPv6 endpoint configured to discover default routers? if ndp.configs.DiscoverDefaultRouters { - rtr, ok := ndp.defaultRouters[ip] - rl := ra.RouterLifetime() - switch { - case !ok && rl != 0: - // This is a new default router we are discovering. + prf := ra.DefaultRouterPreference() + if prf == header.ReservedRoutePreference { + // As per RFC 4191 section 2.2, // - // Only remember it if we currently know about less than - // MaxDiscoveredDefaultRouters routers. - if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters { - ndp.rememberDefaultRouter(ip, rl) - } - - case ok && rl != 0: - // This is an already discovered default router. Update - // the invalidation job. - rtr.invalidationJob.Cancel() - rtr.invalidationJob.Schedule(rl) - ndp.defaultRouters[ip] = rtr - - case ok && rl == 0: - // We know about the router but it is no longer to be - // used as a default router so invalidate it. - ndp.invalidateDefaultRouter(ip) + // Prf (Default Router Preference) + // + // If the Reserved (10) value is received, the receiver MUST treat the + // value as if it were (00). + // + // Note that the value 00 is the medium (default) router preference value. + prf = header.MediumRoutePreference } + + // We represent default routers with a default (off-link) route through the + // router. + ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: ip}, ra.RouterLifetime(), prf) } // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter @@ -815,55 +792,75 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { } } -// invalidateDefaultRouter invalidates a discovered default router. +// invalidateOffLinkRoute invalidates a discovered off-link route. // // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { - rtr, ok := ndp.defaultRouters[ip] - - // Is the router still discovered? +func (ndp *ndpState) invalidateOffLinkRoute(route offLinkRoute) { + state, ok := ndp.offLinkRoutes[route] if !ok { - // ...Nope, do nothing further. return } - rtr.invalidationJob.Cancel() - delete(ndp.defaultRouters, ip) + state.invalidationJob.Cancel() + delete(ndp.offLinkRoutes, route) - // Let the integrator know a discovered default router is invalidated. + // Let the integrator know a discovered off-link route is invalidated. if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { - ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip) + ndpDisp.OnOffLinkRouteInvalidated(ndp.ep.nic.ID(), route.dest, route.router) } } -// rememberDefaultRouter remembers a newly discovered default router with IPv6 -// link-local address ip with lifetime rl. -// -// The router identified by ip MUST NOT already be known by the IPv6 endpoint. +// handleOffLinkRouteDiscovery handles the discovery of an off-link route. // -// The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { +// Precondition: ndp.ep.mu must be locked. +func (ndp *ndpState) handleOffLinkRouteDiscovery(route offLinkRoute, lifetime time.Duration, prf header.NDPRoutePreference) { ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return } - // Inform the integrator when we discovered a default router. - if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) { - // Informed by the integrator to not remember the router, do - // nothing further. - return - } + state, ok := ndp.offLinkRoutes[route] + switch { + case !ok && lifetime != 0: + // This is a new route we are discovering. + // + // Only remember it if we currently know about less than + // MaxDiscoveredOffLinkRoutes routers. + if len(ndp.offLinkRoutes) < MaxDiscoveredOffLinkRoutes { + // Inform the integrator when we discovered an off-link route. + ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf) + + state := offLinkRouteState{ + prf: prf, + invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { + ndp.invalidateOffLinkRoute(route) + }), + } - state := defaultRouterState{ - invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { - ndp.invalidateDefaultRouter(ip) - }), - } + state.invalidationJob.Schedule(lifetime) + + ndp.offLinkRoutes[route] = state + } + + case ok && lifetime != 0: + // This is an already discovered off-link route. Update the lifetime. + state.invalidationJob.Cancel() + state.invalidationJob.Schedule(lifetime) + + if prf != state.prf { + state.prf = prf - state.invalidationJob.Schedule(rl) + // Inform the integrator about route preference updates. + ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), route.dest, route.router, prf) + } + + ndp.offLinkRoutes[route] = state - ndp.defaultRouters[ip] = state + case ok && lifetime == 0: + // The already discovered off-link route is no longer considered valid so we + // invalidate it immediately. + ndp.invalidateOffLinkRoute(route) + } } // rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6 @@ -879,11 +876,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) } // Inform the integrator when we discovered an on-link prefix. - if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) { - // Informed by the integrator to not remember the prefix, do - // nothing further. - return - } + ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) state := onLinkPrefixState{ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { @@ -1051,12 +1044,13 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1, } - now := time.Now() + now := ndp.ep.protocol.stack.Clock().NowMonotonic() // The time an address is preferred until is needed to properly generate the // address. if pl < header.NDPInfiniteLifetime { - state.preferredUntil = now.Add(pl) + t := now.Add(pl) + state.preferredUntil = &t } if !ndp.generateSLAACAddr(prefix, &state) { @@ -1074,7 +1068,8 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { if vl < header.NDPInfiniteLifetime { state.invalidationJob.Schedule(vl) - state.validUntil = now.Add(vl) + t := now.Add(vl) + state.validUntil = &t } // If the address is assigned (DAD resolved), generate a temporary address. @@ -1097,16 +1092,13 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config return nil } - if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) { - // Informed by the integrator not to add the address. - return nil - } - addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated) if err != nil { panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err)) } + ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) + return addressEndpoint } @@ -1182,7 +1174,8 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt state.stableAddr.localGenerationFailures++ } - if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil { + deprecated := state.preferredUntil != nil && !state.preferredUntil.After(ndp.ep.protocol.stack.Clock().NowMonotonic()) + if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, deprecated); addressEndpoint != nil { state.stableAddr.addressEndpoint = addressEndpoint state.generationAttempts++ return true @@ -1237,13 +1230,13 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla } stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address - now := time.Now() + now := ndp.ep.protocol.stack.Clock().NowMonotonic() // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary // address is the lower of the valid lifetime of the stable address or the // maximum temporary address valid lifetime. vl := ndp.configs.MaxTempAddrValidLifetime - if prefixState.validUntil != (time.Time{}) { + if prefixState.validUntil != nil { if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL { vl = prefixVL } @@ -1259,7 +1252,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla // maximum temporary address preferred lifetime - the temporary address desync // factor. pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor - if prefixState.preferredUntil != (time.Time{}) { + if prefixState.preferredUntil != nil { if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL { // Respect the preferred lifetime of the prefix, as per RFC 4941 section // 3.3 step 4. @@ -1394,16 +1387,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // deprecation job so it can be reset. prefixState.deprecationJob.Cancel() - now := time.Now() + now := ndp.ep.protocol.stack.Clock().NowMonotonic() // Schedule the deprecation job if prefix has a finite preferred lifetime. if pl < header.NDPInfiniteLifetime { if !deprecated { prefixState.deprecationJob.Schedule(pl) } - prefixState.preferredUntil = now.Add(pl) + t := now.Add(pl) + prefixState.preferredUntil = &t } else { - prefixState.preferredUntil = time.Time{} + prefixState.preferredUntil = nil } // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix: @@ -1421,17 +1415,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // Handle the infinite valid lifetime separately as we do not schedule a // job in this case. prefixState.invalidationJob.Cancel() - prefixState.validUntil = time.Time{} + prefixState.validUntil = nil } else { var effectiveVl time.Duration var rl time.Duration // If the prefix was originally set to be valid forever, assume the // remaining time to be the maximum possible value. - if prefixState.validUntil == (time.Time{}) { + if prefixState.validUntil == nil { rl = header.NDPInfiniteLifetime } else { - rl = time.Until(prefixState.validUntil) + rl = prefixState.validUntil.Sub(now) } if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { @@ -1443,7 +1437,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat if effectiveVl != 0 { prefixState.invalidationJob.Cancel() prefixState.invalidationJob.Schedule(effectiveVl) - prefixState.validUntil = now.Add(effectiveVl) + t := now.Add(effectiveVl) + prefixState.validUntil = &t } } @@ -1463,8 +1458,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // maximum temporary address valid lifetime. Note, the valid lifetime of a // temporary address is relative to the address's creation time. validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime) - if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 { - validUntil = prefixState.validUntil + if prefixState.validUntil != nil && prefixState.validUntil.Before(validUntil) { + validUntil = *prefixState.validUntil } // If the address is no longer valid, invalidate it immediately. Otherwise, @@ -1483,14 +1478,15 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // desync factor. Note, the preferred lifetime of a temporary address is // relative to the address's creation time. preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor) - if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 { - preferredUntil = prefixState.preferredUntil + if prefixState.preferredUntil != nil && prefixState.preferredUntil.Before(preferredUntil) { + preferredUntil = *prefixState.preferredUntil } // If the address is no longer preferred, deprecate it immediately. // Otherwise, schedule the deprecation job again. newPreferredLifetime := preferredUntil.Sub(now) tempAddrState.deprecationJob.Cancel() + if newPreferredLifetime <= 0 { ndp.deprecateSLAACAddress(tempAddrState.addressEndpoint) } else { @@ -1680,12 +1676,12 @@ func (ndp *ndpState) cleanupState() { panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got)) } - for router := range ndp.defaultRouters { - ndp.invalidateDefaultRouter(router) + for route := range ndp.offLinkRoutes { + ndp.invalidateOffLinkRoute(route) } - if got := len(ndp.defaultRouters); got != 0 { - panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got)) + if got := len(ndp.offLinkRoutes); got != 0 { + panic(fmt.Sprintf("ndp: still have discovered off-link routes after cleaning up; found = %d", got)) } ndp.dhcpv6Configuration = 0 @@ -1718,7 +1714,7 @@ func (ndp *ndpState) startSolicitingRouters() { // 4861 section 6.3.7. var delay time.Duration if ndp.configs.MaxRtrSolicitationDelay > 0 { - delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) + delay = time.Duration(ndp.ep.protocol.stack.Rand().Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) } // Protected by ndp.ep.mu. @@ -1754,7 +1750,7 @@ func (ndp *ndpState) startSolicitingRouters() { header.NDPSourceLinkLayerAddressOption(linkAddress), } } - payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + optsSerializer.Length() icmpData := header.ICMPv6(buffer.NewView(payloadSize)) icmpData.SetType(header.ICMPv6RouterSolicit) rs := header.NDPRouterSolicit(icmpData.MessageBody()) @@ -1848,21 +1844,19 @@ func (ndp *ndpState) stopSolicitingRouters() { } func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) { - if ndp.defaultRouters != nil { + if ndp.offLinkRoutes != nil { panic("attempted to initialize NDP state twice") } ndp.ep = ep ndp.configs = ep.protocol.options.NDPConfigs ndp.dad.Init(&ndp.ep.mu, ep.protocol.options.DADConfigs, dadOptions) - ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState) + ndp.offLinkRoutes = make(map[offLinkRoute]offLinkRouteState) ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState) ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) - if MaxDesyncFactor != 0 { - ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) - } + ndp.temporaryAddressDesyncFactor = time.Duration(ep.protocol.stack.Rand().Int63n(int64(MaxDesyncFactor))) } func (ndp *ndpState) SendDADMessage(addr tcpip.Address, nonce []byte) tcpip.Error { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 234e34952..f0186c64e 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -16,7 +16,7 @@ package ipv6 import ( "bytes" - "context" + "math/rand" "strings" "testing" "time" @@ -32,58 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) -// setupStackAndEndpoint creates a stack with a single NIC with a link-local -// address llladdr and an IPv6 endpoint to a remote with link-local address -// rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - { - subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - - ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - t.Cleanup(ep.Close) - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := llladdr.WithPrefix() - if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - addressEP.DecRef() - } - - return s, ep -} - var _ NDPDispatcher = (*testNDPDispatcher)(nil) // testNDPDispatcher is an NDPDispatcher only allows default router discovery. @@ -94,24 +42,21 @@ type testNDPDispatcher struct { func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { } -func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { +func (t *testNDPDispatcher) OnOffLinkRouteUpdated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address, _ header.NDPRoutePreference) { t.addr = addr - return true } -func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) { +func (t *testNDPDispatcher) OnOffLinkRouteInvalidated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address) { t.addr = addr } -func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { - return false +func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) { } func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) { } -func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { - return false +func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) { } func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) { @@ -148,7 +93,7 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { ipv6EP := ep.(*endpoint) ipv6EP.mu.Lock() - ipv6EP.mu.ndp.rememberDefaultRouter(lladdr1, time.Hour) + ipv6EP.mu.ndp.handleOffLinkRouteDiscovery(offLinkRoute{dest: header.IPv6EmptySubnet, router: lladdr1}, time.Hour, header.MediumRoutePreference) ipv6EP.mu.Unlock() if ndpDisp.addr != lladdr1 { @@ -163,11 +108,6 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { } } -type linkResolutionResult struct { - linkAddr tcpip.LinkAddress - ok bool -} - // TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. @@ -456,8 +396,10 @@ func TestNeighborSolicitationResponse(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + Clock: clock, }) e := channel.New(1, 1280, nicLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -527,7 +469,8 @@ func TestNeighborSolicitationResponse(t *testing.T) { } if test.performsLinkResolution { - p, got := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, got := e.Read() if !got { t.Fatal("expected an NDP NS response") } @@ -582,7 +525,8 @@ func TestNeighborSolicitationResponse(t *testing.T) { })) } - p, got := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, got := e.Read() if !got { t.Fatal("expected an NDP NA response") } @@ -906,12 +850,12 @@ func TestNDPValidation(t *testing.T) { routerOnly := stats.RouterOnlyPacketsDroppedByHost typStat := typ.statCounter(stats) - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp[:typ.size], + icmpH := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmpH[typ.size:], typ.extraData) + icmpH.SetType(typ.typ) + icmpH.SetCode(test.code) + icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpH[:typ.size], Src: lladdr0, Dst: lladdr1, PayloadCsum: header.Checksum(typ.extraData /* initial */, 0), @@ -937,7 +881,7 @@ func TestNDPValidation(t *testing.T) { t.FailNow() } - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) + handleIPv6Payload(buffer.View(icmpH), test.hopLimit, test.atomicFragment, ep) // Rx count of the NDP packet should have increased. if got := typStat.Value(); got != 1 { @@ -1297,8 +1241,9 @@ func TestCheckDuplicateAddress(t *testing.T) { var secureRNG bytes.Reader secureRNG.Reset(secureRNGBytes[:]) s := stack.New(stack.Options{ - SecureRNG: &secureRNG, - Clock: clock, + Clock: clock, + RandSource: rand.NewSource(time.Now().UnixNano()), + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ DADConfigs: dadConfigs, })}, @@ -1315,7 +1260,8 @@ func TestCheckDuplicateAddress(t *testing.T) { snmc := header.SolicitedNodeAddr(lladdr0) remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc) checkDADMsg := func() { - p, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, ok := e.Read() if !ok { t.Fatalf("expected %d-th DAD message", dadPacketsSent) } @@ -1363,7 +1309,7 @@ func TestCheckDuplicateAddress(t *testing.T) { t.Fatalf("RemoveAddress(%d, %s): %s", nicID, lladdr0, err) } // Should not restart DAD since we already requested DAD above - the handler - // should be called when the original request compeletes so we should not send + // should be called when the original request completes so we should not send // an extra DAD message here. dadRequestsMade++ if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) { diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index b5b013b64..854d6a6ba 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -101,7 +101,7 @@ func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) { // Wildcard destinations affect all destinations for TupleOnly. if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress { // Only bitwise and the TupleOnlyFlag. - intersection &= ((^TupleOnlyFlag) | counter.SharedFlags()) + intersection &= (^TupleOnlyFlag) | counter.SharedFlags() count++ } } @@ -238,13 +238,13 @@ type PortTester func(port uint16) (good bool, err tcpip.Error) // possible ephemeral ports, allowing the caller to decide whether a given port // is suitable for its needs, and stopping when a port is found or an error // occurs. -func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err tcpip.Error) { +func (pm *PortManager) PickEphemeralPort(rng *rand.Rand, testPort PortTester) (port uint16, err tcpip.Error) { pm.ephemeralMu.RLock() firstEphemeral := pm.firstEphemeral numEphemeral := pm.numEphemeral pm.ephemeralMu.RUnlock() - offset := uint32(rand.Int31n(int32(numEphemeral))) + offset := uint32(rng.Int31n(int32(numEphemeral))) return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort) } @@ -303,7 +303,7 @@ func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester) // An optional PortTester can be passed in which if provided will be used to // test if the picked port can be used. The function should return true if the // port is safe to use, false otherwise. -func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) { +func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) { pm.mu.Lock() defer pm.mu.Unlock() @@ -328,7 +328,7 @@ func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reserv } // A port wasn't specified, so try to find one. - return pm.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { + return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) { res.Port = p if !pm.reserveSpecificPortLocked(res) { return false, nil diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 6c4fb8c68..a91b130df 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -18,6 +18,7 @@ import ( "math" "math/rand" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" @@ -331,6 +332,7 @@ func TestPortReservation(t *testing.T) { t.Run(test.tname, func(t *testing.T) { pm := NewPortManager() net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} + rng := rand.New(rand.NewSource(time.Now().UnixNano())) for _, test := range test.actions { first, _ := pm.PortRange() @@ -356,7 +358,7 @@ func TestPortReservation(t *testing.T) { BindToDevice: test.device, Dest: test.dest, } - gotPort, err := pm.ReservePort(portRes, nil /* testPort */) + gotPort, err := pm.ReservePort(rng, portRes, nil /* testPort */) if diff := cmp.Diff(test.want, err); diff != "" { t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff) } @@ -417,10 +419,11 @@ func TestPickEphemeralPort(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() + rng := rand.New(rand.NewSource(time.Now().UnixNano())) if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { t.Fatalf("failed to set ephemeral port range: %s", err) } - port, err := pm.PickEphemeralPort(test.f) + port, err := pm.PickEphemeralPort(rng, test.f) if diff := cmp.Diff(test.wantErr, err); diff != "" { t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) } diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index b26936b7f..0ea85f9ed 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -222,7 +222,7 @@ type SocketOptions struct { getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"` // receiveBufferSize determines the receive buffer size for this socket. - receiveBufferSize int64 + receiveBufferSize atomicbitops.AlignedAtomicInt64 // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -601,9 +601,10 @@ func (so *SocketOptions) GetBindToDevice() int32 { return atomic.LoadInt32(&so.bindToDevice) } -// SetBindToDevice sets value for SO_BINDTODEVICE option. +// SetBindToDevice sets value for SO_BINDTODEVICE option. If bindToDevice is +// zero, the socket device binding is removed. func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error { - if !so.handler.HasNIC(bindToDevice) { + if bindToDevice != 0 && !so.handler.HasNIC(bindToDevice) { return &ErrUnknownDevice{} } @@ -653,13 +654,13 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { // GetReceiveBufferSize gets value for SO_RCVBUF option. func (so *SocketOptions) GetReceiveBufferSize() int64 { - return atomic.LoadInt64(&so.receiveBufferSize) + return so.receiveBufferSize.Load() } // SetReceiveBufferSize sets value for SO_RCVBUF option. func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) { if !notify { - atomic.StoreInt64(&so.receiveBufferSize, receiveBufferSize) + so.receiveBufferSize.Store(receiveBufferSize) return } @@ -684,8 +685,8 @@ func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bo v = math.MaxInt32 } - oldSz := atomic.LoadInt64(&so.receiveBufferSize) + oldSz := so.receiveBufferSize.Load() // Notify endpoint about change in buffer size. newSz := so.handler.OnSetReceiveBufferSize(v, oldSz) - atomic.StoreInt64(&so.receiveBufferSize, newSz) + so.receiveBufferSize.Store(newSz) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 84aa6a9e4..e0847e58a 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -56,6 +56,7 @@ go_library( "neighbor_entry_list.go", "neighborstate_string.go", "nic.go", + "nic_stats.go", "nud.go", "packet_buffer.go", "packet_buffer_list.go", @@ -94,7 +95,7 @@ go_library( go_test( name = "stack_x_test", - size = "medium", + size = "small", srcs = [ "addressable_endpoint_state_test.go", "ndp_test.go", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 5720e7543..782e74b24 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -35,7 +35,6 @@ import ( // Currently, only TCP tracking is supported. // Our hash table has 16K buckets. -// TODO(gvisor.dev/issue/170): These should be tunable. const numBuckets = 1 << 14 // Direction of the tuple. @@ -125,6 +124,8 @@ type conn struct { tcb tcpconntrack.TCB // lastUsed is the last time the connection saw a relevant packet, and // is updated by each packet on the connection. It is protected by mu. + // + // TODO(gvisor.dev/issue/5939): do not use the ambient clock. lastUsed time.Time `state:".(unixTime)"` } @@ -163,8 +164,6 @@ func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. - // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle - // other tcp states. if cn.tcb.IsEmpty() { cn.tcb.Init(tcpHeader) } else if hook == cn.tcbHook { @@ -244,8 +243,7 @@ func (ct *ConnTrack) init() { // connFor gets the conn for pkt if it exists, or returns nil // if it does not. It returns an error when pkt does not contain a valid TCP // header. -// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support -// other transport protocols. +// TODO(gvisor.dev/issue/6168): Support UDP. func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { tid, err := packetToTupleID(pkt) if err != nil { @@ -383,7 +381,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } - // TODO(gvisor.dev/issue/170): Support other transport protocols. + // TODO(gvisor.dev/issue/6168): Support UDP. if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { return false } @@ -407,16 +405,23 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { // validated if checksum offloading is off. It may require IP defrag if the // packets are fragmented. + var newAddr tcpip.Address + var newPort uint16 + + updateSRCFields := false + switch hook { case Prerouting, Output: if conn.manip == manipDestination { switch dir { case dirOriginal: - tcpHeader.SetDestinationPort(conn.reply.srcPort) - netHeader.SetDestinationAddress(conn.reply.srcAddr) + newPort = conn.reply.srcPort + newAddr = conn.reply.srcAddr case dirReply: - tcpHeader.SetSourcePort(conn.original.dstPort) - netHeader.SetSourceAddress(conn.original.dstAddr) + newPort = conn.original.dstPort + newAddr = conn.original.dstAddr + + updateSRCFields = true } pkt.NatDone = true } @@ -424,11 +429,13 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { if conn.manip == manipSource { switch dir { case dirOriginal: - tcpHeader.SetSourcePort(conn.reply.dstPort) - netHeader.SetSourceAddress(conn.reply.dstAddr) + newPort = conn.reply.dstPort + newAddr = conn.reply.dstAddr + + updateSRCFields = true case dirReply: - tcpHeader.SetDestinationPort(conn.original.srcPort) - netHeader.SetDestinationAddress(conn.original.srcAddr) + newPort = conn.original.srcPort + newAddr = conn.original.srcAddr } pkt.NatDone = true } @@ -439,33 +446,33 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } + fullChecksum := false + updatePseudoHeader := false switch hook { case Prerouting, Input: case Output, Postrouting: // Calculate the TCP checksum and set it. - tcpHeader.SetChecksum(0) - length := uint16(len(tcpHeader) + pkt.Data().Size()) - xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { - tcpHeader.SetChecksum(xsum) + updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) + fullChecksum = true + updatePseudoHeader = true } default: panic(fmt.Sprintf("unrecognized hook = %s", hook)) } - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } + rewritePacket( + netHeader, + tcpHeader, + updateSRCFields, + fullChecksum, + updatePseudoHeader, + newPort, + newAddr, + ) // Update the state of tcb. - // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle - // other tcp states. conn.mu.Lock() defer conn.mu.Unlock() @@ -542,8 +549,6 @@ func (ct *ConnTrack) bucket(id tupleID) int { // reapUnused returns the next bucket that should be checked and the time after // which it should be called again. func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { - // TODO(gvisor.dev/issue/170): This can be more finely controlled, as - // it is in Linux via sysctl. const fractionPerReaping = 128 const maxExpiredPct = 50 const maxFullTraversal = 60 * time.Second diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7107d598d..72f66441f 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -114,10 +115,6 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen } -func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -134,7 +131,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParam } // WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -224,7 +221,7 @@ func (*fwdTestNetworkProtocol) Wait() {} func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { if fn := f.proto.onLinkAddressResolved; fn != nil { - time.AfterFunc(f.proto.addrResolveDelay, func() { + f.proto.stack.clock.AfterFunc(f.proto.addrResolveDelay, func() { fn(f.proto.neigh, addr, remoteLinkAddr) }) } @@ -319,7 +316,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -354,17 +351,19 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (e *fwdTestLinkEndpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) { panic("not implemented") } -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.ManualClock, *fwdTestLinkEndpoint, *fwdTestLinkEndpoint) { + clock := faketime.NewManualClock() // Create a stack with the network protocol and two NICs. s := New(Options{ NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { proto.stack = s return proto }}, + Clock: clock, }) protoNum := proto.Number() @@ -373,7 +372,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f } // NIC 1 has the link address "a", and added the network address 1. - ep1 = &fwdTestLinkEndpoint{ + ep1 := &fwdTestLinkEndpoint{ C: make(chan fwdTestPacketInfo, 300), mtu: fwdTestNetDefaultMTU, linkAddr: "a", @@ -386,7 +385,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f } // NIC 2 has the link address "b", and added the network address 2. - ep2 = &fwdTestLinkEndpoint{ + ep2 := &fwdTestLinkEndpoint{ C: make(chan fwdTestPacketInfo, 300), mtu: fwdTestNetDefaultMTU, linkAddr: "b", @@ -416,7 +415,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) } - return ep1, ep2 + return clock, ep1, ep2 } func TestForwardingWithStaticResolver(t *testing.T) { @@ -432,7 +431,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -444,6 +443,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: default: @@ -475,7 +475,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -487,9 +487,10 @@ func TestForwardingWithFakeResolver(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -508,7 +509,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // Whether or not we use the neighbor cache here does not matter since // neither linkAddrCache nor neighborCache will be used. - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -518,10 +519,11 @@ func TestForwardingWithNoResolver(t *testing.T) { Data: buf.ToVectorisedView(), })) + clock.Advance(proto.addrResolveDelay) select { case <-ep2.C: t.Fatal("Packet should not be forwarded") - case <-time.After(time.Second): + default: } } @@ -533,7 +535,7 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) const numPackets int = 5 // These packets will all be enqueued in the packet queue to wait for link @@ -547,12 +549,12 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { } // All packets should fail resolution. - // TODO(gvisor.dev/issue/5141): Use a fake clock. for i := 0; i < numPackets; i++ { + clock.Advance(proto.addrResolveDelay) select { case got := <-ep2.C: t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got) - case <-time.After(100 * time.Millisecond): + default: } } } @@ -576,7 +578,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { } }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject an inbound packet to address 4 on NIC 1. This packet should // not be forwarded. @@ -596,9 +598,10 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -631,7 +634,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject two inbound packets to address 3 on NIC 1. for i := 0; i < 2; i++ { @@ -645,9 +648,10 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -681,7 +685,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) for i := 0; i < maxPendingPacketsPerResolution+5; i++ { // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. @@ -697,9 +701,10 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { for i := 0; i < maxPendingPacketsPerResolution; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -745,7 +750,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) for i := 0; i < maxPendingResolutions+5; i++ { // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. @@ -761,9 +766,10 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { for i := 0; i < maxPendingResolutions; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 3670d5995..f152c0d83 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() *IPTables { +func DefaultTables(seed uint32) *IPTables { return &IPTables{ v4Tables: [NumTables]Table{ NATID: { @@ -182,7 +182,7 @@ func DefaultTables() *IPTables { Postrouting: {MangleID, NATID}, }, connections: ConnTrack{ - seed: generateRandUint32(), + seed: seed, }, reaperDone: make(chan struct{}, 1), } @@ -268,10 +268,6 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // -// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from -// which address can be gathered. Currently, address is only needed for -// prerouting. -// // Precondition: pkt.NetworkHeader is set. func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { @@ -371,6 +367,7 @@ func (it *IPTables) startReaper(interval time.Duration) { select { case <-it.reaperDone: return + // TODO(gvisor.dev/issue/5939): do not use the ambient clock. case <-time.After(interval): bucket, interval = it.connections.reapUnused(bucket, interval) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 2812c89aa..96cc899bb 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -87,9 +87,6 @@ func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Addre // destination port/IP. Outgoing packets are redirected to the loopback device, // and incoming packets are redirected to the incoming interface (rather than // forwarded). -// -// TODO(gvisor.dev/issue/170): Other flags need to be added after we support -// them. type RedirectTarget struct { // Port indicates port used to redirect. It is immutable. Port uint16 @@ -100,9 +97,6 @@ type RedirectTarget struct { } // Action implements Target.Action. -// TODO(gvisor.dev/issue/170): Parse headers without copying. The current -// implementation only works for Prerouting and calls pkt.Clone(), neither -// of which should be the case. func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { @@ -136,34 +130,26 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r panic("redirect target is supported only on output and prerouting hooks") } - // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if - // we need to change dest address (for OUTPUT chain) or ports. switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: udpHeader := header.UDP(pkt.TransportHeader().View()) - udpHeader.SetDestinationPort(rt.Port) - // Calculate UDP checksum and set it. if hook == Output { - udpHeader.SetChecksum(0) - netHeader := pkt.Network() - netHeader.SetDestinationAddress(address) - // Only calculate the checksum if offloading isn't supported. - if r.RequiresTXTransportChecksum() { - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - } + requiresChecksum := r.RequiresTXTransportChecksum() + rewritePacket( + pkt.Network(), + udpHeader, + false, /* updateSRCFields */ + requiresChecksum, + requiresChecksum, + rt.Port, + address, + ) + } else { + udpHeader.SetDestinationPort(rt.Port) } - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { @@ -222,26 +208,18 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: - udpHeader := header.UDP(pkt.TransportHeader().View()) - udpHeader.SetChecksum(0) - udpHeader.SetSourcePort(st.Port) - netHeader := pkt.Network() - netHeader.SetSourceAddress(st.Addr) - // Only calculate the checksum if offloading isn't supported. - if r.RequiresTXTransportChecksum() { - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) - xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum()) - udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) - } + requiresChecksum := r.RequiresTXTransportChecksum() + rewritePacket( + pkt.Network(), + header.UDP(pkt.TransportHeader().View()), + true, /* updateSRCFields */ + requiresChecksum, + requiresChecksum, + st.Port, + st.Addr, + ) - // After modification, IPv4 packets need a valid checksum. - if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - netHeader.SetChecksum(0) - netHeader.SetChecksum(^netHeader.CalculateChecksum()) - } pkt.NatDone = true case header.TCPProtocolNumber: if ct == nil { @@ -260,3 +238,42 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou return RuleAccept, 0 } + +func rewritePacket(n header.Network, t header.ChecksummableTransport, updateSRCFields, fullChecksum, updatePseudoHeader bool, newPort uint16, newAddr tcpip.Address) { + if updateSRCFields { + if fullChecksum { + t.SetSourcePortWithChecksumUpdate(newPort) + } else { + t.SetSourcePort(newPort) + } + } else { + if fullChecksum { + t.SetDestinationPortWithChecksumUpdate(newPort) + } else { + t.SetDestinationPort(newPort) + } + } + + if updatePseudoHeader { + var oldAddr tcpip.Address + if updateSRCFields { + oldAddr = n.SourceAddress() + } else { + oldAddr = n.DestinationAddress() + } + + t.UpdateChecksumPseudoHeaderAddress(oldAddr, newAddr, fullChecksum) + } + + if checksummableNetHeader, ok := n.(header.ChecksummableNetwork); ok { + if updateSRCFields { + checksummableNetHeader.SetSourceAddressWithChecksumUpdate(newAddr) + } else { + checksummableNetHeader.SetDestinationAddressWithChecksumUpdate(newAddr) + } + } else if updateSRCFields { + n.SetSourceAddress(newAddr) + } else { + n.SetDestinationAddress(newAddr) + } +} diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 93592e7f5..66e5f22ac 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -242,7 +242,6 @@ type IPHeaderFilter struct { func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool { // Extract header fields. var ( - // TODO(gvisor.dev/issue/170): Support other filter fields. transProto tcpip.TransportProtocolNumber dstAddr tcpip.Address srcAddr tcpip.Address @@ -291,7 +290,6 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa return true case Postrouting: - // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING. return true default: panic(fmt.Sprintf("unknown hook: %d", hook)) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index ac2fa777e..9623d9c28 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -16,14 +16,14 @@ package stack_test import ( "bytes" - "context" "encoding/binary" "fmt" + "math/rand" "testing" "time" "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" + cryptorand "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -52,17 +52,6 @@ const ( linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09") defaultPrefixLen = 128 - - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 10 * time.Second - - // Extra time to use when waiting for an async event to not occur. - // - // Since a negative check is used to make sure an event did not happen, it is - // okay to use a smaller timeout compared to the positive case since execution - // stall in regards to the monotonic clock will not affect the expected - // outcome. - defaultAsyncNegativeEventTimeout = time.Second ) var ( @@ -112,11 +101,13 @@ type ndpDADEvent struct { res stack.DADResult } -type ndpRouterEvent struct { - nicID tcpip.NICID - addr tcpip.Address - // true if router was discovered, false if invalidated. - discovered bool +type ndpOffLinkRouteEvent struct { + nicID tcpip.NICID + subnet tcpip.Subnet + router tcpip.Address + prf header.NDPRoutePreference + // true if route was updated, false if invalidated. + updated bool } type ndpPrefixEvent struct { @@ -140,6 +131,10 @@ type ndpAutoGenAddrEvent struct { eventType ndpAutoGenAddrEventType } +func (e ndpAutoGenAddrEvent) String() string { + return fmt.Sprintf("%T{nicID=%d addr=%s eventType=%d}", e, e.nicID, e.addr, e.eventType) +} + type ndpRDNSS struct { addrs []tcpip.Address lifetime time.Duration @@ -167,10 +162,8 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) // related events happen for test purposes. type ndpDispatcher struct { dadC chan ndpDADEvent - routerC chan ndpRouterEvent - rememberRouter bool + offLinkRouteC chan ndpOffLinkRouteEvent prefixC chan ndpPrefixEvent - rememberPrefix bool autoGenAddrC chan ndpAutoGenAddrEvent rdnssC chan ndpRDNSSEvent dnsslC chan ndpDNSSLEvent @@ -189,32 +182,35 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, add } } -// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered. -func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ +// Implements ipv6.NDPDispatcher.OnOffLinkRouteUpdated. +func (n *ndpDispatcher) OnOffLinkRouteUpdated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address, prf header.NDPRoutePreference) { + if c := n.offLinkRouteC; c != nil { + c <- ndpOffLinkRouteEvent{ nicID, - addr, + subnet, + router, + prf, true, } } - - return n.rememberRouter } -// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated. -func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ +// Implements ipv6.NDPDispatcher.OnOffLinkRouteInvalidated. +func (n *ndpDispatcher) OnOffLinkRouteInvalidated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) { + if c := n.offLinkRouteC; c != nil { + var prf header.NDPRoutePreference + c <- ndpOffLinkRouteEvent{ nicID, - addr, + subnet, + router, + prf, false, } } } // Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered. -func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { +func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) { if c := n.prefixC; c != nil { c <- ndpPrefixEvent{ nicID, @@ -222,8 +218,6 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip true, } } - - return n.rememberPrefix } // Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated. @@ -237,7 +231,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi } } -func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { +func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { if c := n.autoGenAddrC; c != nil { c <- ndpAutoGenAddrEvent{ nicID, @@ -245,7 +239,6 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi newAddr, } } - return true } func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { @@ -497,8 +490,9 @@ func TestDADResolve(t *testing.T) { clock := faketime.NewManualClock() s := stack.New(stack.Options{ - Clock: clock, - SecureRNG: &secureRNG, + Clock: clock, + RandSource: rand.NewSource(time.Now().UnixNano()), + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: stack.DADConfigurations{ @@ -605,7 +599,10 @@ func TestDADResolve(t *testing.T) { // Validate the sent Neighbor Solicitation messages. for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p, _ := e.ReadContext(context.Background()) + p, ok := e.Read() + if !ok { + t.Fatal("packet didn't arrive") + } // Make sure its an IPv6 packet. if p.Proto != header.IPv6ProtocolNumber { @@ -731,11 +728,13 @@ func TestDADFail(t *testing.T) { dadConfigs.RetransmitTimer = time.Second * 2 e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -761,16 +760,17 @@ func TestDADFail(t *testing.T) { // Wait for DAD to fail and make sure the address did // not get resolved. + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the - // expected resolution time + extra 1s buffer, - // something is wrong. - t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + // If we don't get a failure event after the + // expected resolution time + extra 1s buffer, + // something is wrong. + t.Fatal("timed out waiting for DAD failure") } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Fatal(err) @@ -839,11 +839,13 @@ func TestDADStop(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) @@ -861,15 +863,16 @@ func TestDADStop(t *testing.T) { test.stopFn(t, s) // Wait for DAD to fail (since the address was removed during DAD). + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") } if !test.skipFinalAddrCheck { @@ -920,10 +923,12 @@ func TestSetNDPConfigurations(t *testing.T) { dadC: make(chan ndpDADEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, })}, + Clock: clock, }) expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { @@ -1002,28 +1007,23 @@ func TestSetNDPConfigurations(t *testing.T) { t.Fatal(err) } - // Sleep until right (500ms before) before resolution to - // make sure the address didn't resolve on NIC(1) yet. - const delta = 500 * time.Millisecond - time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) + // Sleep until right before resolution to make sure the address didn't + // resolve on NIC(1) yet. + const delta = 1 + clock.Advance(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Fatal(err) } // Wait for DAD to resolve. + clock.Advance(delta) select { - case <-time.After(2 * delta): - // We should get a resolution event after 500ms - // (delta) since we wait for 500ms less than the - // expected resolution time above to make sure - // that the address did not yet resolve. Waiting - // for 1s (2x delta) without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD resolution") } if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { t.Fatal(err) @@ -1032,10 +1032,13 @@ func TestSetNDPConfigurations(t *testing.T) { } } -// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options -// and DHCPv6 configurations specified. -func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) +// raBuf returns a valid NDP Router Advertisement with options, router +// preference and DHCPv6 configurations specified. +func raBuf(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, prf header.NDPRoutePreference, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { + const flagsByte = 1 + const routerLifetimeOffset = 2 + + icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length() hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) @@ -1043,19 +1046,19 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo raPayload := pkt.MessageBody() ra := header.NDPRouterAdvert(raPayload) // Populate the Router Lifetime. - binary.BigEndian.PutUint16(raPayload[2:], rl) + binary.BigEndian.PutUint16(raPayload[routerLifetimeOffset:], rl) // Populate the Managed Address flag field. if managedAddress { - // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing) - // of the RA payload. - raPayload[1] |= (1 << 7) + // The Managed Addresses flag field is the 7th bit of the flags byte. + raPayload[flagsByte] |= 1 << 7 } // Populate the Other Configurations flag field. if otherConfigurations { - // The Other Configurations flag field is the 6th bit of byte #1 - // (0-indexing) of the RA payload. - raPayload[1] |= (1 << 6) + // The Other Configurations flag field is the 6th bit of the flags byte. + raPayload[flagsByte] |= 1 << 6 } + // The Prf field is held in the flags byte. + raPayload[flagsByte] |= byte(prf) << 3 opts := ra.Options() opts.Serialize(optSer) pkt.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ @@ -1083,7 +1086,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo // Note, raBufWithOpts does not populate any of the RA fields other than the // Router Lifetime. func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer) + return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, 0 /* prf */, optSer) } // raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related @@ -1091,18 +1094,26 @@ func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializ // // Note, raBufWithDHCPv6 does not populate any of the RA fields other than the // DHCPv6 related ones. -func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer { - return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{}) +func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfigurations bool) *stack.PacketBuffer { + return raBuf(ip, 0, managedAddresses, otherConfigurations, 0 /* prf */, header.NDPOptionsSerializer{}) } // raBuf returns a valid NDP Router Advertisement. // // Note, raBuf does not populate any of the RA fields other than the // Router Lifetime. -func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer { +func raBufSimple(ip tcpip.Address, rl uint16) *stack.PacketBuffer { return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{}) } +// raBufWithPrf returns a valid NDP Router Advertisement with a preference. +// +// Note, raBufWithPrf does not populate any of the RA fields other than the +// Router Lifetime and Default Router Preference fields. +func raBufWithPrf(ip tcpip.Address, rl uint16, prf header.NDPRoutePreference) *stack.PacketBuffer { + return raBuf(ip, rl, false /* managedAddress */, false /* otherConfigurations */, prf, header.NDPOptionsSerializer{}) +} + // raBufWithPI returns a valid NDP Router Advertisement with a single Prefix // Information option. // @@ -1162,7 +1173,7 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { config: func(enable bool) ipv6.NDPConfigurations { return ipv6.NDPConfigurations{DiscoverDefaultRouters: enable} }, - ra: raBuf(llAddr2, 1000), + ra: raBufSimple(llAddr2, 1000), }, { name: "No Prefix Discovery", @@ -1198,9 +1209,9 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - prefixC: make(chan ndpPrefixEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } ndpConfigs := test.config(enable) ndpConfigs.HandleRAs = handle @@ -1270,8 +1281,8 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) } select { - case e := <-ndpDisp.routerC: - t.Errorf("unexpectedly discovered a router when configured not to: %#v", e) + case e := <-ndpDisp.offLinkRouteC: + t.Errorf("unexpectedly updated an off-link route when configured not to: %#v", e) default: } select { @@ -1297,10 +1308,8 @@ func boolToUint64(v bool) uint64 { return 0 } -// Check e to make sure that the event is for addr on nic with ID 1, and the -// discovered flag set to discovered. -func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { - return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) +func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, nicID tcpip.NICID, router tcpip.Address, prf header.NDPRoutePreference, updated bool) string { + return cmp.Diff(ndpOffLinkRouteEvent{nicID: nicID, subnet: header.IPv6EmptySubnet, router: router, prf: prf, updated: updated}, e, cmp.AllowUnexported(e)) } func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { @@ -1333,56 +1342,15 @@ func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, b } } -// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered router when the dispatcher asks it not to. -func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA for a router we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr2, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the router in the first place. - select { - case <-ndpDisp.routerC: - t.Fatal("should not have received any router events") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): - } -} - func TestRouterDiscovery(t *testing.T) { + const nicID = 1 + testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1391,30 +1359,33 @@ func TestRouterDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) - expectRouterEvent := func(addr tcpip.Address, discovered bool) { + expectOffLinkRouteEvent := func(addr tcpip.Address, prf header.NDPRoutePreference, updated bool) { t.Helper() select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, discovered); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, updated); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected router discovery event") } } - expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, false); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + var prf header.NDPRoutePreference + if diff := checkOffLinkRouteEvent(e, nicID, addr, prf, false); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for router discovery event") } } @@ -1423,37 +1394,44 @@ func TestRouterDiscovery(t *testing.T) { t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } // Rx an RA from lladdr2 with zero lifetime. It should not be // remembered. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0)) select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") + case <-ndpDisp.offLinkRouteC: + t.Fatal("unexpectedly updated an off-link route with 0 lifetime") default: } - // Rx an RA from lladdr2 with a huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + // Rx an RA from lladdr2 with a huge lifetime and reserved preference value + // (which should be interpreted as the default (medium) preference value). + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, 1000, header.ReservedRoutePreference)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) - // Rx an RA from another router (lladdr3) with non-zero lifetime. + // Rx an RA from another router (lladdr3) with non-zero lifetime and + // non-default preference value. const l3LifetimeSeconds = 6 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) - expectRouterEvent(llAddr3, true) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr3, l3LifetimeSeconds, header.HighRoutePreference)) + expectOffLinkRouteEvent(llAddr3, header.HighRoutePreference, true) - // Rx an RA from lladdr2 with lesser lifetime. + // Rx an RA from lladdr2 with lesser lifetime and default (medium) + // preference value. const l2LifetimeSeconds = 2 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, l2LifetimeSeconds)) select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") + case <-ndpDisp.offLinkRouteC: + t.Fatal("should not receive a off-link route event when updating lifetimes for known routers") default: } + // Rx an RA from lladdr2 with a different preference. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPrf(llAddr2, l2LifetimeSeconds, header.LowRoutePreference)) + expectOffLinkRouteEvent(llAddr2, header.LowRoutePreference, true) + // Wait for lladdr2's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. @@ -1461,15 +1439,15 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second) // Rx an RA from lladdr2 with huge lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 1000)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, true) // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - expectRouterEvent(llAddr2, false) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr2, 0)) + expectOffLinkRouteEvent(llAddr2, header.MediumRoutePreference, false) // Wait for lladdr3's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) @@ -1478,16 +1456,17 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second) }) } // TestRouterDiscoveryMaxRouters tests that only -// ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered. +// ipv6.MaxDiscoveredOffLinkRoutes discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { + const nicID = 1 + ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -1500,23 +1479,23 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { })}, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } // Receive an RA from 2 more than the max number of discovered routers. - for i := 1; i <= ipv6.MaxDiscoveredDefaultRouters+2; i++ { + for i := 1; i <= ipv6.MaxDiscoveredOffLinkRoutes+2; i++ { linkAddr := []byte{2, 2, 3, 4, 5, 0} linkAddr[5] = byte(i) llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr)) - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufSimple(llAddr, 5)) - if i <= ipv6.MaxDiscoveredDefaultRouters { + if i <= ipv6.MaxDiscoveredOffLinkRoutes { select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, llAddr, header.MediumRoutePreference, true); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected router discovery event") @@ -1524,7 +1503,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } else { select { - case <-ndpDisp.routerC: + case <-ndpDisp.offLinkRouteC: t.Fatal("should not have discovered a new router after we already discovered the max number of routers") default: } @@ -1538,51 +1517,6 @@ func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) st return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) } -// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered on-link prefix when the dispatcher asks it not to. -func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { - prefix, subnet, _ := prefixSubnetAddr(0, "") - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix that we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet, true); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the prefix in the first place. - select { - case <-ndpDisp.prefixC: - t.Fatal("should not have received any prefix events") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): - } -} - func TestPrefixDiscovery(t *testing.T) { prefix1, subnet1, _ := prefixSubnetAddr(0, "") prefix2, subnet2, _ := prefixSubnetAddr(1, "") @@ -1590,10 +1524,10 @@ func TestPrefixDiscovery(t *testing.T) { testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1602,6 +1536,7 @@ func TestPrefixDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -1662,12 +1597,13 @@ func TestPrefixDiscovery(t *testing.T) { // Wait for prefix2's most recent invalidation job plus some buffer to // expire. + clock.Advance(time.Duration(lifetime) * time.Second) select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet2, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for prefix discovery event") } @@ -1678,17 +1614,6 @@ func TestPrefixDiscovery(t *testing.T) { } func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { - // Update the infinite lifetime value to a smaller value so we can test - // that when we receive a PI with such a lifetime value, we do not - // invalidate the prefix. - const testInfiniteLifetimeSeconds = 2 - const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second - saved := header.NDPInfiniteLifetime - header.NDPInfiniteLifetime = testInfiniteLifetime - defer func() { - header.NDPInfiniteLifetime = saved - }() - prefix := tcpip.AddressWithPrefix{ Address: testutil.MustParse6("102:304:506:708::"), PrefixLen: 64, @@ -1696,10 +1621,10 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { subnet := prefix.Subnet() ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1708,6 +1633,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -1729,46 +1655,39 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with prefix in an NDP Prefix Information option (PI) // with infinite valid lifetime which should not get invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0)) expectPrefixEvent(subnet, true) + clock.Advance(header.NDPInfiniteLifetime) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): + default: } // Receive an RA with finite lifetime. - // The prefix should get invalidated after 1s. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0)) + clock.Advance(header.NDPInfiniteLifetime - time.Second) select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(testInfiniteLifetime): + default: t.Fatal("timed out waiting for prefix discovery event") } // Receive an RA with finite lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds-1, 0)) expectPrefixEvent(subnet, true) // Receive an RA with prefix with an infinite lifetime. // The prefix should not be invalidated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) - select { - case <-ndpDisp.prefixC: - t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): - } - - // Receive an RA with a prefix with a lifetime value greater than the - // set infinite lifetime value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, infiniteLifetimeSeconds, 0)) + clock.Advance(header.NDPInfiniteLifetime) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout): + default: } // Receive an RA with 0 lifetime. @@ -1781,8 +1700,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -1859,17 +1777,12 @@ func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e)) } +const minVLSeconds = uint32(ipv6.MinPrefixInformationValidLifetimeForUpdate / time.Second) +const infiniteLifetimeSeconds = uint32(header.NDPInfiniteLifetime / time.Second) + // TestAutoGenAddr tests that an address is properly generated and invalidated // when configured to do so. func TestAutoGenAddr(t *testing.T) { - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -1878,6 +1791,7 @@ func TestAutoGenAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1886,6 +1800,7 @@ func TestAutoGenAddr(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, forwarding); err != nil { @@ -1935,8 +1850,9 @@ func TestAutoGenAddr(t *testing.T) { default: } - // Receive an RA with prefix2 in a PI. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds + // the minimum. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds+1, 0)) expectAutoGenAddrEvent(addr2, newAddr) if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { t.Fatalf("Should have %s in the list of addresses", addr1) @@ -1946,7 +1862,7 @@ func TestAutoGenAddr(t *testing.T) { } // Refresh valid lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix") @@ -1954,12 +1870,13 @@ func TestAutoGenAddr(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) { @@ -1989,20 +1906,7 @@ func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []t // TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when // configured to do so as part of IPv6 Privacy Extensions. func TestAutoGenTempAddr(t *testing.T) { - const ( - nicID = 1 - newMinVL = 5 - newMinVLDuration = newMinVL * time.Second - ) - - savedMinPrefixInformationValidLifetimeForUpdate := ipv6.MinPrefixInformationValidLifetimeForUpdate - savedMaxDesync := ipv6.MaxDesyncFactor - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate - ipv6.MaxDesyncFactor = savedMaxDesync - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - ipv6.MaxDesyncFactor = time.Nanosecond + const nicID = 1 prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) @@ -2022,218 +1926,211 @@ func TestAutoGenTempAddr(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for i, test := range tests { - i := i - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - seed := []byte{uint8(i)} - var tempIIDHistory [header.IIDSize]byte - header.InitialTempIID(tempIIDHistory[:], seed, nicID) - newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { - return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) - } - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 2), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - DADConfigs: stack.DADConfigurations{ - DupAddrDetectTransmits: test.dupAddrTransmits, - RetransmitTimer: test.retransmitTimer, - }, - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - TempIIDSeed: seed, - })}, - }) - - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } - } - - expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + for i, test := range tests { + t.Run(test.name, func(t *testing.T) { + seed := []byte{uint8(i)} + var tempIIDHistory [header.IIDSize]byte + header.InitialTempIID(tempIIDHistory[:], seed, nicID) + newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix { + return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 2), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + DADConfigs: stack.DADConfigurations{ + DupAddrDetectTransmits: test.dupAddrTransmits, + RetransmitTimer: test.retransmitTimer, + }, + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + MaxTempAddrValidLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate, + MaxTempAddrPreferredLifetime: 2 * ipv6.MinPrefixInformationValidLifetimeForUpdate, + }, + NDPDisp: &ndpDisp, + TempIIDSeed: seed, + })}, + Clock: clock, + }) - expectDADEventAsync := func(addr tcpip.Address) { - t.Helper() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") - } - } + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("expected addr auto gen event") } + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, newAddr) - expectDADEventAsync(addr1.Address) + expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for addr auto gen event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + } - // Receive an RA with prefix1 in an NDP Prefix Information option (PI) - // with non-zero valid & preferred lifetimes. - tempAddr1 := newTempAddr(addr1.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr1, newAddr) - expectDADEventAsync(tempAddr1.Address) - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + expectDADEventAsync := func(addr tcpip.Address) { + t.Helper() - // Receive an RA with prefix2 in an NDP Prefix Information option (PI) - // with preferred lifetime > valid lifetime - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer) select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for DAD event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { - t.Fatal(mismatch) - } + } - // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred - // lifetimes. - tempAddr2 := newTempAddr(addr2.Address) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - expectDADEventAsync(addr2.Address) - expectAutoGenAddrEventAsync(tempAddr2, newAddr) - expectDADEventAsync(tempAddr2.Address) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e) + default: + } - // Deprecate prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, newAddr) + expectDADEventAsync(addr1.Address) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly got an auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Refresh lifetimes for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) + // with non-zero valid & preferred lifetimes. + tempAddr1 := newTempAddr(addr1.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(tempAddr1, newAddr) + expectDADEventAsync(tempAddr1.Address) + if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Reduce valid lifetime and deprecate addresses of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { - t.Fatal(mismatch) - } + // Receive an RA with prefix2 in an NDP Prefix Information option (PI) + // with preferred lifetime > valid lifetime + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6)) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" { + t.Fatal(mismatch) + } - // Wait for addrs of prefix1 to be invalidated. They should be - // invalidated at the same time. - select { - case e := <-ndpDisp.autoGenAddrC: - var nextAddr tcpip.AddressWithPrefix - if e.addr == addr1 { - if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = tempAddr1 - } else { - if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - nextAddr = addr1 - } + // Receive an RA with prefix2 in a PI with a valid lifetime that exceeds + // the minimum and won't be reached in this test. + tempAddr2 := newTempAddr(addr2.Address) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 2*minVLSeconds, 2*minVLSeconds)) + expectAutoGenAddrEvent(addr2, newAddr) + expectDADEventAsync(addr2.Address) + expectAutoGenAddrEventAsync(tempAddr2, newAddr) + expectDADEventAsync(tempAddr2.Address) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") + // Deprecate prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Refresh lifetimes for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Reduce valid lifetime and deprecate addresses of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" { + t.Fatal(mismatch) + } + + // Wait for addrs of prefix1 to be invalidated. They should be + // invalidated at the same time. + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) + select { + case e := <-ndpDisp.autoGenAddrC: + var nextAddr tcpip.AddressWithPrefix + if e.addr == addr1 { + if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) + nextAddr = tempAddr1 + } else { + if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + nextAddr = addr1 } - // Receive an RA with prefix2 in a PI w/ 0 lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) select { case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unexpected auto gen addr event = %+v", e) + if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } default: + t.Fatal("timed out waiting for addr auto gen event") } - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { - t.Fatal(mismatch) - } - }) - } - }) + default: + t.Fatal("timed out waiting for addr auto gen event") + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + + // Receive an RA with prefix2 in a PI w/ 0 lifetimes. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + expectAutoGenAddrEvent(tempAddr2, deprecatedAddr) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unexpected auto gen addr event = %+v", e) + default: + } + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" { + t.Fatal(mismatch) + } + }) + } } // TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not @@ -2241,12 +2138,6 @@ func TestAutoGenTempAddr(t *testing.T) { func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { const nicID = 1 - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - }() - ipv6.MaxDesyncFactor = time.Nanosecond - tests := []struct { name string dupAddrTransmits uint8 @@ -2262,66 +2153,56 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - AutoGenTempGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - AutoGenLinkLocal: true, - })}, - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + dadC: make(chan ndpDADEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + AutoGenTempGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + AutoGenLinkLocal: true, + })}, + Clock: clock, + }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - // The stable link-local address should auto-generate and resolve DAD. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") + // The stable link-local address should auto-generate and resolve DAD. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - select { - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { - t.Errorf("DAD event mismatch (-want +got):\n%s", diff) - } - case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD event") + default: + t.Fatal("expected addr auto gen event") + } + clock.Advance(time.Duration(test.dupAddrTransmits) * test.retransmitTimer) + select { + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD event") + } - // No new addresses should be generated. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Errorf("got unxpected auto gen addr event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - }) - } - }) + // No new addresses should be generated. + select { + case e := <-ndpDisp.autoGenAddrC: + t.Errorf("got unxpected auto gen addr event = %+v", e) + default: + } + }) + } } // TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address @@ -2334,12 +2215,6 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { retransmitTimer = 2 * time.Second ) - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - }() - ipv6.MaxDesyncFactor = 0 - prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) @@ -2350,6 +2225,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -2363,6 +2239,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2392,12 +2269,13 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { // Wait for DAD to complete for the stable address then expect the temporary // address to be generated. + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } select { @@ -2405,7 +2283,7 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } @@ -2414,46 +2292,44 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { // regenerated. func TestAutoGenTempAddrRegen(t *testing.T) { const ( - nicID = 1 - regenAfter = 2 * time.Second - newMinVL = 10 - newMinVLDuration = newMinVL * time.Second - ) + nicID = 1 + regenAdv = 2 * time.Second - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime - }() - ipv6.MaxDesyncFactor = 0 - ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration - ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration + numTempAddrs = 3 + maxTempAddrValidLifetime = numTempAddrs * ipv6.MinPrefixInformationValidLifetimeForUpdate + ) prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix + for i := 0; i < len(tempAddrs); i++ { + tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + } ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: newMinVLDuration - regenAfter, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: regenAdv, + MaxTempAddrValidLifetime: maxTempAddrValidLifetime, + MaxTempAddrPreferredLifetime: ipv6.MinPrefixInformationValidLifetimeForUpdate, + } + clock := faketime.NewManualClock() + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, })}, + Clock: clock, + RandSource: &randSource, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -2476,36 +2352,43 @@ func TestAutoGenTempAddrRegen(t *testing.T) { expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } + tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor + effectiveMaxTempAddrPL := ipv6.MinPrefixInformationValidLifetimeForUpdate - tempDesyncFactor + // The time since the last regeneration before a new temporary address is + // generated. + tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv + // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddr1, newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + expectAutoGenAddrEvent(tempAddrs[0], newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" { t.Fatal(mismatch) } // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" { + expectAutoGenAddrEventAsync(tempAddrs[1], newAddr, tempAddrRegenenerationTime) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0], tempAddrs[1]}, nil); mismatch != "" { t.Fatal(mismatch) } + expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv) // Wait for regeneration - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" { - t.Fatal(mismatch) - } + expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, tempAddrRegenenerationTime-regenAdv) + expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv) // Stop generating temporary addresses ndpConfigs.AutoGenTempGlobalAddresses = false @@ -2516,45 +2399,24 @@ func TestAutoGenTempAddrRegen(t *testing.T) { ndpEP.SetNDPConfigurations(ndpConfigs) } + // Refresh lifetimes and wait for the last temporary address to be deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, minVLSeconds)) + expectAutoGenAddrEventAsync(tempAddrs[2], deprecatedAddr, effectiveMaxTempAddrPL-regenAdv) + + // Refresh lifetimes such that the prefix is valid and preferred forever. + // + // This should not affect the lifetimes of temporary addresses because they + // are capped by the maximum valid and preferred lifetimes for temporary + // addresses. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds)) + // Wait for all the temporary addresses to get invalidated. - tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3} - invalidateAfter := newMinVLDuration - 2*regenAfter + invalidateAfter := maxTempAddrValidLifetime - clock.NowMonotonic().Sub(tcpip.MonotonicTime{}) for _, addr := range tempAddrs { - // Wait for a deprecation then invalidation event, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation jobs could execute in any - // order. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we shouldn't get a deprecation - // event after. - select { - case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpectedly got an auto-generated event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): - } - } else { - t.Fatalf("got unexpected auto-generated event = %+v", e) - } - case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - - invalidateAfter = regenAfter + expectAutoGenAddrEventAsync(addr, invalidatedAddr, invalidateAfter) + invalidateAfter = tempAddrRegenenerationTime } - if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" { + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs[:]); mismatch != "" { t.Fatal(mismatch) } } @@ -2563,52 +2425,54 @@ func TestAutoGenTempAddrRegen(t *testing.T) { // regeneration job gets updated when refreshing the address's lifetimes. func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { const ( - nicID = 1 - regenAfter = 2 * time.Second - newMinVL = 10 - newMinVLDuration = newMinVL * time.Second - ) + nicID = 1 + regenAdv = 2 * time.Second - savedMaxDesyncFactor := ipv6.MaxDesyncFactor - savedMinMaxTempAddrPreferredLifetime := ipv6.MinMaxTempAddrPreferredLifetime - savedMinMaxTempAddrValidLifetime := ipv6.MinMaxTempAddrValidLifetime - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesyncFactor - ipv6.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime - ipv6.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime - }() - ipv6.MaxDesyncFactor = 0 - ipv6.MinMaxTempAddrPreferredLifetime = newMinVLDuration - ipv6.MinMaxTempAddrValidLifetime = newMinVLDuration + numTempAddrs = 3 + maxTempAddrPreferredLifetime = ipv6.MinPrefixInformationValidLifetimeForUpdate + maxTempAddrPreferredLifetimeSeconds = uint32(maxTempAddrPreferredLifetime / time.Second) + ) prefix, _, addr := prefixSubnetAddr(0, linkAddr1) var tempIIDHistory [header.IIDSize]byte header.InitialTempIID(tempIIDHistory[:], nil, nicID) - tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) - tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + var tempAddrs [numTempAddrs]tcpip.AddressWithPrefix + for i := 0; i < len(tempAddrs); i++ { + tempAddrs[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address) + } ndpDisp := ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) ndpConfigs := ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - AutoGenTempGlobalAddresses: true, - RegenAdvanceDuration: newMinVLDuration - regenAfter, + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + AutoGenTempGlobalAddresses: true, + RegenAdvanceDuration: regenAdv, + MaxTempAddrPreferredLifetime: maxTempAddrPreferredLifetime, + MaxTempAddrValidLifetime: maxTempAddrPreferredLifetime * 2, + } + clock := faketime.NewManualClock() + initialTime := clock.NowMonotonic() + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ndpConfigs, NDPDisp: &ndpDisp, })}, + Clock: clock, + RandSource: &randSource, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } + tempDesyncFactor := time.Duration(randSource.lastInt63) % ipv6.MaxDesyncFactor + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -2625,22 +2489,23 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } // Receive an RA with prefix1 in an NDP Prefix Information option (PI) // with non-zero valid & preferred lifetimes. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds)) expectAutoGenAddrEvent(addr, newAddr) - expectAutoGenAddrEvent(tempAddr1, newAddr) - if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" { + expectAutoGenAddrEvent(tempAddrs[0], newAddr) + if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddrs[0]}, nil); mismatch != "" { t.Fatal(mismatch) } @@ -2648,13 +2513,27 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // // A new temporary address should be generated after the regeneration // time has passed since the prefix is deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, 0)) expectAutoGenAddrEvent(addr, deprecatedAddr) - expectAutoGenAddrEvent(tempAddr1, deprecatedAddr) + expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr) select { case e := <-ndpDisp.autoGenAddrC: - t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): + t.Fatalf("unexpected auto gen addr event = %#v", e) + default: + } + + effectiveMaxTempAddrPL := maxTempAddrPreferredLifetime - tempDesyncFactor + // The time since the last regeneration before a new temporary address is + // generated. + tempAddrRegenenerationTime := effectiveMaxTempAddrPL - regenAdv + + // Advance the clock by the regeneration time but don't expect a new temporary + // address as the prefix is deprecated. + clock.Advance(tempAddrRegenenerationTime) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %#v", e) + default: } // Prefer the prefix again. @@ -2662,8 +2541,15 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // A new temporary address should immediately be generated since the // regeneration time has already passed since the last address was generated // - this regeneration does not depend on a job. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEvent(tempAddr2, newAddr) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, maxTempAddrPreferredLifetimeSeconds, maxTempAddrPreferredLifetimeSeconds)) + expectAutoGenAddrEvent(tempAddrs[1], newAddr) + // Wait for the first temporary address to be deprecated. + expectAutoGenAddrEventAsync(tempAddrs[0], deprecatedAddr, regenAdv) + select { + case e := <-ndpDisp.autoGenAddrC: + t.Fatalf("unexpected auto gen addr event = %s", e) + default: + } // Increase the maximum lifetimes for temporary addresses to large values // then refresh the lifetimes of the prefix. @@ -2674,34 +2560,30 @@ func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { // regenerate a new temporary address. Note, new addresses are only // regenerated after the preferred lifetime - the regenerate advance duration // as paased. - ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second - ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second + const largeLifetimeSeconds = minVLSeconds * 2 + const largeLifetime = time.Duration(largeLifetimeSeconds) * time.Second + ndpConfigs.MaxTempAddrValidLifetime = 2 * largeLifetime + ndpConfigs.MaxTempAddrPreferredLifetime = largeLifetime ipv6Ep, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) if err != nil { t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) } ndpEP := ipv6Ep.(ipv6.NDPEndpoint) ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + timeSinceInitialTime := clock.NowMonotonic().Sub(initialTime) + clock.Advance(largeLifetime - timeSinceInitialTime) + expectAutoGenAddrEvent(tempAddrs[0], deprecatedAddr) + // to offset the advement of time to test the first temporary address's + // deprecation after the second was generated + advLess := regenAdv + expectAutoGenAddrEventAsync(tempAddrs[2], newAddr, timeSinceInitialTime-advLess-(tempDesyncFactor+regenAdv)) + expectAutoGenAddrEventAsync(tempAddrs[1], deprecatedAddr, regenAdv) select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpected auto gen addr event = %+v", e) - case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout): + default: } - - // Set the maximum lifetimes for temporary addresses such that on the next - // RA, the regeneration job gets scheduled again. - // - // The maximum lifetime is the sum of the minimum lifetimes for temporary - // addresses + the time that has already passed since the last address was - // generated so that the regeneration job is needed to generate the next - // address. - newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout - ndpConfigs.MaxTempAddrValidLifetime = newLifetimes - ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes - ndpEP.SetNDPConfigurations(ndpConfigs) - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) - expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout) } // TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response @@ -2929,13 +2811,14 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // stack.Stack will have a default route through the router (llAddr3) installed // and a static link-address (linkAddr3) added to the link address cache for the // router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) { t.Helper() ndpDisp := &ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -2945,6 +2828,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd NDPDisp: ndpDisp, })}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -2958,7 +2842,7 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd if err := s.AddStaticNeighbor(nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3); err != nil { t.Fatalf("s.AddStaticNeighbor(%d, %d, %s, %s): %s", nicID, ipv6.ProtocolNumber, llAddr3, linkAddr3, err) } - return ndpDisp, e, s + return ndpDisp, e, s, clock } // addrForNewConnectionTo returns the local address used when creating a new @@ -3032,7 +2916,7 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3135,19 +3019,11 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { // when its preferred lifetime expires. func TestAutoGenAddrJobDeprecation(t *testing.T) { const nicID = 1 - const newMinVL = 2 - newMinVLDuration := newMinVL * time.Second - - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, clock := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3165,12 +3041,13 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } @@ -3188,7 +3065,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, infiniteLifetimeSeconds, infiniteLifetimeSeconds)) expectAutoGenAddrEvent(addr2, newAddr) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { t.Fatalf("should have %s in the list of addresses", addr2) @@ -3207,7 +3084,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3216,7 +3093,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3226,6 +3103,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // addr2 should be the primary endpoint now since addr1 is deprecated but // addr2 is not. expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { @@ -3234,7 +3112,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, 0)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3246,7 +3124,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, minVLSeconds, minVLSeconds-1)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3256,7 +3134,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr1) // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, ipv6.MinPrefixInformationValidLifetimeForUpdate-time.Second) if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3270,7 +3148,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second) if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { t.Fatalf("should not have %s in the list of addresses", addr1) } @@ -3280,7 +3158,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { expectPrimaryAddr(addr2) // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, minVLSeconds, minVLSeconds)) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") @@ -3292,6 +3170,17 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // cases because the deprecation and invalidation handlers could be handled in // either deprecation then invalidation, or invalidation then deprecation // (which should be cancelled by the invalidation handler). + // + // Since we're about to cause both events to fire, we need the dispatcher + // channel to be able to hold both. + if got, want := len(ndpDisp.autoGenAddrC), 0; got != want { + t.Fatalf("got len(ndpDisp.autoGenAddrC) = %d, want %d", got, want) + } + if got, want := cap(ndpDisp.autoGenAddrC), 1; got != want { + t.Fatalf("got cap(ndpDisp.autoGenAddrC) = %d, want %d", got, want) + } + ndpDisp.autoGenAddrC = make(chan ndpAutoGenAddrEvent, 2) + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { @@ -3302,21 +3191,21 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation + // If we get an invalidation event first, we should not get a deprecation // event after. select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } } else { t.Fatalf("got unexpected auto-generated event") } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3353,15 +3242,6 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { // infinite values. func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { const infiniteVLSeconds = 2 - const minVLSeconds = 1 - savedIL := header.NDPInfiniteLifetime - savedMinVL := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = savedMinVL - header.NDPInfiniteLifetime = savedIL - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second - header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3385,68 +3265,58 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + } + e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + Clock: clock, + }) - // Receive an RA with finite prefix. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - default: - t.Fatal("expected addr auto gen event") + // Receive an RA with finite prefix. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - // Receive an new RA with prefix with infinite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) + default: + t.Fatal("expected addr auto gen event") + } - // Receive a new RA with prefix with finite VL. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) + // Receive an new RA with prefix with infinite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } + // Receive a new RA with prefix with finite VL. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0)) - case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") + clock.Advance(ipv6.MinPrefixInformationValidLifetimeForUpdate) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - }) - } - }) + + default: + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } } // TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an @@ -3454,12 +3324,6 @@ func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) { // RFC 4862 section 5.5.3.e. func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { const infiniteVL = 4294967295 - const newMinVL = 4 - saved := ipv6.MinPrefixInformationValidLifetimeForUpdate - defer func() { - ipv6.MinPrefixInformationValidLifetimeForUpdate = saved - }() - ipv6.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second prefix, _, addr := prefixSubnetAddr(0, linkAddr1) @@ -3470,137 +3334,129 @@ func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) { evl uint32 }{ // Should update the VL to the minimum VL for updating if the - // new VL is less than newMinVL but was originally greater than + // new VL is less than minVLSeconds but was originally greater than // it. { "LargeVLToVLLessThanMinVLForUpdate", 9999, 1, - newMinVL, + minVLSeconds, }, { "LargeVLTo0", 9999, 0, - newMinVL, + minVLSeconds, }, { "InfiniteVLToVLLessThanMinVLForUpdate", infiniteVL, 1, - newMinVL, + minVLSeconds, }, { "InfiniteVLTo0", infiniteVL, 0, - newMinVL, + minVLSeconds, }, - // Should not update VL if original VL was less than newMinVL - // and the new VL is also less than newMinVL. + // Should not update VL if original VL was less than minVLSeconds + // and the new VL is also less than minVLSeconds. { "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate", - newMinVL - 1, - newMinVL - 3, - newMinVL - 1, + minVLSeconds - 1, + minVLSeconds - 3, + minVLSeconds - 1, }, // Should take the new VL if the new VL is greater than the - // remaining time or is greater than newMinVL. + // remaining time or is greater than minVLSeconds. { "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate", - newMinVL + 5, - newMinVL + 3, - newMinVL + 3, + minVLSeconds + 5, + minVLSeconds + 3, + minVLSeconds + 3, }, { "SmallVLToGreaterVLButStillLessThanMinVLForUpdate", - newMinVL - 3, - newMinVL - 1, - newMinVL - 1, + minVLSeconds - 3, + minVLSeconds - 1, + minVLSeconds - 1, }, { "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate", - newMinVL - 3, - newMinVL + 1, - newMinVL + 1, + minVLSeconds - 3, + minVLSeconds + 1, + minVLSeconds + 1, }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - ndpDisp := ndpDispatcher{ - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), - } - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - AutoGenGlobalAddresses: true, - }, - NDPDisp: &ndpDisp, - })}, - }) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ndpDisp := ndpDispatcher{ + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10), + } + e := channel.New(10, 1280, linkAddr1) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, + AutoGenGlobalAddresses: true, + }, + NDPDisp: &ndpDisp, + })}, + Clock: clock, + }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(1) = %s", err) + } - // Receive an RA with prefix with initial VL, - // test.ovl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") + // Receive an RA with prefix with initial VL, + // test.ovl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0)) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("expected addr auto gen event") + } - // Receive an new RA with prefix with new VL, - // test.nvl. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) + // Receive an new RA with prefix with new VL, + // test.nvl. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0)) - // - // Validate that the VL for the address got set - // to test.evl. - // + // + // Validate that the VL for the address got set + // to test.evl. + // - // The address should not be invalidated until the effective valid - // lifetime has passed. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout): - } + // The address should not be invalidated until the effective valid + // lifetime has passed. + const delta = 1 + clock.Advance(time.Duration(test.evl)*time.Second - delta) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly received an auto gen addr event") + default: + } - // Wait for the invalidation event. - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timeout waiting for addr auto gen event") + // Wait for the invalidation event. + clock.Advance(delta) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - }) - } - }) + default: + t.Fatal("timeout waiting for addr auto gen event") + } + }) + } } // TestAutoGenAddrRemoval tests that when auto-generated addresses are removed @@ -3613,6 +3469,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3621,6 +3478,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -3654,10 +3512,11 @@ func TestAutoGenAddrRemoval(t *testing.T) { // Wait for the original valid lifetime to make sure the original job got // cancelled/cleaned up. + clock.Advance(lifetimeSeconds * time.Second) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + default: } } @@ -3668,7 +3527,7 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) { prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + ndpDisp, e, s, _ := stackAndNdpDispatcherWithDefaultRoute(t, nicID) expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() @@ -3779,6 +3638,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3787,6 +3647,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -3816,30 +3677,36 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { // Should not get an invalidation event after the PI's invalidation // time. + clock.Advance(lifetimeSeconds * time.Second) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + default: } if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) } } +func makeSecretKey(t *testing.T) []byte { + secretKey := make([]byte, header.OpaqueIIDSecretKeyMinBytes) + n, err := cryptorand.Read(secretKey) + if err != nil { + t.Fatalf("cryptorand.Read(_): %s", err) + } + if l := len(secretKey); n != l { + t.Fatalf("got cryptorand.Read(_) = (%d, nil), want = (%d, nil)", n, l) + } + return secretKey +} + // TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use // opaque interface identifiers when configured to do so. func TestAutoGenAddrWithOpaqueIID(t *testing.T) { const nicID = 1 const nicName = "nic1" - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + + secretKey := makeSecretKey(t) prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1) prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1) @@ -3861,6 +3728,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3875,6 +3743,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -3913,12 +3782,13 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. + clock.Advance(validLifetimeSecondPrefix1 * time.Second) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3937,22 +3807,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { const maxMaxRetries = 3 const lifetimeSeconds = 10 - // Needed for the temporary address sub test. - savedMaxDesync := ipv6.MaxDesyncFactor - defer func() { - ipv6.MaxDesyncFactor = savedMaxDesync - }() - ipv6.MaxDesyncFactor = time.Nanosecond - - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + secretKey := makeSecretKey(t) prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) @@ -3977,22 +3832,24 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + expectAutoGenAddrEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { t.Helper() + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } } - expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { + expectDADEvent := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() + clock.RunImmediatelyScheduledJobs() select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr, res); diff != "" { @@ -4003,15 +3860,16 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { + expectDADEventAsync := func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr, res); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } } @@ -4022,7 +3880,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { name string ndpConfigs ipv6.NDPConfigurations autoGenLinkLocal bool - prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix + prepareFn func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix }{ { @@ -4031,7 +3889,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, AutoGenGlobalAddresses: true, }, - prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { + prepareFn: func(_ *testing.T, _ *faketime.ManualClock, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix { // Receive an RA with prefix1 in a PI. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds)) return nil @@ -4045,7 +3903,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { name: "LinkLocal address", ndpConfigs: ipv6.NDPConfigurations{}, autoGenLinkLocal: true, - prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { + prepareFn: func(*testing.T, *faketime.ManualClock, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix { return nil }, addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix { @@ -4059,14 +3917,14 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { AutoGenGlobalAddresses: true, AutoGenTempGlobalAddresses: true, }, - prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { + prepareFn: func(t *testing.T, clock *faketime.ManualClock, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix { header.InitialTempIID(tempIIDHistory, nil, nicID) // Generate a stable SLAAC address so temporary addresses will be // generated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) - expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) + expectDADEventAsync(t, clock, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) // The stable address will be assigned throughout the test. return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} @@ -4078,14 +3936,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } for _, addrType := range addrTypes { - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the parallel - // tests complete and limit the number of parallel tests running at the same - // time to reduce flakes. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. t.Run(addrType.name, func(t *testing.T) { for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ { for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ { @@ -4094,8 +3944,6 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { addrType := addrType t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), @@ -4103,6 +3951,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { e := channel.New(0, 1280, linkAddr1) ndpConfigs := addrType.ndpConfigs ndpConfigs.AutoGenAddressConflictRetries = maxRetries + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: addrType.autoGenLinkLocal, @@ -4119,6 +3968,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4126,12 +3976,12 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } var tempIIDHistory [header.IIDSize]byte - stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:]) + stableAddrs := addrType.prepareFn(t, clock, &ndpDisp, e, tempIIDHistory[:]) // Simulate DAD conflicts so the address is regenerated. for i := uint8(0); i < numFailures; i++ { addr := addrType.addrGenFn(i, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) + expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr) // Should not have any new addresses assigned to the NIC. if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" { @@ -4141,7 +3991,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // Simulate a DAD conflict. rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) - expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) + expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) // Attempting to add the address manually should not fail if the // address's state was cleaned up when DAD failed. @@ -4151,7 +4001,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if err := s.RemoveAddress(nicID, addr.Address); err != nil { t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) } - expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADAborted{}) + expectDADEvent(t, clock, &ndpDisp, addr.Address, &stack.DADAborted{}) } // Should not have any new addresses assigned to the NIC. @@ -4163,8 +4013,8 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // an address after DAD resolves. if maxRetries+1 > numFailures { addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) - expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) - expectDADEventAsync(t, &ndpDisp, addr.Address, &stack.DADSucceeded{}) + expectAutoGenAddrEventAsync(t, clock, &ndpDisp, addr, newAddr) + expectDADEventAsync(t, clock, &ndpDisp, addr.Address, &stack.DADSucceeded{}) if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { t.Fatal(mismatch) } @@ -4174,7 +4024,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } }) } @@ -4231,13 +4081,12 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { addrType := addrType t.Run(addrType.name, func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: addrType.autoGenLinkLocal, @@ -4248,6 +4097,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { RetransmitTimer: retransmitTimer, }, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -4292,7 +4142,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } }) } @@ -4309,15 +4159,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { const maxRetries = 1 const lifetimeSeconds = 5 - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + secretKey := makeSecretKey(t) prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) @@ -4326,6 +4168,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -4345,6 +4188,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4375,7 +4219,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { expectAutoGenAddrEvent(addr, newAddr) // Simulate a DAD conflict after some time has passed. - time.Sleep(failureTimer) + clock.Advance(failureTimer) rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(addr, invalidatedAddr) select { @@ -4390,12 +4234,13 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { // Let the next address resolve. addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey)) expectAutoGenAddrEvent(addr, newAddr) + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } @@ -4409,6 +4254,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { // // We expect either just the invalidation event or the deprecation event // followed by the invalidation event. + clock.Advance(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer) select { case e := <-ndpDisp.autoGenAddrC: if e.eventType == deprecatedAddr { @@ -4421,7 +4267,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") } } else { @@ -4429,7 +4275,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } } - case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for auto gen addr event") } } @@ -4691,11 +4537,9 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { ) ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ @@ -4738,17 +4582,17 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { ), ) select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, nicID, llAddr3, header.MediumRoutePreference, true /* discovered */); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID) + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID) } select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID) @@ -4770,8 +4614,8 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } select { - case e := <-ndpDisp.routerC: - t.Errorf("unexpected router event = %#v", e) + case e := <-ndpDisp.offLinkRouteC: + t.Errorf("unexpected off-link route event = %#v", e) default: } select { @@ -4857,12 +4701,11 @@ func TestCleanupNDPState(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, maxRouterAndPrefixEvents), + prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), } + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, @@ -4874,16 +4717,17 @@ func TestCleanupNDPState(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) - expectRouterEvent := func() (bool, ndpRouterEvent) { + expectOffLinkRouteEvent := func() (bool, ndpOffLinkRouteEvent) { select { - case e := <-ndpDisp.routerC: + case e := <-ndpDisp.offLinkRouteC: return true, e default: } - return false, ndpRouterEvent{} + return false, ndpOffLinkRouteEvent{} } expectPrefixEvent := func() (bool, ndpPrefixEvent) { @@ -4928,8 +4772,8 @@ func TestCleanupNDPState(t *testing.T) { // multiple addresses. e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) @@ -4939,8 +4783,8 @@ func TestCleanupNDPState(t *testing.T) { } e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) @@ -4950,8 +4794,8 @@ func TestCleanupNDPState(t *testing.T) { } e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) @@ -4961,8 +4805,8 @@ func TestCleanupNDPState(t *testing.T) { } e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) @@ -5003,14 +4847,14 @@ func TestCleanupNDPState(t *testing.T) { test.cleanupFn(t, s) // Collect invalidation events after having NDP state cleaned up. - gotRouterEvents := make(map[ndpRouterEvent]int) + gotOffLinkRouteEvents := make(map[ndpOffLinkRouteEvent]int) for i := 0; i < maxRouterAndPrefixEvents; i++ { - ok, e := expectRouterEvent() + ok, e := expectOffLinkRouteEvent() if !ok { - t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + t.Errorf("expected %d off-link route events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) break } - gotRouterEvents[e]++ + gotOffLinkRouteEvents[e]++ } gotPrefixEvents := make(map[ndpPrefixEvent]int) for i := 0; i < maxRouterAndPrefixEvents; i++ { @@ -5037,14 +4881,14 @@ func TestCleanupNDPState(t *testing.T) { t.FailNow() } - expectedRouterEvents := map[ndpRouterEvent]int{ - {nicID: nicID1, addr: llAddr3, discovered: false}: 1, - {nicID: nicID1, addr: llAddr4, discovered: false}: 1, - {nicID: nicID2, addr: llAddr3, discovered: false}: 1, - {nicID: nicID2, addr: llAddr4, discovered: false}: 1, + expectedOffLinkRouteEvents := map[ndpOffLinkRouteEvent]int{ + {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, + {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, + {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, + {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, } - if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { - t.Errorf("router events mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(expectedOffLinkRouteEvents, gotOffLinkRouteEvents); diff != "" { + t.Errorf("off-link route events mismatch (-want +got):\n%s", diff) } expectedPrefixEvents := map[ndpPrefixEvent]int{ {nicID: nicID1, prefix: subnet1, discovered: false}: 1, @@ -5106,10 +4950,10 @@ func TestCleanupNDPState(t *testing.T) { // Should not get any more events (invalidation timers should have been // cancelled when the NDP state was cleaned up). - time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout) + clock.Advance(lifetimeSeconds * time.Second) select { - case <-ndpDisp.routerC: - t.Error("unexpected router event") + case <-ndpDisp.offLinkRouteC: + t.Error("unexpected off-link route event") default: } select { @@ -5134,7 +4978,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { ndpDisp := ndpDispatcher{ dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), - rememberRouter: true, } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -5236,6 +5079,23 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { expectNoDHCPv6Event() } +var _ rand.Source = (*savingRandSource)(nil) + +type savingRandSource struct { + s rand.Source + + lastInt63 int64 +} + +func (d *savingRandSource) Int63() int64 { + i := d.s.Int63() + d.lastInt63 = i + return i +} +func (d *savingRandSource) Seed(seed int64) { + d.s.Seed(seed) +} + // TestRouterSolicitation tests the initial Router Solicitations that are sent // when a NIC newly becomes enabled. func TestRouterSolicitation(t *testing.T) { @@ -5402,6 +5262,9 @@ func TestRouterSolicitation(t *testing.T) { t.Fatalf("unexpectedly got a packet = %#v", p) } } + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), + } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -5411,8 +5274,10 @@ func TestRouterSolicitation(t *testing.T) { MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, }, })}, - Clock: clock, + Clock: clock, + RandSource: &randSource, }) + if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -5425,19 +5290,27 @@ func TestRouterSolicitation(t *testing.T) { // Make sure each RS is sent at the right time. remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay) + if remaining != 0 { + maxRtrSolicitDelay := test.maxRtrSolicitDelay + if maxRtrSolicitDelay < 0 { + maxRtrSolicitDelay = ipv6.DefaultNDPConfigurations().MaxRtrSolicitationDelay + } + var actualRtrSolicitDelay time.Duration + if maxRtrSolicitDelay != 0 { + actualRtrSolicitDelay = time.Duration(randSource.lastInt63) % maxRtrSolicitDelay + } + waitForPkt(actualRtrSolicitDelay) remaining-- } subTest.afterFirstRS(t, s) - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + for ; remaining != 0; remaining-- { + if test.effectiveRtrSolicitInt != 0 { waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) waitForPkt(time.Nanosecond) } else { - waitForPkt(test.effectiveRtrSolicitInt) + waitForPkt(0) } } @@ -5533,12 +5406,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { e := channel.New(maxRtrSolicitations, 1280, linkAddr1) - waitForPkt := func(timeout time.Duration) { + waitForPkt := func(clock *faketime.ManualClock, timeout time.Duration) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := e.ReadContext(ctx) + clock.Advance(timeout) + p, ok := e.Read() if !ok { t.Fatal("timed out waiting for packet") } @@ -5552,6 +5424,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPRS()) } + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -5561,6 +5434,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { MaxRtrSolicitationDelay: delay, }, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -5568,13 +5442,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stop soliciting routers. test.stopFn(t, s, true /* first */) - ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(delay) + if _, ok := e.Read(); ok { // A single RS may have been sent before solicitations were stopped. - ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok = e.ReadContext(ctx); ok { + clock.Advance(interval) + if _, ok = e.Read(); ok { t.Fatal("should not have sent more than one RS message") } } @@ -5582,9 +5454,8 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stopping router solicitations after it has already been stopped should // do nothing. test.stopFn(t, s, false /* first */) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(delay) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") } @@ -5595,21 +5466,19 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Start soliciting routers. test.startFn(t, s) - waitForPkt(delay + defaultAsyncPositiveEventTimeout) - waitForPkt(interval + defaultAsyncPositiveEventTimeout) - waitForPkt(interval + defaultAsyncPositiveEventTimeout) - ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + waitForPkt(clock, delay) + waitForPkt(clock, interval) + waitForPkt(clock, interval) + clock.Advance(interval) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") } // Starting router solicitations after it has already completed should do // nothing. test.startFn(t, s) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(interval) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got a packet after finishing router solicitations") } }) diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 509f5ce5c..08857e1a9 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -310,7 +310,7 @@ func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { func (n *neighborCache) init(nic *nic, r LinkAddressResolver) { *n = neighborCache{ nic: nic, - state: NewNUDState(nic.stack.nudConfigs, nic.stack.randomGenerator), + state: NewNUDState(nic.stack.nudConfigs, nic.stack.clock, nic.stack.randomGenerator), linkRes: r, } n.mu.Lock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 9821a18d3..7de25fe37 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -15,8 +15,6 @@ package stack import ( - "bytes" - "encoding/binary" "fmt" "math" "math/rand" @@ -48,9 +46,6 @@ const ( // be sent to all nodes. testEntryBroadcastAddr = tcpip.Address("broadcast") - // testEntryLocalAddr is the source address of neighbor probes. - testEntryLocalAddr = tcpip.Address("local_addr") - // testEntryBroadcastLinkAddr is a special link address sent back to // multicast neighbor probes. testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") @@ -95,7 +90,7 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl randomGenerator: rng, }, id: 1, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), }, linkRes) return linkRes } @@ -106,20 +101,24 @@ type testEntryStore struct { entriesMap map[tcpip.Address]NeighborEntry } -func toAddress(i int) tcpip.Address { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint16(i)) - return tcpip.Address(buf.String()) +func toAddress(i uint16) tcpip.Address { + return tcpip.Address([]byte{ + 1, + 0, + byte(i >> 8), + byte(i), + }) } -func toLinkAddress(i int) tcpip.LinkAddress { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint32(i)) - return tcpip.LinkAddress(buf.String()) +func toLinkAddress(i uint16) tcpip.LinkAddress { + return tcpip.LinkAddress([]byte{ + 1, + 0, + 0, + 0, + byte(i >> 8), + byte(i), + }) } // newTestEntryStore returns a testEntryStore pre-populated with entries. @@ -127,7 +126,7 @@ func newTestEntryStore() *testEntryStore { store := &testEntryStore{ entriesMap: make(map[tcpip.Address]NeighborEntry), } - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) linkAddr := toLinkAddress(i) @@ -140,15 +139,15 @@ func newTestEntryStore() *testEntryStore { } // size returns the number of entries in the store. -func (s *testEntryStore) size() int { +func (s *testEntryStore) size() uint16 { s.mu.RLock() defer s.mu.RUnlock() - return len(s.entriesMap) + return uint16(len(s.entriesMap)) } // entry returns the entry at index i. Returns an empty entry and false if i is // out of bounds. -func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { +func (s *testEntryStore) entry(i uint16) (NeighborEntry, bool) { return s.entryByAddr(toAddress(i)) } @@ -166,7 +165,7 @@ func (s *testEntryStore) entries() []NeighborEntry { entries := make([]NeighborEntry, 0, len(s.entriesMap)) s.mu.RLock() defer s.mu.RUnlock() - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) if entry, ok := s.entriesMap[addr]; ok { entries = append(entries, entry) @@ -176,7 +175,7 @@ func (s *testEntryStore) entries() []NeighborEntry { } // set modifies the link addresses of an entry. -func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { +func (s *testEntryStore) set(i uint16, linkAddr tcpip.LinkAddress) { addr := toAddress(i) s.mu.Lock() defer s.mu.Unlock() @@ -236,13 +235,6 @@ func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return 0 } -type entryEvent struct { - nicID tcpip.NICID - address tcpip.Address - linkAddr tcpip.LinkAddress - state NeighborState -} - func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() @@ -301,10 +293,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }) } @@ -313,10 +305,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }) @@ -347,10 +339,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -419,10 +411,10 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -461,7 +453,7 @@ func newTestContext(c NUDConfigurations) testContext { } type overflowOptions struct { - startAtEntryIndex int + startAtEntryIndex uint16 wantStaticEntries []NeighborEntry } @@ -500,12 +492,12 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if !ok { return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(c.linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(c.linkRes.entries.size()-i-1) * typicalLatency wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now().Add(-durationReachableNanos), } wantUnorderedEntries = append(wantUnorderedEntries, wantEntry) } @@ -571,10 +563,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, } @@ -616,10 +608,10 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -663,10 +655,10 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -689,10 +681,10 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -733,10 +725,10 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -758,10 +750,10 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -814,20 +806,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -844,10 +836,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -875,10 +867,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -890,10 +882,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -910,10 +902,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -947,10 +939,10 @@ func TestNeighborCacheClear(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + UpdatedAt: clock.Now(), }, }, } @@ -973,20 +965,20 @@ func TestNeighborCacheClear(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + UpdatedAt: clock.Now(), }, }, } @@ -1027,10 +1019,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, } @@ -1062,13 +1054,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { clock := faketime.NewManualClock() linkRes := newTestNeighborResolver(&nudDisp, config, clock) - startedAt := clock.NowNanoseconds() + startedAt := clock.Now() // The following logic is very similar to overflowCache, but // periodically refreshes the frequently used entry. // Fill the neighbor cache to capacity - for i := 0; i < neighborCacheSize; i++ { + for i := uint16(0); i < neighborCacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) @@ -1084,7 +1076,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Keep adding more entries - for i := neighborCacheSize; i < linkRes.entries.size(); i++ { + for i := uint16(neighborCacheSize); i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { @@ -1118,7 +1110,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { State: Reachable, // Can be inferred since the frequently used entry is the first to // be created and transitioned to Reachable. - UpdatedAtNanos: startedAt + typicalLatency.Nanoseconds(), + UpdatedAt: startedAt.Add(typicalLatency), }, } @@ -1127,12 +1119,12 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now().Add(-durationReachableNanos), }) } @@ -1190,12 +1182,12 @@ func TestNeighborCacheConcurrent(t *testing.T) { if !ok { t.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now().Add(-durationReachableNanos), }) } @@ -1244,10 +1236,10 @@ func TestNeighborCacheReplace(t *testing.T) { t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Delay, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Delay, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1263,10 +1255,10 @@ func TestNeighborCacheReplace(t *testing.T) { t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1301,10 +1293,10 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, got); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1405,10 +1397,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -1436,10 +1428,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -1455,10 +1447,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { { wantEntries := []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Unreachable, + UpdatedAt: clock.Now(), }, } if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, unorderedEntriesDiffOpts()...); diff != "" { @@ -1488,10 +1480,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -1518,10 +1510,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1541,10 +1533,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(gotEntry, wantEntry); diff != "" { t.Fatalf("neighbor entry mismatch (-got, +want):\n%s", diff) @@ -1561,9 +1553,9 @@ func BenchmarkCacheClear(b *testing.B) { linkRes.delay = 0 // Clear for every possible size of the cache - for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { + for cacheSize := uint16(0); cacheSize < neighborCacheSize; cacheSize++ { // Fill the neighbor cache to capacity. - for i := 0; i < cacheSize; i++ { + for i := uint16(0); i < cacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 6d95e1664..0a59eecdd 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -31,10 +31,10 @@ const ( // NeighborEntry describes a neighboring device in the local network. type NeighborEntry struct { - Addr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAtNanos int64 + Addr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time } // NeighborState defines the state of a NeighborEntry within the Neighbor @@ -138,10 +138,10 @@ func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState * // calling `setStateLocked`. func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { entry := NeighborEntry{ - Addr: addr, - LinkAddr: linkAddr, - State: Static, - UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(), + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: cache.nic.stack.clock.Now(), } n := &neighborEntry{ cache: cache, @@ -166,14 +166,20 @@ func (e *neighborEntry) notifyCompletionLocked(err tcpip.Error) { if ch := e.mu.done; ch != nil { close(ch) e.mu.done = nil - // Dequeue the pending packets in a new goroutine to not hold up the current + // Dequeue the pending packets asynchronously to not hold up the current // goroutine as writing packets may be a costly operation. // // At the time of writing, when writing packets, a neighbor's link address // is resolved (which ends up obtaining the entry's lock) while holding the - // link resolution queue's lock. Dequeuing packets in a new goroutine avoids - // a lock ordering violation. - go e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err) + // link resolution queue's lock. Dequeuing packets asynchronously avoids a + // lock ordering violation. + // + // NB: this is equivalent to spawning a goroutine directly using the go + // keyword but allows tests that use manual clocks to deterministically + // wait for this work to complete. + e.cache.nic.stack.clock.AfterFunc(0, func() { + e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err) + }) } } @@ -224,7 +230,7 @@ func (e *neighborEntry) cancelTimerLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) removeLocked() { - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() e.dispatchRemoveEventLocked() e.cancelTimerLocked() // TODO(https://gvisor.dev/issues/5583): test the case where this function is @@ -246,7 +252,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { prev := e.mu.neigh.State e.mu.neigh.State = next - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() config := e.nudState.Config() switch next { @@ -307,7 +313,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { err = e.cache.linkRes.LinkAddressRequest(addr, "" /* localAddr */, linkAddr) @@ -354,14 +360,14 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { case Unknown, Unreachable: prev := e.mu.neigh.State e.mu.neigh.State = Incomplete - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() switch prev { case Unknown: e.dispatchAddEventLocked() case Unreachable: e.dispatchChangeEventLocked() - e.cache.nic.stats.Neighbor.UnreachableEntryLookups.Increment() + e.cache.nic.stats.neighbor.unreachableEntryLookups.Increment() } config := e.nudState.Config() @@ -378,7 +384,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { // As per RFC 4861 section 7.2.2: diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 1d39ee73d..59d86d6d4 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -36,11 +36,6 @@ const ( entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") - - // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match the - // MTU of loopback interfaces on Linux systems. - entryTestNetDefaultMTU = 65536 ) var ( @@ -196,13 +191,13 @@ func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.A // ResolveStaticAddress attempts to resolve address without sending requests. // It either resolves the name immediately or returns the empty LinkAddress. -func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*entryTestLinkResolver) ResolveStaticAddress(tcpip.Address) (tcpip.LinkAddress, bool) { return "", false } // LinkAddressProtocol returns the network protocol of the addresses this // resolver can resolve. -func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return entryTestNetNumber } @@ -219,7 +214,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e nudConfigs: c, randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), }, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ @@ -354,10 +349,10 @@ func unknownToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes * EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -415,10 +410,10 @@ func unknownToStale(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entry EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -446,7 +441,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { // UpdatedAt should remain the same during address resolution. e.mu.Lock() - startedAt := e.mu.neigh.UpdatedAtNanos + startedAt := e.mu.neigh.UpdatedAt e.mu.Unlock() // Wait for the rest of the reachability probe transmissions, signifying @@ -470,7 +465,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Lock() - if got, want := e.mu.neigh.UpdatedAtNanos, startedAt; got != want { + if got, want := e.mu.neigh.UpdatedAt, startedAt; got != want { t.Errorf("got e.mu.neigh.UpdatedAt = %q, want = %q", got, want) } e.mu.Unlock() @@ -485,10 +480,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -547,10 +542,10 @@ func incompleteToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -644,10 +639,10 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -678,10 +673,10 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -757,10 +752,10 @@ func incompleteToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *tes EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -943,10 +938,10 @@ func reachableToStale(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDis EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -998,10 +993,10 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1050,10 +1045,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1102,10 +1097,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1191,10 +1186,10 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1243,10 +1238,10 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1284,10 +1279,10 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1332,10 +1327,10 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1391,10 +1386,10 @@ func staleToDelay(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTe EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + UpdatedAt: clock.Now(), }, }, } @@ -1443,10 +1438,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1498,10 +1493,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1553,10 +1548,10 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1645,10 +1640,10 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1697,10 +1692,10 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1770,10 +1765,10 @@ func delayToProbe(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatc EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + UpdatedAt: clock.Now(), }, }, } @@ -1827,10 +1822,10 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1882,10 +1877,10 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -2003,10 +1998,10 @@ func probeToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher, lin EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: linkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: linkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -2155,10 +2150,10 @@ func probeToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDD EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -2227,10 +2222,10 @@ func unreachableToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkR EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dbba2c79f..b854d868c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -51,7 +51,7 @@ type nic struct { name string context NICContext - stats NICStats + stats sharedStats // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. @@ -78,26 +78,13 @@ type nic struct { } } -// NICStats hold statistics for a NIC. -type NICStats struct { - Tx DirectionStats - Rx DirectionStats - - DisabledRx DirectionStats - - Neighbor NeighborStats -} - -func makeNICStats() NICStats { - var s NICStats - tcpip.InitStatCounters(reflect.ValueOf(&s).Elem()) - return s -} - -// DirectionStats includes packet and byte counts. -type DirectionStats struct { - Packets *tcpip.StatCounter - Bytes *tcpip.StatCounter +// makeNICStats initializes the NIC statistics and associates them to the global +// NIC statistics. +func makeNICStats(global tcpip.NICStats) sharedStats { + var stats sharedStats + tcpip.InitStatCounters(reflect.ValueOf(&stats.local).Elem()) + stats.init(&stats.local, &global) + return stats } type packetEndpointList struct { @@ -150,7 +137,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC id: id, name: name, context: ctx, - stats: makeNICStats(), + stats: makeNICStats(stack.Stats().NICs), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver), duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector), @@ -382,8 +369,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt return err } - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + n.stats.tx.packets.Increment() + n.stats.tx.bytes.IncrementBy(uint64(numBytes)) return nil } @@ -399,13 +386,13 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk } writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) - n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) + n.stats.tx.packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { writtenBytes += pb.Size() } - n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + n.stats.tx.bytes.IncrementBy(uint64(writtenBytes)) return writtenPackets, err } @@ -718,18 +705,18 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if !enabled { n.mu.RUnlock() - n.stats.DisabledRx.Packets.Increment() - n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.disabledRx.packets.Increment() + n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size())) return } - n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.rx.packets.Increment() + n.stats.rx.bytes.IncrementBy(uint64(pkt.Data().Size())) networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL3ProtocolRcvdPackets.Increment() return } @@ -786,41 +773,35 @@ func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL4ProtocolRcvdPackets.Increment() return TransportPacketProtocolUnreachable } transProto := state.proto - // Raw socket packets are delivered based solely on the transport - // protocol number. We do not inspect the payload to ensure it's - // validly formed. - n.stack.demux.deliverRawPacket(protocol, pkt) - // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. if pkt.TransportHeader().View().IsEmpty() { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set yet, parse it + // here. See icmp/protocol.go:protocol.Parse for a full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() // We consider a malformed transport packet handled because there is // nothing the caller can do. return TransportPacketHandled } } else if !transProto.Parse(pkt) { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } } srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } @@ -852,7 +833,7 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // If it doesn't handle it then we should do so. switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled case UnknownDestinationPacketUnhandled: return TransportPacketDestinationPortUnreachable @@ -891,6 +872,17 @@ func (n *nic) DeliverTransportError(local, remote tcpip.Address, net tcpip.Netwo } } +// DeliverRawPacket implements TransportDispatcher. +func (n *nic) DeliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) { + // For ICMPv4 only we validate the header length for compatibility with + // raw(7) ICMP_FILTER. The same check is made in Linux here: + // https://github.com/torvalds/linux/blob/70585216/net/ipv4/raw.c#L189. + if protocol == header.ICMPv4ProtocolNumber && pkt.TransportHeader().View().Size()+pkt.Data().Size() < header.ICMPv4MinimumSize { + return + } + n.stack.demux.deliverRawPacket(protocol, pkt) +} + // ID implements NetworkInterface. func (n *nic) ID() tcpip.NICID { return n.id diff --git a/pkg/tcpip/stack/nic_stats.go b/pkg/tcpip/stack/nic_stats.go new file mode 100644 index 000000000..1773d5e8d --- /dev/null +++ b/pkg/tcpip/stack/nic_stats.go @@ -0,0 +1,74 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) + +type sharedStats struct { + local tcpip.NICStats + multiCounterNICStats +} + +// LINT.IfChange(multiCounterNICPacketStats) + +type multiCounterNICPacketStats struct { + packets tcpip.MultiCounterStat + bytes tcpip.MultiCounterStat +} + +func (m *multiCounterNICPacketStats) init(a, b *tcpip.NICPacketStats) { + m.packets.Init(a.Packets, b.Packets) + m.bytes.Init(a.Bytes, b.Bytes) +} + +// LINT.ThenChange(../../tcpip.go:NICPacketStats) + +// LINT.IfChange(multiCounterNICNeighborStats) + +type multiCounterNICNeighborStats struct { + unreachableEntryLookups tcpip.MultiCounterStat +} + +func (m *multiCounterNICNeighborStats) init(a, b *tcpip.NICNeighborStats) { + m.unreachableEntryLookups.Init(a.UnreachableEntryLookups, b.UnreachableEntryLookups) +} + +// LINT.ThenChange(../../tcpip.go:NICNeighborStats) + +// LINT.IfChange(multiCounterNICStats) + +type multiCounterNICStats struct { + unknownL3ProtocolRcvdPackets tcpip.MultiCounterStat + unknownL4ProtocolRcvdPackets tcpip.MultiCounterStat + malformedL4RcvdPackets tcpip.MultiCounterStat + tx multiCounterNICPacketStats + rx multiCounterNICPacketStats + disabledRx multiCounterNICPacketStats + neighbor multiCounterNICNeighborStats +} + +func (m *multiCounterNICStats) init(a, b *tcpip.NICStats) { + m.unknownL3ProtocolRcvdPackets.Init(a.UnknownL3ProtocolRcvdPackets, b.UnknownL3ProtocolRcvdPackets) + m.unknownL4ProtocolRcvdPackets.Init(a.UnknownL4ProtocolRcvdPackets, b.UnknownL4ProtocolRcvdPackets) + m.malformedL4RcvdPackets.Init(a.MalformedL4RcvdPackets, b.MalformedL4RcvdPackets) + m.tx.init(&a.Tx, &b.Tx) + m.rx.init(&a.Rx, &b.Rx) + m.disabledRx.init(&a.DisabledRx, &b.DisabledRx) + m.neighbor.init(&a.Neighbor, &b.Neighbor) +} + +// LINT.ThenChange(../../tcpip.go:NICStats) diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 8a3005295..5cb342f78 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,11 +15,13 @@ package stack import ( + "reflect" "testing" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) @@ -171,19 +173,19 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. nic := nic{ - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } - if got := nic.stats.DisabledRx.Packets.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 0 { t.Errorf("got DisabledRx.Packets = %d, want = 0", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 0 { t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } @@ -195,16 +197,28 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), })) - if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 4 { t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } } + +func TestMultiCounterStatsInitialization(t *testing.T) { + global := tcpip.NICStats{}.FillIn() + nic := nic{ + stats: makeNICStats(global), + } + multi := nic.stats.multiCounterNICStats + local := nic.stats.local + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&multi).Elem(), []reflect.Value{reflect.ValueOf(&local).Elem(), reflect.ValueOf(&global).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 5a94e9ac6..ca9822bca 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -16,6 +16,7 @@ package stack import ( "math" + "math/rand" "sync" "time" @@ -313,45 +314,36 @@ func calcMaxRandomFactor(minRandomFactor float32) float32 { return defaultMaxRandomFactor } -// A Rand is a source of random numbers. -type Rand interface { - // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0). - Float32() float32 -} - // NUDState stores states needed for calculating reachable time. type NUDState struct { - rng Rand + clock tcpip.Clock + rng *rand.Rand - // mu protects the fields below. - // - // It is necessary for NUDState to handle its own locking since neighbor - // entries may access the NUD state from within the goroutine spawned by - // time.AfterFunc(). This goroutine may run concurrently with the main - // process for controlling the neighbor cache and would otherwise introduce - // race conditions if NUDState was not locked properly. - mu sync.RWMutex + mu struct { + sync.RWMutex - config NUDConfigurations + config NUDConfigurations - // reachableTime is the duration to wait for a REACHABLE entry to - // transition into STALE after inactivity. This value is calculated with - // the algorithm defined in RFC 4861 section 6.3.2. - reachableTime time.Duration + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration - expiration time.Time - prevBaseReachableTime time.Duration - prevMinRandomFactor float32 - prevMaxRandomFactor float32 + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 + } } // NewNUDState returns new NUDState using c as configuration and the specified // random number generator for use in recomputing ReachableTime. -func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { +func NewNUDState(c NUDConfigurations, clock tcpip.Clock, rng *rand.Rand) *NUDState { s := &NUDState{ - rng: rng, + clock: clock, + rng: rng, } - s.config = c + s.mu.config = c return s } @@ -359,14 +351,14 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { func (s *NUDState) Config() NUDConfigurations { s.mu.RLock() defer s.mu.RUnlock() - return s.config + return s.mu.config } // SetConfig replaces the existing NUD configurations with c. func (s *NUDState) SetConfig(c NUDConfigurations) { s.mu.Lock() defer s.mu.Unlock() - s.config = c + s.mu.config = c } // ReachableTime returns the duration to wait for a REACHABLE entry to @@ -377,13 +369,13 @@ func (s *NUDState) ReachableTime() time.Duration { s.mu.Lock() defer s.mu.Unlock() - if time.Now().After(s.expiration) || - s.config.BaseReachableTime != s.prevBaseReachableTime || - s.config.MinRandomFactor != s.prevMinRandomFactor || - s.config.MaxRandomFactor != s.prevMaxRandomFactor { + if s.clock.Now().After(s.mu.expiration) || + s.mu.config.BaseReachableTime != s.mu.prevBaseReachableTime || + s.mu.config.MinRandomFactor != s.mu.prevMinRandomFactor || + s.mu.config.MaxRandomFactor != s.mu.prevMaxRandomFactor { s.recomputeReachableTimeLocked() } - return s.reachableTime + return s.mu.reachableTime } // recomputeReachableTimeLocked forces a recalculation of ReachableTime using @@ -408,23 +400,23 @@ func (s *NUDState) ReachableTime() time.Duration { // // s.mu MUST be locked for writing. func (s *NUDState) recomputeReachableTimeLocked() { - s.prevBaseReachableTime = s.config.BaseReachableTime - s.prevMinRandomFactor = s.config.MinRandomFactor - s.prevMaxRandomFactor = s.config.MaxRandomFactor + s.mu.prevBaseReachableTime = s.mu.config.BaseReachableTime + s.mu.prevMinRandomFactor = s.mu.config.MinRandomFactor + s.mu.prevMaxRandomFactor = s.mu.config.MaxRandomFactor - randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + randomFactor := s.mu.config.MinRandomFactor + s.rng.Float32()*(s.mu.config.MaxRandomFactor-s.mu.config.MinRandomFactor) // Check for overflow, given that minRandomFactor and maxRandomFactor are // guaranteed to be positive numbers. - if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { - s.reachableTime = time.Duration(math.MaxInt64) + if math.MaxInt64/randomFactor < float32(s.mu.config.BaseReachableTime) { + s.mu.reachableTime = time.Duration(math.MaxInt64) } else if randomFactor == 1 { // Avoid loss of precision when a large base reachable time is used. - s.reachableTime = s.config.BaseReachableTime + s.mu.reachableTime = s.mu.config.BaseReachableTime } else { - reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) - s.reachableTime = time.Duration(reachableTime) + reachableTime := int64(float32(s.mu.config.BaseReachableTime) * randomFactor) + s.mu.reachableTime = time.Duration(reachableTime) } - s.expiration = time.Now().Add(2 * time.Hour) + s.mu.expiration = s.clock.Now().Add(2 * time.Hour) } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index e1253f310..1aeb2f8a5 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -16,6 +16,7 @@ package stack_test import ( "math" + "math/rand" "testing" "time" @@ -28,17 +29,15 @@ import ( ) const ( - defaultBaseReachableTime = 30 * time.Second - minimumBaseReachableTime = time.Millisecond - defaultMinRandomFactor = 0.5 - defaultMaxRandomFactor = 1.5 - defaultRetransmitTimer = time.Second - minimumRetransmitTimer = time.Millisecond - defaultDelayFirstProbeTime = 5 * time.Second - defaultMaxMulticastProbes = 3 - defaultMaxUnicastProbes = 3 - defaultMaxAnycastDelayTime = time.Second - defaultMaxReachbilityConfirmations = 3 + defaultBaseReachableTime = 30 * time.Second + minimumBaseReachableTime = time.Millisecond + defaultMinRandomFactor = 0.5 + defaultMaxRandomFactor = 1.5 + defaultRetransmitTimer = time.Second + minimumRetransmitTimer = time.Millisecond + defaultDelayFirstProbeTime = 5 * time.Second + defaultMaxMulticastProbes = 3 + defaultMaxUnicastProbes = 3 defaultFakeRandomNum = 0.5 ) @@ -48,12 +47,14 @@ type fakeRand struct { num float32 } -var _ stack.Rand = (*fakeRand)(nil) +var _ rand.Source = (*fakeRand)(nil) -func (f *fakeRand) Float32() float32 { - return f.num +func (f *fakeRand) Int63() int64 { + return int64(f.num * float32(1<<63)) } +func (*fakeRand) Seed(int64) {} + func TestNUDFunctions(t *testing.T) { const nicID = 1 @@ -169,7 +170,7 @@ func TestNUDFunctions(t *testing.T) { t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) } else if test.expectedErr == nil { if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAt: clock.Now()}}, neighbors, ); diff != "" { t.Errorf("neighbors mismatch (-want +got):\n%s", diff) @@ -710,7 +711,8 @@ func TestNUDStateReachableTime(t *testing.T) { rng := fakeRand{ num: defaultFakeRandomNum, } - s := stack.NewNUDState(c, &rng) + var clock faketime.NullClock + s := stack.NewNUDState(c, &clock, rand.New(&rng)) if got, want := s.ReachableTime(), test.want; got != want { t.Errorf("got ReachableTime = %q, want = %q", got, want) } @@ -782,7 +784,8 @@ func TestNUDStateRecomputeReachableTime(t *testing.T) { rng := fakeRand{ num: defaultFakeRandomNum, } - s := stack.NewNUDState(c, &rng) + var clock faketime.NullClock + s := stack.NewNUDState(c, &clock, rand.New(&rng)) old := s.ReachableTime() if got, want := s.ReachableTime(), old; got != want { diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 01652fbe7..9192d8433 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -134,7 +134,7 @@ type PacketBuffer struct { // https://www.man7.org/linux/man-pages/man7/packet.7.html. PktType tcpip.PacketType - // NICID is the ID of the interface the network packet was received at. + // NICID is the ID of the last interface the network packet was handled at. NICID tcpip.NICID // RXTransportChecksumValidated indicates that transport checksum verification @@ -245,10 +245,10 @@ func (pk *PacketBuffer) dataOffset() int { func (pk *PacketBuffer) push(typ headerType, size int) tcpipbuffer.View { h := &pk.headers[typ] if h.length > 0 { - panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + panic(fmt.Sprintf("push(%s, %d) called after previous push", typ, size)) } if pk.pushed+size > pk.reserved { - panic("not enough headroom reserved") + panic(fmt.Sprintf("push(%s, %d) overflows; pushed=%d reserved=%d", typ, size, pk.pushed, pk.reserved)) } pk.pushed += size h.offset = -pk.pushed diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go index 421fb5c15..c8294eb6e 100644 --- a/pkg/tcpip/stack/rand.go +++ b/pkg/tcpip/stack/rand.go @@ -15,7 +15,7 @@ package stack import ( - mathrand "math/rand" + "math/rand" "gvisor.dev/gvisor/pkg/sync" ) @@ -23,7 +23,7 @@ import ( // lockedRandomSource provides a threadsafe rand.Source. type lockedRandomSource struct { mu sync.Mutex - src mathrand.Source + src rand.Source } func (r *lockedRandomSource) Int63() (n int64) { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 85bb87b4b..dfe2c886f 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -265,6 +265,11 @@ type TransportDispatcher interface { // // DeliverTransportError takes ownership of the packet buffer. DeliverTransportError(local, remote tcpip.Address, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, _ TransportError, _ *PacketBuffer) + + // DeliverRawPacket delivers a packet to any subscribed raw sockets. + // + // DeliverRawPacket does NOT take ownership of the packet buffer. + DeliverRawPacket(tcpip.TransportProtocolNumber, *PacketBuffer) } // PacketLooping specifies where an outbound packet should be sent. @@ -420,7 +425,7 @@ const ( PermanentExpired // Temporary is an endpoint, created on a one-off basis to temporarily - // consider the NIC bound an an address that it is not explictiy bound to + // consider the NIC bound an an address that it is not explicitly bound to // (such as a permanent address). Its reference count must not be biased by 1 // so that the address is removed immediately when references to it are no // longer held. @@ -630,7 +635,7 @@ type NetworkEndpoint interface { // HandlePacket takes ownership of pkt. HandlePacket(pkt *PacketBuffer) - // Close is called when the endpoint is reomved from a stack. + // Close is called when the endpoint is removed from a stack. Close() // NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for @@ -968,7 +973,7 @@ type DuplicateAddressDetector interface { // called with the result of the original DAD request. CheckDuplicateAddress(tcpip.Address, DADCompletionHandler) DADCheckAddressDisposition - // SetDADConfiguations sets the configurations for DAD. + // SetDADConfigurations sets the configurations for DAD. SetDADConfigurations(c DADConfigurations) // DuplicateAddressProtocol returns the network protocol the receiver can @@ -979,7 +984,7 @@ type DuplicateAddressDetector interface { // LinkAddressResolver handles link address resolution for a network protocol. type LinkAddressResolver interface { // LinkAddressRequest sends a request for the link address of the target - // address. The request is broadcasted on the local network if a remote link + // address. The request is broadcast on the local network if a remote link // address is not provided. LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error @@ -1072,4 +1077,4 @@ type GSOEndpoint interface { // SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment. // This isn't a hard limit, because it is never set into packet headers. -const SoftwareGSOMaxSize = (1 << 16) +const SoftwareGSOMaxSize = 1 << 16 diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 8814f45a6..81fabe29a 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -20,17 +20,16 @@ package stack import ( - "bytes" "encoding/binary" "fmt" "io" - mathrand "math/rand" + "math/rand" "sync/atomic" "time" "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/atomicbitops" - "gvisor.dev/gvisor/pkg/rand" + cryptorand "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -40,13 +39,6 @@ import ( ) const ( - // ageLimit is set to the same cache stale time used in Linux. - ageLimit = 1 * time.Minute - // resolutionTimeout is set to the same ARP timeout used in Linux. - resolutionTimeout = 1 * time.Second - // resolutionAttempts is set to the same ARP retries used in Linux. - resolutionAttempts = 3 - // DefaultTOS is the default type of service value for network endpoints. DefaultTOS = 0 ) @@ -116,7 +108,7 @@ type Stack struct { handleLocal bool // tables are the iptables packet filtering and manipulation rules. - // TODO(gvisor.dev/issue/170): S/R this field. + // TODO(gvisor.dev/issue/4595): S/R this field. tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the @@ -145,7 +137,7 @@ type Stack struct { // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. - randomGenerator *mathrand.Rand + randomGenerator *rand.Rand // secureRNG is a cryptographically secure random number generator. secureRNG io.Reader @@ -196,9 +188,9 @@ type Options struct { // TransportProtocols lists the transport protocols to enable. TransportProtocols []TransportProtocolFactory - // Clock is an optional clock source used for timestampping packets. + // Clock is an optional clock used for timekeeping. // - // If no Clock is specified, the clock source will be time.Now. + // If Clock is nil, tcpip.NewStdClock() will be used. Clock tcpip.Clock // Stats are optional statistic counters. @@ -225,15 +217,21 @@ type Options struct { // RandSource is an optional source to use to generate random // numbers. If omitted it defaults to a Source seeded by the data - // returned by rand.Read(). + // returned by the stack secure RNG. // // RandSource must be thread-safe. - RandSource mathrand.Source + RandSource rand.Source - // IPTables are the initial iptables rules. If nil, iptables will allow + // IPTables are the initial iptables rules. If nil, DefaultIPTables will be + // used to construct the initial iptables rules. // all traffic. IPTables *IPTables + // DefaultIPTables is an optional iptables rules constructor that is called + // if IPTables is nil. If both fields are nil, iptables will allow all + // traffic. + DefaultIPTables func(uint32) *IPTables + // SecureRNG is a cryptographically secure random number generator. SecureRNG io.Reader } @@ -331,23 +329,32 @@ func New(opts Options) *Stack { opts.UniqueID = new(uniqueIDGenerator) } + if opts.SecureRNG == nil { + opts.SecureRNG = cryptorand.Reader + } + randSrc := opts.RandSource if randSrc == nil { - // Source provided by mathrand.NewSource is not thread-safe so + var v int64 + if err := binary.Read(opts.SecureRNG, binary.LittleEndian, &v); err != nil { + panic(err) + } + // Source provided by rand.NewSource is not thread-safe so // we wrap it in a simple thread-safe version. - randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} + randSrc = &lockedRandomSource{src: rand.NewSource(v)} } + randomGenerator := rand.New(randSrc) + seed := randomGenerator.Uint32() if opts.IPTables == nil { - opts.IPTables = DefaultTables() + if opts.DefaultIPTables == nil { + opts.DefaultIPTables = DefaultTables + } + opts.IPTables = opts.DefaultIPTables(seed) } opts.NUDConfigs.resetInvalidFields() - if opts.SecureRNG == nil { - opts.SecureRNG = rand.Reader - } - s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), @@ -360,11 +367,11 @@ func New(opts Options) *Stack { handleLocal: opts.HandleLocal, tables: opts.IPTables, icmpRateLimiter: NewICMPRateLimiter(), - seed: generateRandUint32(), + seed: seed, nudConfigs: opts.NUDConfigs, uniqueIDGenerator: opts.UniqueID, nudDisp: opts.NUDDisp, - randomGenerator: mathrand.New(randSrc), + randomGenerator: randomGenerator, secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, @@ -804,7 +811,7 @@ type NICInfo struct { // MTU is the maximum transmission unit. MTU uint32 - Stats NICStats + Stats tcpip.NICStats // NetworkStats holds the stats of each NetworkEndpoint bound to the NIC. NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats @@ -856,7 +863,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { ProtocolAddresses: nic.primaryAddresses(), Flags: flags, MTU: nic.LinkEndpoint.MTU(), - Stats: nic.stats, + Stats: nic.stats.local, NetworkStats: netStats, Context: nic.context, ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), @@ -1819,7 +1826,7 @@ func (s *Stack) Seed() uint32 { // Rand returns a reference to a pseudo random generator that can be used // to generate random numbers as required. -func (s *Stack) Rand() *mathrand.Rand { +func (s *Stack) Rand() *rand.Rand { return s.randomGenerator } @@ -1829,27 +1836,6 @@ func (s *Stack) SecureRNG() io.Reader { return s.secureRNG } -func generateRandUint32() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - return binary.LittleEndian.Uint32(b) -} - -func generateRandInt64() int64 { - b := make([]byte, 8) - if _, err := rand.Read(b); err != nil { - panic(err) - } - buf := bytes.NewReader(b) - var v int64 - if err := binary.Read(buf, binary.LittleEndian, &v); err != nil { - panic(err) - } - return v -} - // FindNICNameFromID returns the name of the NIC for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { s.mu.RLock() @@ -1886,9 +1872,8 @@ const ( // ParsePacketBufferTransport parses the provided packet buffer's transport // header. func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set yet, parse it + // here. See icmp/protocol.go:protocol.Parse for a full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { return ParsedOK } diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go index 33824afd0..dfec4258a 100644 --- a/pkg/tcpip/stack/stack_global_state.go +++ b/pkg/tcpip/stack/stack_global_state.go @@ -14,78 +14,6 @@ package stack -import "time" - // StackFromEnv is the global stack created in restore run. // FIXME(b/36201077) var StackFromEnv *Stack - -// saveT is invoked by stateify. -func (t *TCPCubicState) saveT() unixTime { - return unixTime{t.T.Unix(), t.T.UnixNano()} -} - -// loadT is invoked by stateify. -func (t *TCPCubicState) loadT(unix unixTime) { - t.T = time.Unix(unix.second, unix.nano) -} - -// saveXmitTime is invoked by stateify. -func (t *TCPRACKState) saveXmitTime() unixTime { - return unixTime{t.XmitTime.Unix(), t.XmitTime.UnixNano()} -} - -// loadXmitTime is invoked by stateify. -func (t *TCPRACKState) loadXmitTime(unix unixTime) { - t.XmitTime = time.Unix(unix.second, unix.nano) -} - -// saveLastSendTime is invoked by stateify. -func (t *TCPSenderState) saveLastSendTime() unixTime { - return unixTime{t.LastSendTime.Unix(), t.LastSendTime.UnixNano()} -} - -// loadLastSendTime is invoked by stateify. -func (t *TCPSenderState) loadLastSendTime(unix unixTime) { - t.LastSendTime = time.Unix(unix.second, unix.nano) -} - -// saveRTTMeasureTime is invoked by stateify. -func (t *TCPSenderState) saveRTTMeasureTime() unixTime { - return unixTime{t.RTTMeasureTime.Unix(), t.RTTMeasureTime.UnixNano()} -} - -// loadRTTMeasureTime is invoked by stateify. -func (t *TCPSenderState) loadRTTMeasureTime(unix unixTime) { - t.RTTMeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) saveMeasureTime() unixTime { - return unixTime{r.MeasureTime.Unix(), r.MeasureTime.UnixNano()} -} - -// loadMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { - r.MeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveRTTMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) saveRTTMeasureTime() unixTime { - return unixTime{r.RTTMeasureTime.Unix(), r.RTTMeasureTime.UnixNano()} -} - -// loadRTTMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) loadRTTMeasureTime(unix unixTime) { - r.RTTMeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveSegTime is invoked by stateify. -func (t *TCPEndpointState) saveSegTime() unixTime { - return unixTime{t.SegTime.Unix(), t.SegTime.UnixNano()} -} - -// loadSegTime is invoked by stateify. -func (t *TCPEndpointState) loadSegTime(unix unixTime) { - t.SegTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 02d54d29b..21951d05a 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -166,10 +166,6 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fakeNetHeaderLen } -func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -197,11 +193,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHe } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -463,14 +459,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer } } -func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) @@ -920,15 +916,15 @@ func TestRouteWithDownNIC(t *testing.T) { if err := test.downFn(s, nicID1); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID1, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) testSend(t, r2, ep2, buf) // Writes with Routes that use NIC2 after being brought down should fail. if err := test.downFn(s, nicID2); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID2, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) if upFn := test.upFn; upFn != nil { // Writes with Routes that use NIC1 after being brought up should @@ -941,7 +937,7 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) } }) } @@ -1066,7 +1062,7 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. err := s.RemoveAddress(1, localAddr) @@ -1118,8 +1114,8 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. { @@ -1140,7 +1136,7 @@ func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.A // No address given, verify that there is no address assigned to the NIC. for _, a := range info.ProtocolAddresses { if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { - t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, tcpip.AddressWithPrefix{}) } } return @@ -1220,7 +1216,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 2. Add Address, everything should work. @@ -1248,7 +1244,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 4. Add Address back, everything should work again. @@ -1287,8 +1283,8 @@ func TestEndpointExpiration(t *testing.T) { testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 7. Add Address back, everything should work again. @@ -1324,7 +1320,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } }) } @@ -1574,7 +1570,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, dstAddr, nil, &tcpip.ErrNoRoute{}) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -1615,7 +1611,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } } - protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } @@ -1641,12 +1637,12 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } func TestOutgoingBroadcastWithRouteTable(t *testing.T) { - defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + defaultAddr := tcpip.AddressWithPrefix{Address: header.IPv4Any} // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. - nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Addr := tcpip.AddressWithPrefix{Address: "\xc0\xa8\x01\x3a", PrefixLen: 24} nic1Gateway := testutil.MustParse4("192.168.1.1") // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. - nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Addr := tcpip.AddressWithPrefix{Address: "\x0a\x0a\x0a\x05", PrefixLen: 24} nic2Gateway := testutil.MustParse4("10.10.10.1") // Create a new stack with two NICs. @@ -1660,12 +1656,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err := s.CreateNIC(2, ep); err != nil { t.Fatalf("CreateNIC failed: %s", err) } - nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } - nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } @@ -1709,7 +1705,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // 2. Case: Having an explicit route for broadcast will select that one. rt = append( []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, }, rt..., ) @@ -2049,7 +2045,7 @@ func TestAddAddress(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } @@ -2113,7 +2109,7 @@ func TestAddAddressWithOptions(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } } @@ -2234,7 +2230,7 @@ func TestCreateNICWithOptions(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { s := stack.New(stack.Options{}) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") for _, call := range test.calls { if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) @@ -2248,46 +2244,87 @@ func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed: ", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + + nics := []struct { + addr tcpip.Address + txByteCount int + rxByteCount int + }{ + { + addr: "\x01", + txByteCount: 30, + rxByteCount: 10, + }, + { + addr: "\x02", + txByteCount: 50, + rxByteCount: 20, + }, } - // Route all packets for address \x01 to NIC 1. - { - subnet, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) + + var txBytesTotal, rxBytesTotal, txPacketsTotal, rxPacketsTotal int + for i, nic := range nics { + nicid := tcpip.NICID(i) + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { + t.Fatal("CreateNIC failed: ", err) + } + if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { + t.Fatal("AddAddress failed:", err) } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - // Send a packet to address 1. - buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } + { + subnet, err := tcpip.NewSubnet(nic.addr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicid}}) + } + + nicStats := s.NICInfo()[nicid].Stats - if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + // Inbound packet. + rxBuffer := buffer.NewView(nic.rxByteCount) + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: rxBuffer.ToVectorisedView(), + })) + if got, want := nicStats.Rx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) + } + if got, want := nicStats.Rx.Bytes.Value(), uint64(nic.rxByteCount); got != want { + t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + } + rxPacketsTotal++ + rxBytesTotal += nic.rxByteCount + + // Outbound packet. + txBuffer := buffer.NewView(nic.txByteCount) + actualTxLength := nic.txByteCount + fakeNetHeaderLen + if err := sendTo(s, nic.addr, txBuffer); err != nil { + t.Fatal("sendTo failed: ", err) + } + want := ep.Drain() + if got := nicStats.Tx.Packets.Value(); got != uint64(want) { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := nicStats.Tx.Bytes.Value(), uint64(actualTxLength); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + txPacketsTotal += want + txBytesTotal += actualTxLength } - payload := buffer.NewView(10) - // Write a packet out via the address for NIC 1 - if err := sendTo(s, "\x01", payload); err != nil { - t.Fatal("sendTo failed: ", err) + // Now verify that each NIC stats was correctly aggregated at the stack level. + if got, want := s.Stats().NICs.Rx.Packets.Value(), uint64(rxPacketsTotal); got != want { + t.Errorf("got s.Stats().NIC.Rx.Packets.Value() = %d, want = %d", got, want) } - want := uint64(ep1.Drain()) - if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) + if got, want := s.Stats().NICs.Rx.Bytes.Value(), uint64(rxBytesTotal); got != want { + t.Errorf("got s.Stats().Rx.Bytes.Value() = %d, want = %d", got, want) } - - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want { + if got, want := s.Stats().NICs.Tx.Packets.Value(), uint64(txPacketsTotal); got != want { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := s.Stats().NICs.Tx.Bytes.Value(), uint64(txBytesTotal); got != want { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } @@ -2316,7 +2353,7 @@ func TestNICContextPreservation(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{}) id := tcpip.NICID(1) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) } @@ -2603,15 +2640,17 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { const nicID = 1 ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } dadConfigs := stack.DefaultDADConfigurations() + clock := faketime.NewManualClock() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, } e := channel.New(int(dadConfigs.DupAddrDetectTransmits), 1280, linkAddr1) @@ -2629,17 +2668,18 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { linkLocalAddr := header.LinkLocalAddr(linkAddr1) // Wait for DAD to resolve. + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: // We should get a resolution event after 1s (default time to // resolve as per default NDP configurations). Waiting for that // resolution time + an extra 1s without a resolution event // means something is wrong. t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil { t.Fatal(err) @@ -3270,8 +3310,9 @@ func TestDoDADWhenNICEnabled(t *testing.T) { const nicID = 1 ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } + clock := faketime.NewManualClock() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -3280,6 +3321,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, } e := channel.New(dadTransmits, 1280, linkAddr1) @@ -3324,13 +3366,14 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } // Wait for DAD to resolve. + clock.Advance(dadTransmits * retransmitTimer) select { - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD resolution") } if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) @@ -3837,8 +3880,6 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { // TestAddRoute tests Stack.AddRoute func TestAddRoute(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) subnet1, err := tcpip.NewSubnet("\x00", "\x00") @@ -3875,8 +3916,6 @@ func TestAddRoute(t *testing.T) { // TestRemoveRoutes tests Stack.RemoveRoutes func TestRemoveRoutes(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) addressToRemove := tcpip.Address("\x01") @@ -4223,7 +4262,7 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) - if r != nil { + if err == nil { defer r.Release() } if diff := cmp.Diff(test.findRouteErr, err); diff != "" { @@ -4394,7 +4433,7 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) } else if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAt: clock.Now()}}, neighbors, ); diff != "" { t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff) diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go index ddff6e2d6..90a8ba6cf 100644 --- a/pkg/tcpip/stack/tcp.go +++ b/pkg/tcpip/stack/tcp.go @@ -39,7 +39,7 @@ type TCPCubicState struct { WMax float64 // T is the time when the current congestion avoidance was entered. - T time.Time `state:".(unixTime)"` + T tcpip.MonotonicTime // TimeSinceLastCongestion denotes the time since the current // congestion avoidance was entered. @@ -78,7 +78,7 @@ type TCPCubicState struct { type TCPRACKState struct { // XmitTime is the transmission timestamp of the most recent // acknowledged segment. - XmitTime time.Time `state:".(unixTime)"` + XmitTime tcpip.MonotonicTime // EndSequence is the ending TCP sequence number of the most recent // acknowledged segment. @@ -216,7 +216,7 @@ type TCPRTTState struct { // +stateify savable type TCPSenderState struct { // LastSendTime is the timestamp at which we sent the last segment. - LastSendTime time.Time `state:".(unixTime)"` + LastSendTime tcpip.MonotonicTime // DupAckCount is the number of Duplicate ACKs received. It is used for // fast retransmit. @@ -256,7 +256,7 @@ type TCPSenderState struct { RTTMeasureSeqNum seqnum.Value // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. - RTTMeasureTime time.Time `state:".(unixTime)"` + RTTMeasureTime tcpip.MonotonicTime // Closed indicates that the caller has closed the endpoint for // sending. @@ -313,7 +313,7 @@ type TCPSACKInfo struct { type RcvBufAutoTuneParams struct { // MeasureTime is the time at which the current measurement was // started. - MeasureTime time.Time `state:".(unixTime)"` + MeasureTime tcpip.MonotonicTime // CopiedBytes is the number of bytes copied to user space since this // measure began. @@ -341,7 +341,7 @@ type RcvBufAutoTuneParams struct { // RTTMeasureTime is the absolute time at which the current RTT // measurement period began. - RTTMeasureTime time.Time `state:".(unixTime)"` + RTTMeasureTime tcpip.MonotonicTime // Disabled is true if an explicit receive buffer is set for the // endpoint. @@ -380,9 +380,6 @@ type TCPSndBufState struct { // SndClosed indicates that the endpoint has been closed for sends. SndClosed bool - // SndBufInQueue is the number of bytes in the send queue. - SndBufInQueue seqnum.Size - // PacketTooBigCount is used to notify the main protocol routine how // many times a "packet too big" control packet is received. PacketTooBigCount int @@ -429,7 +426,7 @@ type TCPEndpointState struct { ID TCPEndpointID // SegTime denotes the absolute time when this segment was received. - SegTime time.Time `state:".(unixTime)"` + SegTime tcpip.MonotonicTime // RcvBufState contains information about the state of the endpoint's // receive socket buffer. diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 80ad1a9d4..dda57e225 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "math/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -217,13 +216,20 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t netProto: netProto, transProto: transProto, } - epsByNIC.endpoints[bindToDevice] = multiPortEp } - return multiPortEp.singleRegisterEndpoint(t, flags) + if err := multiPortEp.singleRegisterEndpoint(t, flags); err != nil { + return err + } + // Only add this newly created multiportEndpoint if the singleRegisterEndpoint + // succeeded. + if !ok { + epsByNIC.endpoints[bindToDevice] = multiPortEp + } + return nil } -func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { +func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -407,7 +413,6 @@ func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *Packet func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - bits := flags.Bits() & ports.MultiBindFlagMask if len(ep.endpoints) != 0 { @@ -470,17 +475,21 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol eps.mu.Lock() defer eps.mu.Unlock() - epsByNIC, ok := eps.endpoints[id] if !ok { epsByNIC = &endpointsByNIC{ endpoints: make(map[tcpip.NICID]*multiPortEndpoint), - seed: rand.Uint32(), + seed: d.stack.Seed(), } + } + if err := epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice); err != nil { + return err + } + // Only add this newly created epsByNIC if registerEndpoint succeeded. + if !ok { eps.endpoints[id] = epsByNIC } - - return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) + return nil } func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { @@ -502,7 +511,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum return nil } - return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice) + return epsByNIC.checkEndpoint(flags, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 4848495c9..45b09110d 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -18,6 +18,7 @@ import ( "io/ioutil" "math" "math/rand" + "strconv" "testing" "gvisor.dev/gvisor/pkg/tcpip" @@ -84,7 +85,8 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI } type headers struct { - srcPort, dstPort uint16 + srcPort uint16 + dstPort uint16 } func newPayload() []byte { @@ -201,6 +203,56 @@ func TestTransportDemuxerRegister(t *testing.T) { } } +func TestTransportDemuxerRegisterMultiple(t *testing.T) { + type test struct { + flags ports.Flags + want tcpip.Error + } + for _, subtest := range []struct { + name string + tests []test + }{ + {"zeroFlags", []test{ + {ports.Flags{}, nil}, + {ports.Flags{}, &tcpip.ErrPortInUse{}}, + }}, + {"multibindFlags", []test{ + // Allow multiple registrations same TransportEndpointID with multibind flags. + {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, + {ports.Flags{LoadBalanced: true, MostRecent: true}, nil}, + // Disallow registration w/same ID for a non-multibindflag. + {ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}}, + }}, + } { + t.Run(subtest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + var eps []tcpip.Endpoint + for idx, test := range subtest.tests { + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + eps = append(eps, ep) + tEP, ok := ep.(stack.TransportEndpoint) + if !ok { + t.Fatalf("%T does not implement stack.TransportEndpoint", ep) + } + id := stack.TransportEndpointID{LocalPort: 1} + if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want { + t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want) + } + } + for _, ep := range eps { + ep.Close() + } + }) + } +} + // TestBindToDeviceDistribution injects varied packets on input devices and checks that // the distribution of packets received matches expectations. func TestBindToDeviceDistribution(t *testing.T) { @@ -208,7 +260,7 @@ func TestBindToDeviceDistribution(t *testing.T) { reuse bool bindToDevice tcpip.NICID } - for _, test := range []struct { + tcs := []struct { name string // endpoints will received the inject packets. endpoints []endpointSockopts @@ -217,29 +269,29 @@ func TestBindToDeviceDistribution(t *testing.T) { wantDistributions map[tcpip.NICID][]float64 }{ { - "BindPortReuse", + name: "BindPortReuse", // 5 endpoints that all have reuse set. - []endpointSockopts{ + endpoints: []endpointSockopts{ {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, }, - map[tcpip.NICID][]float64{ + wantDistributions: map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. 1: {0.2, 0.2, 0.2, 0.2, 0.2}, }, }, { - "BindToDevice", + name: "BindToDevice", // 3 endpoints with various bindings. - []endpointSockopts{ + endpoints: []endpointSockopts{ {reuse: false, bindToDevice: 1}, {reuse: false, bindToDevice: 2}, {reuse: false, bindToDevice: 3}, }, - map[tcpip.NICID][]float64{ + wantDistributions: map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. 1: {1, 0, 0}, // Injected packets on dev1 go only to the endpoint bound to dev1. @@ -249,9 +301,9 @@ func TestBindToDeviceDistribution(t *testing.T) { }, }, { - "ReuseAndBindToDevice", + name: "ReuseAndBindToDevice", // 6 endpoints with various bindings. - []endpointSockopts{ + endpoints: []endpointSockopts{ {reuse: true, bindToDevice: 1}, {reuse: true, bindToDevice: 1}, {reuse: true, bindToDevice: 2}, @@ -259,7 +311,7 @@ func TestBindToDeviceDistribution(t *testing.T) { {reuse: true, bindToDevice: 2}, {reuse: true, bindToDevice: 0}, }, - map[tcpip.NICID][]float64{ + wantDistributions: map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to // dev0. 1: {0.5, 0.5, 0, 0, 0, 0}, @@ -270,35 +322,42 @@ func TestBindToDeviceDistribution(t *testing.T) { 1000: {0, 0, 0, 0, 0, 1}, }, }, - } { - for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{ - "IPv4": ipv4.ProtocolNumber, - "IPv6": ipv6.ProtocolNumber, - } { + } + protos := map[string]tcpip.NetworkProtocolNumber{ + "IPv4": ipv4.ProtocolNumber, + "IPv6": ipv6.ProtocolNumber, + } + + for _, test := range tcs { + for protoName, protoNum := range protos { for device, wantDistribution := range test.wantDistributions { - t.Run(test.name+protoName+string(device), func(t *testing.T) { + t.Run(test.name+protoName+"-"+strconv.Itoa(int(device)), func(t *testing.T) { + // Create the NICs. var devices []tcpip.NICID for d := range test.wantDistributions { devices = append(devices, d) } c := newDualTestContextMultiNIC(t, defaultMTU, devices) + // Create endpoints and bind each to a NIC, sometimes reusing ports. eps := make(map[tcpip.Endpoint]int) - pollChannel := make(chan tcpip.Endpoint) for i, endpoint := range test.endpoints { // Try to receive the data. wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) + t.Cleanup(func() { + wq.EventUnregister(&we) + close(ch) + }) var err tcpip.Error - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, protoNum, &wq) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) } + t.Cleanup(ep.Close) eps[ep] = i go func(ep tcpip.Endpoint) { @@ -307,32 +366,34 @@ func TestBindToDeviceDistribution(t *testing.T) { } }(ep) - defer ep.Close() ep.SocketOptions().SetReusePort(endpoint.reuse) if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) } var dstAddr tcpip.Address - switch netProtoNum { + switch protoNum { case ipv4.ProtocolNumber: dstAddr = testDstAddrV4 case ipv6.ProtocolNumber: dstAddr = testDstAddrV6 default: - t.Fatalf("unexpected protocol number: %d", netProtoNum) + t.Fatalf("unexpected protocol number: %d", protoNum) } if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil { t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) } } - npackets := 100000 - nports := 10000 + // Send packets across a range of ports, checking that packets from + // the same source port are always demultiplexed to the same + // destination endpoint. + npackets := 10_000 + nports := 1_000 if got, want := len(test.endpoints), len(wantDistribution); got != want { t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) } - ports := make(map[uint16]tcpip.Endpoint) + endpoints := make(map[uint16]tcpip.Endpoint) stats := make(map[tcpip.Endpoint]int) for i := 0; i < npackets; i++ { // Send a packet. @@ -342,13 +403,13 @@ func TestBindToDeviceDistribution(t *testing.T) { srcPort: testSrcPort + port, dstPort: testDstPort, } - switch netProtoNum { + switch protoNum { case ipv4.ProtocolNumber: c.sendV4Packet(payload, hdrs, device) case ipv6.ProtocolNumber: c.sendV6Packet(payload, hdrs, device) default: - t.Fatalf("unexpected protocol number: %d", netProtoNum) + t.Fatalf("unexpected protocol number: %d", protoNum) } ep := <-pollChannel @@ -357,11 +418,11 @@ func TestBindToDeviceDistribution(t *testing.T) { } stats[ep]++ if i < nports { - ports[uint16(i)] = ep + endpoints[uint16(i)] = ep } else { // Check that all packets from one client are handled by the same // socket. - if want, got := ports[port], ep; want != got { + if want, got := endpoints[port], ep; want != got { t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) } } diff --git a/pkg/tcpip/stdclock.go b/pkg/tcpip/stdclock.go index 7ce43a68e..371da2f40 100644 --- a/pkg/tcpip/stdclock.go +++ b/pkg/tcpip/stdclock.go @@ -60,11 +60,11 @@ type stdClock struct { // monotonicOffset is assigned maxMonotonic after restore so that the // monotonic time will continue from where it "left off" before saving as part // of S/R. - monotonicOffset int64 `state:"nosave"` + monotonicOffset MonotonicTime `state:"nosave"` // monotonicMU protects maxMonotonic. monotonicMU sync.Mutex `state:"nosave"` - maxMonotonic int64 + maxMonotonic MonotonicTime } // NewStdClock returns an instance of a clock that uses the time package. @@ -76,25 +76,25 @@ func NewStdClock() Clock { var _ Clock = (*stdClock)(nil) -// NowNanoseconds implements Clock.NowNanoseconds. -func (*stdClock) NowNanoseconds() int64 { - return time.Now().UnixNano() +// Now implements Clock.Now. +func (*stdClock) Now() time.Time { + return time.Now() } // NowMonotonic implements Clock.NowMonotonic. -func (s *stdClock) NowMonotonic() int64 { +func (s *stdClock) NowMonotonic() MonotonicTime { sinceBase := time.Since(s.baseTime) if sinceBase < 0 { panic(fmt.Sprintf("got negative duration = %s since base time = %s", sinceBase, s.baseTime)) } - monotonicValue := sinceBase.Nanoseconds() + s.monotonicOffset + monotonicValue := s.monotonicOffset.Add(sinceBase) s.monotonicMU.Lock() defer s.monotonicMU.Unlock() // Monotonic time values must never decrease. - if monotonicValue > s.maxMonotonic { + if s.maxMonotonic.Before(monotonicValue) { s.maxMonotonic = monotonicValue } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 797778e08..8f2658f64 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -64,17 +64,47 @@ func (e *ErrSaveRejection) Error() string { return "save rejected due to unsupported networking state: " + e.Err.Error() } +// MonotonicTime is a monotonic clock reading. +// +// +stateify savable +type MonotonicTime struct { + nanoseconds int64 +} + +// Before reports whether the monotonic clock reading mt is before u. +func (mt MonotonicTime) Before(u MonotonicTime) bool { + return mt.nanoseconds < u.nanoseconds +} + +// After reports whether the monotonic clock reading mt is after u. +func (mt MonotonicTime) After(u MonotonicTime) bool { + return mt.nanoseconds > u.nanoseconds +} + +// Add returns the monotonic clock reading mt+d. +func (mt MonotonicTime) Add(d time.Duration) MonotonicTime { + return MonotonicTime{ + nanoseconds: time.Unix(0, mt.nanoseconds).Add(d).Sub(time.Unix(0, 0)).Nanoseconds(), + } +} + +// Sub returns the duration mt-u. If the result exceeds the maximum (or minimum) +// value that can be stored in a Duration, the maximum (or minimum) duration +// will be returned. To compute t-d for a duration d, use t.Add(-d). +func (mt MonotonicTime) Sub(u MonotonicTime) time.Duration { + return time.Unix(0, mt.nanoseconds).Sub(time.Unix(0, u.nanoseconds)) +} + // A Clock provides the current time and schedules work for execution. // // Times returned by a Clock should always be used for application-visible // time. Only monotonic times should be used for netstack internal timekeeping. type Clock interface { - // NowNanoseconds returns the current real time as a number of - // nanoseconds since the Unix epoch. - NowNanoseconds() int64 + // Now returns the current local time. + Now() time.Time - // NowMonotonic returns a monotonic time value at nanosecond resolution. - NowMonotonic() int64 + // NowMonotonic returns the current monotonic clock reading. + NowMonotonic() MonotonicTime // AfterFunc waits for the duration to elapse and then calls f in its own // goroutine. It returns a Timer that can be used to cancel the call using @@ -435,11 +465,11 @@ type ControlMessages struct { // PacketOwner is used to get UID and GID of the packet. type PacketOwner interface { - // UID returns UID of the packet. - UID() uint32 + // UID returns KUID of the packet. + KUID() uint32 - // GID returns GID of the packet. - GID() uint32 + // GID returns KGID of the packet. + KGID() uint32 } // ReadOptions contains options for Endpoint.Read. @@ -861,6 +891,9 @@ type SettableSocketOption interface { isSettableSocketOption() } +// EndpointState represents the state of an endpoint. +type EndpointState uint8 + // CongestionControlState indicates the current congestion control state for // TCP sender. type CongestionControlState int @@ -897,6 +930,9 @@ type TCPInfoOption struct { // RTO is the retransmission timeout for the endpoint. RTO time.Duration + // State is the current endpoint protocol state. + State EndpointState + // CcState is the congestion control state. CcState CongestionControlState @@ -1552,6 +1588,10 @@ type IPForwardingStats struct { // were too big for the outgoing MTU. PacketTooBig *StatCounter + // HostUnreachable is the number of IP packets received which could not be + // successfully forwarded due to an unresolvable next hop. + HostUnreachable *StatCounter + // ExtensionHeaderProblem is the number of IP packets which were dropped // because of a problem encountered when processing an IPv6 extension // header. @@ -1835,37 +1875,104 @@ type UDPStats struct { ChecksumErrors *StatCounter } +// NICNeighborStats holds metrics for the neighbor table. +type NICNeighborStats struct { + // LINT.IfChange(NICNeighborStats) + + // UnreachableEntryLookups counts the number of lookups performed on an + // entry in Unreachable state. + UnreachableEntryLookups *StatCounter + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICNeighborStats) +} + +// NICPacketStats holds basic packet statistics. +type NICPacketStats struct { + // LINT.IfChange(NICPacketStats) + + // Packets is the number of packets counted. + Packets *StatCounter + + // Bytes is the number of bytes counted. + Bytes *StatCounter + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICPacketStats) +} + +// NICStats holds NIC statistics. +type NICStats struct { + // LINT.IfChange(NICStats) + + // UnknownL3ProtocolRcvdPackets is the number of packets received that were + // for an unknown or unsupported network protocol. + UnknownL3ProtocolRcvdPackets *StatCounter + + // UnknownL4ProtocolRcvdPackets is the number of packets received that were + // for an unknown or unsupported transport protocol. + UnknownL4ProtocolRcvdPackets *StatCounter + + // MalformedL4RcvdPackets is the number of packets received by a NIC that + // could not be delivered to a transport endpoint because the L4 header could + // not be parsed. + MalformedL4RcvdPackets *StatCounter + + // Tx contains statistics about transmitted packets. + Tx NICPacketStats + + // Rx contains statistics about received packets. + Rx NICPacketStats + + // DisabledRx contains statistics about received packets on disabled NICs. + DisabledRx NICPacketStats + + // Neighbor contains statistics about neighbor entries. + Neighbor NICNeighborStats + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICStats) +} + +// FillIn returns a copy of s with nil fields initialized to new StatCounters. +func (s NICStats) FillIn() NICStats { + InitStatCounters(reflect.ValueOf(&s).Elem()) + return s +} + // Stats holds statistics about the networking stack. -// -// All fields are optional. type Stats struct { - // UnknownProtocolRcvdPackets is the number of packets received by the - // stack that were for an unknown or unsupported protocol. - UnknownProtocolRcvdPackets *StatCounter + // TODO(https://gvisor.dev/issues/5986): Make the DroppedPackets stat less + // ambiguous. - // MalformedRcvdPackets is the number of packets received by the stack - // that were deemed malformed. - MalformedRcvdPackets *StatCounter - - // DroppedPackets is the number of packets dropped due to full queues. + // DroppedPackets is the number of packets dropped at the transport layer. DroppedPackets *StatCounter - // ICMP breaks out ICMP-specific stats (both v4 and v6). + // NICs is an aggregation of every NIC's statistics. These should not be + // incremented using this field, but using the relevant NIC multicounters. + NICs NICStats + + // ICMP is an aggregation of every NetworkEndpoint's ICMP statistics (both v4 + // and v6). These should not be incremented using this field, but using the + // relevant NetworkEndpoint ICMP multicounters. ICMP ICMPStats - // IGMP breaks out IGMP-specific stats. + // IGMP is an aggregation of every NetworkEndpoint's IGMP statistics. These + // should not be incremented using this field, but using the relevant + // NetworkEndpoint IGMP multicounters. IGMP IGMPStats - // IP breaks out IP-specific stats (both v4 and v6). + // IP is an aggregation of every NetworkEndpoint's IP statistics. These should + // not be incremented using this field, but using the relevant NetworkEndpoint + // IP multicounters. IP IPStats - // ARP breaks out ARP-specific stats. + // ARP is an aggregation of every NetworkEndpoint's ARP statistics. These + // should not be incremented using this field, but using the relevant + // NetworkEndpoint ARP multicounters. ARP ARPStats - // TCP breaks out TCP-specific stats. + // TCP holds TCP-specific stats. TCP TCPStats - // UDP breaks out UDP-specific stats. + // UDP holds UDP-specific stats. UDP UDPStats } diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index 269081ff8..c96ae2f02 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -19,7 +19,6 @@ import ( "fmt" "io" "net" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -210,26 +209,6 @@ func TestAddressString(t *testing.T) { } } -func TestStatsString(t *testing.T) { - got := fmt.Sprintf("%+v", Stats{}.FillIn()) - - matchers := []string{ - // Print root-level stats correctly. - "UnknownProtocolRcvdPackets:0", - // Print protocol-specific stats correctly. - "TCP:{ActiveConnectionOpenings:0", - } - - for _, m := range matchers { - if !strings.Contains(got, m) { - t.Errorf("string.Contains(got, %q) = false", m) - } - } - if t.Failed() { - t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) - } -} - func TestAddressWithPrefixSubnet(t *testing.T) { tests := []struct { addr Address diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index ab2dab60c..181ef799e 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -53,12 +53,14 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/faketime", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 07ba2b837..f9ab7d0af 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -166,7 +166,7 @@ func TestIPTablesStatsForInput(t *testing.T) { // Make sure the packet is not dropped by the next rule. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -187,7 +187,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -207,7 +207,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -227,7 +227,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -250,7 +250,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -273,7 +273,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -293,7 +293,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -313,7 +313,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -465,7 +465,7 @@ func TestIPTableWritePackets(t *testing.T) { } if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil { - t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err) + t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err) } }, genPacket: func(r *stack.Route) stack.PacketBufferList { @@ -556,7 +556,7 @@ func TestIPTableWritePackets(t *testing.T) { } if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil { - t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err) + t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err) } }, genPacket: func(r *stack.Route) stack.PacketBufferList { @@ -681,6 +681,32 @@ func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Addr checker.ICMPv6Type(header.ICMPv6EchoReply))) } +func boolToInt(v bool) uint64 { + if v { + return 1 + } + return 0 +} + +func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { + return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, ipv6) + ruleIdx := filter.BuiltinChains[hook] + filter.Rules[ruleIdx].Filter = f + filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} + if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) + } + } +} + func TestForwardingHook(t *testing.T) { const ( nicID1 = 1 @@ -740,32 +766,6 @@ func TestForwardingHook(t *testing.T) { }, } - setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { - return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { - t.Helper() - - ipv6 := netProto == ipv6.ProtocolNumber - - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, ipv6) - ruleIdx := filter.BuiltinChains[stack.Forward] - filter.Rules[ruleIdx].Filter = f - filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} - // Make sure the packet is not dropped by the next rule. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} - if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) - } - } - } - - boolToInt := func(v bool) uint64 { - if v { - return 1 - } - return 0 - } - subTests := []struct { name string setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) @@ -779,59 +779,59 @@ func TestForwardingHook(t *testing.T) { { name: "Drop", - setupFilter: setupDropFilter(stack.IPHeaderFilter{}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}), expectForward: false, }, { name: "Drop with input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}), expectForward: false, }, { name: "Drop with output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}), expectForward: false, }, { name: "Drop with input and output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), expectForward: false, }, { name: "Drop with other input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other input and output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), expectForward: true, }, { name: "Drop with input and other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other input and other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with inverted input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), expectForward: true, }, { name: "Drop with inverted output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), expectForward: true, }, } @@ -941,3 +941,194 @@ func TestForwardingHook(t *testing.T) { }) } } + +func TestInputHookWithLocalForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Name = "nic1" + nic2Name = "nic2" + + otherNICName = "otherNIC" + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + rx func(*channel.Endpoint) + checker func(*testing.T, []byte) + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + rx: func(e *channel.Endpoint) { + utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl) + }, + checker: func(t *testing.T, b []byte) { + checker.IPv4(t, b, + checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address), + checker.DstAddr(utils.RemoteIPv4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply))) + }, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + rx: func(e *channel.Endpoint) { + utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl) + }, + checker: func(t *testing.T, b []byte) { + checker.IPv6(t, b, + checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address), + checker.DstAddr(utils.RemoteIPv6Addr), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply))) + }, + }, + } + + subTests := []struct { + name string + setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) + expectDrop bool + }{ + { + name: "Accept", + setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, + expectDrop: false, + }, + + { + name: "Drop", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}), + expectDrop: true, + }, + { + name: "Drop with input NIC filtering on arrival NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}), + expectDrop: true, + }, + { + name: "Drop with input NIC filtering on delivered NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}), + expectDrop: false, + }, + + { + name: "Drop with input NIC filtering on other NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}), + expectDrop: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + }) + + subTest.setupFilter(t, s, test.netProto) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err) + } + if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) + } + if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err) + } + if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err) + } + + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID1, + }, + }) + + test.rx(e1) + + ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) + } + ep1Stats := ep1.Stats() + ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) + } + ip1Stats := ipEP1Stats.IPStats() + + if got := ip1Stats.PacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) + } + if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want { + t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want) + } + + ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) + } + ep2Stats := ep2.Stats() + ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) + } + ip2Stats := ipEP2Stats.IPStats() + if got := ip2Stats.PacketsReceived.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) + } + if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want { + t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want) + } + if got := ip2Stats.PacketsSent.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got) + } + + if p, ok := e1.Read(); ok == subTest.expectDrop { + t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop) + } else if !subTest.expectDrop { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + if p, ok := e2.Read(); ok { + t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index c657714ba..27caa0c28 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -17,6 +17,8 @@ package link_resolution_test import ( "bytes" "fmt" + "net" + "runtime" "testing" "time" @@ -27,12 +29,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -283,9 +287,11 @@ func TestTCPLinkResolutionFailure(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + Clock: clock, } host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) @@ -329,7 +335,17 @@ func TestTCPLinkResolutionFailure(t *testing.T) { // Wait for an error due to link resolution failing, or the endpoint to be // writable. + if test.expectedWriteErr != nil { + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) + } + clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) + } else { + clock.RunImmediatelyScheduledJobs() + } <-ch + { var r bytes.Reader r.Reset([]byte{0}) @@ -395,6 +411,242 @@ func TestTCPLinkResolutionFailure(t *testing.T) { } } +func TestForwardingWithLinkResolutionFailure(t *testing.T) { + const ( + incomingNICID = 1 + outgoingNICID = 2 + ttl = 2 + expectedHostUnreachableErrorCount = 1 + ) + outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06") + + rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) + } + + rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) + } + + arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { + if request.Proto != arp.ProtocolNumber { + t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber) + } + if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + } + rep := header.ARP(request.Pkt.NetworkHeader().View()) + if got := rep.Op(); got != header.ARPRequest { + t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != src { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst) + } + } + + ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { + if request.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber) + } + + snmc := header.SolicitedNodeAddr(dst) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()), + checker.SrcAddr(src), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(dst), + )) + } + + icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ipv4.DefaultTTL), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4HostUnreachable), + ), + ) + } + + icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ipv6.DefaultTTL), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6AddressUnreachable), + ), + ) + } + + tests := []struct { + name string + networkProtocolFactory []stack.NetworkProtocolFactory + networkProtocolNumber tcpip.NetworkProtocolNumber + sourceAddr tcpip.Address + destAddr tcpip.Address + incomingAddr tcpip.AddressWithPrefix + outgoingAddr tcpip.AddressWithPrefix + transportProtocol func(*stack.Stack) stack.TransportProtocol + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address) + icmpReplyChecker func(*testing.T, []byte, tcpip.Address, tcpip.Address) + mtu uint32 + }{ + { + name: "IPv4 Host unreachable", + networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + networkProtocolNumber: header.IPv4ProtocolNumber, + sourceAddr: tcptestutil.MustParse4("10.0.0.2"), + destAddr: tcptestutil.MustParse4("11.0.0.2"), + incomingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), + PrefixLen: 8, + }, + outgoingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), + PrefixLen: 8, + }, + transportProtocol: icmp.NewProtocol4, + linkResolutionRequestChecker: arpChecker, + icmpReplyChecker: icmpv4Checker, + rx: rxICMPv4EchoRequest, + mtu: ipv4.MaxTotalSize, + }, + { + name: "IPv6 Host unreachable", + networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + networkProtocolNumber: header.IPv6ProtocolNumber, + sourceAddr: tcptestutil.MustParse6("10::2"), + destAddr: tcptestutil.MustParse6("11::2"), + incomingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + }, + outgoingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11::1").To16()), + PrefixLen: 64, + }, + transportProtocol: icmp.NewProtocol6, + linkResolutionRequestChecker: ndpChecker, + icmpReplyChecker: icmpv6Checker, + rx: rxICMPv6EchoRequest, + mtu: header.IPv6MinimumMTU, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + + s := stack.New(stack.Options{ + NetworkProtocols: test.networkProtocolFactory, + TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol}, + Clock: clock, + }) + + // Set up endpoint through which we will receive packets. + incomingEndpoint := channel.New(1, test.mtu, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) + } + incomingProtoAddr := tcpip.ProtocolAddress{ + Protocol: test.networkProtocolNumber, + AddressWithPrefix: test.incomingAddr, + } + if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err) + } + + // Set up endpoint through which we will attempt to forward packets. + outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr) + outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) + } + outgoingProtoAddr := tcpip.ProtocolAddress{ + Protocol: test.networkProtocolNumber, + AddressWithPrefix: test.outgoingAddr, + } + if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: test.incomingAddr.Subnet(), + NIC: incomingNICID, + }, + { + Destination: test.outgoingAddr.Subnet(), + NIC: outgoingNICID, + }, + }) + + if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err) + } + + test.rx(incomingEndpoint, test.sourceAddr, test.destAddr) + + nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber) + if err != nil { + t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err) + } + // Trigger the first packet on the endpoint. + clock.RunImmediatelyScheduledJobs() + + for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ { + request, ok := outgoingEndpoint.Read() + if !ok { + t.Fatal("expected ARP packet through outgoing NIC") + } + + test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr) + + // Advance the clock the span of one request timeout. + clock.Advance(nudConfigs.RetransmitTimer) + } + + // Next, we make a blocking read to retrieve the error packet. This is + // necessary because outgoing packets are dequeued asynchronously when + // link resolution fails, and this dequeue is what triggers the ICMP + // error. + reply, ok := incomingEndpoint.Read() + if !ok { + t.Fatal("expected ICMP packet through incoming NIC") + } + + test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr) + + // Since link resolution failed, we don't expect the packet to be + // forwarded. + forwardedPacket, ok := outgoingEndpoint.Read() + if ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket) + } + + if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want { + t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want) + } + }) + } +} + func TestGetLinkAddress(t *testing.T) { const ( host1NICID = 1 @@ -449,8 +701,10 @@ func TestGetLinkAddress(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + Clock: clock, } host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) @@ -466,8 +720,20 @@ func TestGetLinkAddress(t *testing.T) { if test.expectedErr == nil { wantRes.LinkAddress = utils.LinkAddr2 } - if diff := cmp.Diff(wantRes, <-ch); diff != "" { - t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) + } + + clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) + select { + case got := <-ch: + if diff := cmp.Diff(wantRes, got); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("event didn't arrive") } }) } @@ -544,8 +810,10 @@ func TestRouteResolvedFields(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + Clock: clock, } host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) @@ -575,8 +843,20 @@ func TestRouteResolvedFields(t *testing.T) { if _, ok := err.(*tcpip.ErrWouldBlock); !ok { t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) } - if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { - t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) + + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) + } + clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) + + select { + case got := <-ch: + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, got, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) + } + default: + t.Fatalf("event didn't arrive") } if test.expectedErr != nil { @@ -789,11 +1069,16 @@ func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo d.c <- e } -func (d *nudDispatcher) waitForEvent(want eventInfo) error { - if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { - return fmt.Errorf("got invalid event (-want +got):\n%s", diff) +func (d *nudDispatcher) expectEvent(want eventInfo) error { + select { + case got := <-d.c: + if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" { + return fmt.Errorf("got invalid event (-want +got):\n%s", diff) + } + return nil + default: + return fmt.Errorf("event didn't arrive") } - return nil } // TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it @@ -804,7 +1089,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto tcpip.NetworkProtocolNumber remoteAddr tcpip.Address neighborAddr tcpip.Address - getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) + getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) isHost1Listener bool }{ { @@ -812,23 +1097,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -836,23 +1123,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -860,23 +1149,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -884,23 +1175,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -908,23 +1201,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -933,23 +1228,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -958,23 +1255,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -983,23 +1282,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -1037,14 +1338,14 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("link resolution mismatch (-want +got):\n%s", diff) } } - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryAdded, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr}, }); err != nil { t.Fatalf("error waiting for initial NUD event: %s", err) } - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1064,7 +1365,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // See NUDConfigurations.BaseReachableTime for more information. maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor) clock.Advance(maxReachableTime) - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1072,8 +1373,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("error waiting for stale NUD event: %s", err) } - listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) - defer listenerEP.Close() + listenerEP, listenerCH, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) defer clientEP.Close() listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234} if err := listenerEP.Bind(listenerAddr); err != nil { @@ -1094,14 +1394,15 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // with confirmation that the neighbor is reachable (indicated by a // successful 3-way handshake). <-clientCH - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, }); err != nil { t.Fatalf("error waiting for delay NUD event: %s", err) } - if err := nudDisp.waitForEvent(eventInfo{ + <-listenerCH + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1109,26 +1410,55 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("error waiting for reachable NUD event: %s", err) } + peerEP, peerWQ, err := listenerEP.Accept(nil) + if err != nil { + t.Fatalf("listenerEP.Accept(): %s", err) + } + defer peerEP.Close() + peerWE, peerCH := waiter.NewChannelEntry(nil) + peerWQ.EventRegister(&peerWE, waiter.ReadableEvents) + // Wait for the neighbor to be stale again then send data to the remote. // // On successful transmission, the neighbor should become reachable // without probing the neighbor as a TCP ACK would be received which is an // indication of the neighbor being reachable. clock.Advance(maxReachableTime) - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, }); err != nil { t.Fatalf("error waiting for stale NUD event: %s", err) } - var r bytes.Reader - r.Reset([]byte{0}) - var wOpts tcpip.WriteOptions - if _, err := clientEP.Write(&r, wOpts); err != nil { - t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) + { + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + if _, err := clientEP.Write(&r, wOpts); err != nil { + t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) + } + } + // Heads up, there is a race here. + // + // Incoming TCP segments are handled in + // tcp.(*endpoint).handleSegmentLocked: + // + // - tcp.(*endpoint).rcv.handleRcvdSegment puts the segment on the + // segment queue and notifies waiting readers (such as this channel) + // + // - tcp.(*endpoint).snd.handleRcvdSegment sends an ACK for the segment + // and notifies the NUD machinery that the peer is reachable + // + // Thus we must permit a delay between the readable signal and the + // expected NUD event. + // + // At the time of writing, this race is reliably hit with gotsan. + <-peerCH + for len(nudDisp.c) == 0 { + runtime.Gosched() } - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1141,7 +1471,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // TCP should not mark the route reachable and NUD should go through the // probe state. clock.Advance(nudConfigs.DelayFirstProbeTime) - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1149,7 +1479,16 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("error waiting for probe NUD event: %s", err) } } - if err := nudDisp.waitForEvent(eventInfo{ + { + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + if _, err := peerEP.Write(&r, wOpts); err != nil { + t.Errorf("peerEP.Write(_, %#v): %s", wOpts, err) + } + } + <-clientCH + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 87d36e1dd..b2008f0b2 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -44,20 +44,17 @@ type ndpDispatcher struct{} func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { } -func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { - return false +func (*ndpDispatcher) OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address, header.NDPRoutePreference) { } -func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {} +func (*ndpDispatcher) OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {} -func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { - return false +func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) { } func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {} -func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { - return true +func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) { } func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {} diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go index f84d399fb..94b580a70 100644 --- a/pkg/tcpip/testutil/testutil.go +++ b/pkg/tcpip/testutil/testutil.go @@ -109,3 +109,15 @@ func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) er return nil } + +// MustParseLink parses a Link string into a tcpip.LinkAddress, panicking on +// error. +// +// The string must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff. +func MustParseLink(addr string) tcpip.LinkAddress { + parsed, err := tcpip.ParseMACAddress(addr) + if err != nil { + panic(fmt.Sprintf("tcpip.ParseMACAddress(%s): %s", addr, err)) + } + return parsed +} diff --git a/pkg/tcpip/time.s b/pkg/tcpip/time.s deleted file mode 100644 index fb37360ac..000000000 --- a/pkg/tcpip/time.s +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Empty assembly file so empty func definitions work. diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index 1633d0aeb..ed1ed8ac6 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -15,6 +15,7 @@ package tcpip_test import ( + "math" "sync" "testing" "time" @@ -22,10 +23,85 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) +func TestMonotonicTimeBefore(t *testing.T) { + var mt tcpip.MonotonicTime + if mt.Before(mt) { + t.Errorf("%#v.Before(%#v)", mt, mt) + } + + one := mt.Add(1) + if one.Before(mt) { + t.Errorf("%#v.Before(%#v)", one, mt) + } + if !mt.Before(one) { + t.Errorf("!%#v.Before(%#v)", mt, one) + } +} + +func TestMonotonicTimeAfter(t *testing.T) { + var mt tcpip.MonotonicTime + if mt.After(mt) { + t.Errorf("%#v.After(%#v)", mt, mt) + } + + one := mt.Add(1) + if mt.After(one) { + t.Errorf("%#v.After(%#v)", mt, one) + } + if !one.After(mt) { + t.Errorf("!%#v.After(%#v)", one, mt) + } +} + +func TestMonotonicTimeAddSub(t *testing.T) { + var mt tcpip.MonotonicTime + if one, two := mt.Add(2), mt.Add(1).Add(1); one != two { + t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two) + } + + min := mt.Add(math.MinInt64) + max := mt.Add(math.MaxInt64) + + if overflow := mt.Add(1).Add(math.MaxInt64); overflow != max { + t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow) + } + if underflow := mt.Add(-1).Add(math.MinInt64); underflow != min { + t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", min, underflow) + } + + if got, want := min.Sub(min), time.Duration(0); want != got { + t.Errorf("got min.Sub(min) = %d, want %d", got, want) + } + if got, want := max.Sub(max), time.Duration(0); want != got { + t.Errorf("got max.Sub(max) = %d, want %d", got, want) + } + + if overflow, want := max.Sub(min), time.Duration(math.MaxInt64); overflow != want { + t.Errorf("mt.Add(math.MaxInt64).Sub(mt.Add(math.MinInt64) != %s (%#v)", want, overflow) + } + if underflow, want := min.Sub(max), time.Duration(math.MinInt64); underflow != want { + t.Errorf("mt.Add(math.MinInt64).Sub(mt.Add(math.MaxInt64) != %s (%#v)", want, underflow) + } +} + +func TestMonotonicTimeSub(t *testing.T) { + var mt tcpip.MonotonicTime + + if one, two := mt.Add(2), mt.Add(1).Add(1); one != two { + t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two) + } + + if max, overflow := mt.Add(math.MaxInt64), mt.Add(1).Add(math.MaxInt64); max != overflow { + t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow) + } + if max, underflow := mt.Add(math.MinInt64), mt.Add(-1).Add(math.MinInt64); max != underflow { + t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", max, underflow) + } +} + const ( shortDuration = 1 * time.Nanosecond middleDuration = 100 * time.Millisecond - longDuration = 1 * time.Second ) func TestJobReschedule(t *testing.T) { @@ -53,10 +129,14 @@ func TestJobReschedule(t *testing.T) { wg.Wait() } +func stdClockWithAfter() (tcpip.Clock, func(time.Duration) <-chan time.Time) { + return tcpip.NewStdClock(), time.After +} + func TestJobExecution(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -68,7 +148,7 @@ func TestJobExecution(t *testing.T) { // Wait for timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -76,14 +156,14 @@ func TestJobExecution(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestCancellableTimerResetFromLongDuration(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -99,7 +179,7 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { // Wait for timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -107,14 +187,14 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestJobRescheduleFromShortDuration(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -128,7 +208,7 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { select { case <-ch: t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration): + case <-after(middleDuration): } job.Schedule(shortDuration) @@ -136,7 +216,7 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { // Wait for timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -144,14 +224,14 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestJobImmediatelyCancel(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -167,14 +247,19 @@ func TestJobImmediatelyCancel(t *testing.T) { select { case <-ch: t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration): + case <-after(middleDuration): } } +func stdClockWithAfterAndSleep() (tcpip.Clock, func(time.Duration) <-chan time.Time, func(time.Duration)) { + clock, after := stdClockWithAfter() + return clock, after, time.Sleep +} + func TestJobCancelledRescheduleWithoutLock(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after, sleep := stdClockWithAfterAndSleep() var lock sync.Mutex ch := make(chan struct{}) @@ -189,7 +274,7 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) { lock.Lock() // Sleep until the timer fires and gets blocked trying to take the lock. - time.Sleep(middleDuration * 2) + sleep(middleDuration * 2) job.Cancel() lock.Unlock() } @@ -199,14 +284,14 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) { select { case <-ch: t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration * 2): + case <-after(middleDuration * 2): } } func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after, sleep := stdClockWithAfterAndSleep() var lock sync.Mutex ch := make(chan struct{}) @@ -215,7 +300,7 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { job.Schedule(shortDuration) for i := 0; i < 10; i++ { // Sleep until the timer fires and gets blocked trying to take the lock. - time.Sleep(middleDuration) + sleep(middleDuration) job.Cancel() job.Schedule(shortDuration) } @@ -224,7 +309,7 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { // Wait for double the duration for the last timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -232,14 +317,14 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestManyJobReschedulesUnderLock(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -255,7 +340,7 @@ func TestManyJobReschedulesUnderLock(t *testing.T) { // Wait for double the duration for the last timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -263,6 +348,6 @@ func TestManyJobReschedulesUnderLock(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 7e5c79776..bbc0e3ecc 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -38,3 +38,22 @@ go_library( "//pkg/waiter", ], ) + +go_test( + name = "icmp_x_test", + size = "small", + srcs = ["icmp_test.go"], + deps = [ + ":icmp", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 8afde7fca..cb316d27a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -16,6 +16,7 @@ package icmp import ( "io" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -26,14 +27,12 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// TODO(https://gvisor.dev/issues/5623): Unit test this package. - // +stateify savable type icmpPacket struct { icmpPacketEntry senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + receivedAt time.Time `state:".(int64)"` } type endpointState int @@ -133,7 +132,8 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) + e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice) } // Close the receive list and drain it. @@ -160,7 +160,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { @@ -193,7 +193,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult Total: p.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: p.timestamp, + Timestamp: p.receivedAt.UnixNano(), }, } if opts.NeedRemoteAddr { @@ -304,6 +304,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicID := to.NIC + if nicID == 0 { + nicID = tcpip.NICID(e.ops.GetBindToDevice()) + } if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { return 0, &tcpip.ErrNoRoute{} @@ -348,8 +351,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return int64(len(v)), nil } +var _ tcpip.SocketOptionsHandler = (*endpoint)(nil) + +// HasNIC implements tcpip.SocketOptionsHandler. +func (e *endpoint) HasNIC(id int32) bool { + return e.stack.HasNIC(tcpip.NICID(id)) +} + // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { +func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { return nil } @@ -390,7 +400,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -606,18 +616,19 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { +func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) return id, err } // We need to find a port for the endpoint. - _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { + _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) + err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) switch err.(type) { case nil: return true, nil @@ -747,8 +758,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB switch e.NetProto { case header.IPv4ProtocolNumber: h := header.ICMPv4(pkt.TransportHeader().View()) - // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed - // after early parsing. if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -756,8 +765,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } case header.IPv6ProtocolNumber: h := header.ICMPv6(pkt.TransportHeader().View()) - // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed - // after early parsing. if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -800,7 +807,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() - packet.timestamp = e.stack.Clock().NowNanoseconds() + packet.receivedAt = e.stack.Clock().Now() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 28a56a2d5..b8b839e4a 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -15,11 +15,23 @@ package icmp import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *icmpPacket) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *icmpPacket) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves icmpPacket.data field. func (p *icmpPacket) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go new file mode 100644 index 000000000..cc950cbde --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_test.go @@ -0,0 +1,235 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package icmp_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// TODO(https://gvisor.dev/issues/5623): Finish unit testing the icmp package. +// See the issue for remaining areas of work. + +var ( + localV4Addr1 = testutil.MustParse4("10.0.0.1") + localV4Addr2 = testutil.MustParse4("10.0.0.2") + remoteV4Addr = testutil.MustParse4("10.0.0.3") +) + +func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name string, addrV4 tcpip.Address) *channel.Endpoint { + t.Helper() + + ep := channel.New(1 /* size */, header.IPv4MinimumMTU, "" /* linkAddr */) + t.Cleanup(ep.Close) + + wep := stack.LinkEndpoint(ep) + if testing.Verbose() { + wep = sniffer.New(ep) + } + + opts := stack.NICOptions{Name: name} + if err := s.CreateNICWithOptions(id, wep, opts); err != nil { + t.Fatalf("s.CreateNIC(%d, _) = %s", id, err) + } + + if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err) + } + + s.AddRoute(tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: id, + }) + + return ep +} + +func writePayload(buf []byte) { + for i := range buf { + buf[i] = byte(i) + } +} + +func newICMPv4EchoRequest(payloadSize uint32) buffer.View { + buf := buffer.NewView(header.ICMPv4MinimumSize + int(payloadSize)) + writePayload(buf[header.ICMPv4MinimumSize:]) + + icmp := header.ICMPv4(buf) + icmp.SetType(header.ICMPv4Echo) + // No need to set the checksum; it is reset by the socket before the packet + // is sent. + + return buf +} + +// TestWriteUnboundWithBindToDevice exercises writing to an unbound ICMP socket +// when SO_BINDTODEVICE is set to the non-default NIC for that subnet. +// +// Only IPv4 is tested. The logic to determine which NIC to use is agnostic to +// the version of IP. +func TestWriteUnboundWithBindToDevice(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + HandleLocal: true, + }) + + // Add two NICs, both with default routes on the same subnet. The first NIC + // added will be the default NIC for that subnet. + defaultEP := addNICWithDefaultRoute(t, s, 1, "default", localV4Addr1) + alternateEP := addNICWithDefaultRoute(t, s, 2, "alternate", localV4Addr2) + + socket, err := s.NewEndpoint(icmp.ProtocolNumber4, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _) = %s", icmp.ProtocolNumber4, ipv4.ProtocolNumber, err) + } + defer socket.Close() + + echoPayloadSize := defaultEP.MTU() - header.IPv4MinimumSize - header.ICMPv4MinimumSize + + // Send a packet without SO_BINDTODEVICE. This verifies that the first NIC + // to be added is the default NIC to send packets when not explicitly bound. + { + buf := newICMPv4EchoRequest(echoPayloadSize) + r := buf.Reader() + n, err := socket.Write(&r, tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: remoteV4Addr}, + }) + if err != nil { + t.Fatalf("socket.Write(_, {To:%s}) = %s", remoteV4Addr, err) + } + if n != int64(len(buf)) { + t.Fatalf("got n = %d, want n = %d", n, len(buf)) + } + + // Verify the packet was sent out the default NIC. + p, ok := defaultEP.Read() + if !ok { + t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)") + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() + + checker.IPv4(t, b, []checker.NetworkChecker{ + checker.SrcAddr(localV4Addr1), + checker.DstAddr(remoteV4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]), + ), + }...) + + // Verify the packet was not sent out the alternate NIC. + if p, ok := alternateEP.Read(); ok { + t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p) + } + } + + // Send a packet with SO_BINDTODEVICE. This exercises reliance on + // SO_BINDTODEVICE to route the packet to the alternate NIC. + { + // Use SO_BINDTODEVICE to send over the alternate NIC by default. + socket.SocketOptions().SetBindToDevice(2) + + buf := newICMPv4EchoRequest(echoPayloadSize) + r := buf.Reader() + n, err := socket.Write(&r, tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: remoteV4Addr}, + }) + if err != nil { + t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err) + } + if n != int64(len(buf)) { + t.Fatalf("got n = %d, want n = %d", n, len(buf)) + } + + // Verify the packet was not sent out the default NIC. + if p, ok := defaultEP.Read(); ok { + t.Fatalf("got defaultEP.Read(_) = %+v, true; want = _, false", p) + } + + // Verify the packet was sent out the alternate NIC. + p, ok := alternateEP.Read() + if !ok { + t.Fatalf("got alternateEP.Read(_) = _, false; want = _, true (packet wasn't written out)") + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() + + checker.IPv4(t, b, []checker.NetworkChecker{ + checker.SrcAddr(localV4Addr2), + checker.DstAddr(remoteV4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]), + ), + }...) + } + + // Send a packet with SO_BINDTODEVICE cleared. This verifies that clearing + // the device binding will fallback to using the default NIC to send + // packets. + { + socket.SocketOptions().SetBindToDevice(0) + + buf := newICMPv4EchoRequest(echoPayloadSize) + r := buf.Reader() + n, err := socket.Write(&r, tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: remoteV4Addr}, + }) + if err != nil { + t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err) + } + if n != int64(len(buf)) { + t.Fatalf("got n = %d, want n = %d", n, len(buf)) + } + + // Verify the packet was sent out the default NIC. + p, ok := defaultEP.Read() + if !ok { + t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)") + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() + + checker.IPv4(t, b, []checker.NetworkChecker{ + checker.SrcAddr(localV4Addr1), + checker.DstAddr(remoteV4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]), + ), + }...) + + // Verify the packet was not sent out the alternate NIC. + if p, ok := alternateEP.Read(); ok { + t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p) + } + } +} diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 47f7dd1cb..fa82affc1 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -123,8 +123,6 @@ func (*protocol) Wait() {} // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - // TODO(gvisor.dev/issue/170): Implement parsing of ICMP. - // // Right now, the Parse() method is tied to enabled protocols passed into // stack.New. This works for UDP and TCP, but we handle ICMP traffic even // when netstack users don't pass ICMP as a supported protocol. diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 496eca581..8e7bb6c6e 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -27,6 +27,7 @@ package packet import ( "fmt" "io" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -41,9 +42,8 @@ type packet struct { packetEntry // data holds the actual packet data, including any headers and // payload. - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // timestampNS is the unix time at which the packet was received. - timestampNS int64 + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + receivedAt time.Time `state:".(int64)"` // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress // packetInfo holds additional information like the protocol @@ -159,7 +159,7 @@ func (ep *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (ep *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { @@ -189,7 +189,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul Total: packet.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: packet.timestampNS, + Timestamp: packet.receivedAt.UnixNano(), }, } if opts.NeedRemoteAddr { @@ -208,7 +208,6 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul } func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, tcpip.Error) { - // TODO(gvisor.dev/issue/173): Implement. return 0, &tcpip.ErrInvalidOptionValue{} } @@ -220,19 +219,19 @@ func (*endpoint) Disconnect() tcpip.Error { // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be // connected, and this function always returnes *tcpip.ErrNotSupported. -func (*endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { +func (*endpoint) Connect(tcpip.FullAddress) tcpip.Error { return &tcpip.ErrNotSupported{} } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used // with Shutdown, and this function always returns *tcpip.ErrNotSupported. -func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { +func (*endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { return &tcpip.ErrNotSupported{} } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with // Listen, and this function always returns *tcpip.ErrNotSupported. -func (*endpoint) Listen(backlog int) tcpip.Error { +func (*endpoint) Listen(int) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -244,8 +243,6 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi // Bind implements tcpip.Endpoint.Bind. func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { - // TODO(gvisor.dev/issue/173): Add Bind support. - // "By default, all packets of the specified protocol type are passed // to a packet socket. To get packets only from a specific interface // use bind(2) specifying an address in a struct sockaddr_ll to bind @@ -318,7 +315,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { +func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -339,7 +336,7 @@ func (ep *endpoint) UpdateLastError(err tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -385,7 +382,6 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet - // TODO(gvisor.dev/issue/173): Return network protocol. if !pkt.LinkHeader().View().IsEmpty() { // Get info directly from the ethernet header. hdr := header.Ethernet(pkt.LinkHeader().View()) @@ -424,7 +420,6 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, default: panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt)) } - } else { // Raw packets need their ethernet headers prepended before // queueing. @@ -451,7 +446,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views()) } } - packet.timestampNS = ep.stack.Clock().NowNanoseconds() + packet.receivedAt = ep.stack.Clock().Now() ep.rcvList.PushBack(&packet) ep.rcvBufSize += packet.data.Size() @@ -484,7 +479,7 @@ func (ep *endpoint) Stats() tcpip.EndpointStats { } // SetOwner implements tcpip.Endpoint.SetOwner. -func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} +func (*endpoint) SetOwner(tcpip.PacketOwner) {} // SocketOptions implements tcpip.Endpoint.SocketOptions. func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 5bd860d20..e729921db 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,11 +15,23 @@ package packet import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *packet) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *packet) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves packet.data field. func (p *packet) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index bcec3d2e7..b6687911a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -27,6 +27,7 @@ package raw import ( "io" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -41,9 +42,8 @@ type rawPacket struct { rawPacketEntry // data holds the actual packet data, including any headers and // payload. - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // timestampNS is the unix time at which the packet was received. - timestampNS int64 + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + receivedAt time.Time `state:".(int64)"` // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress } @@ -183,7 +183,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.mu.Lock() @@ -219,7 +219,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult Total: pkt.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: pkt.timestampNS, + Timestamp: pkt.receivedAt.UnixNano(), }, } if opts.NeedRemoteAddr { @@ -286,26 +286,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return nil, nil, nil, &tcpip.ErrBadBuffer{} } - // If this is an unassociated socket and callee provided a nonzero - // destination address, route using that address. - if e.ops.GetHeaderIncluded() { - ip := header.IPv4(payloadBytes) - if !ip.IsValid(len(payloadBytes)) { - return nil, nil, nil, &tcpip.ErrInvalidOptionValue{} - } - dstAddr := ip.DestinationAddress() - // Update dstAddr with the address in the IP header, unless - // opts.To is set (e.g. if sendto specifies a specific - // address). - if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil { - opts.To = &tcpip.FullAddress{ - NIC: 0, // NIC is unset. - Addr: dstAddr, // The address from the payload. - Port: 0, // There are no ports here. - } - } - } - // Did the user caller provide a destination? If not, use the connected // destination. if opts.To == nil { @@ -402,7 +382,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { } // Find a route to the destination. - route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false) + route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false) if err != nil { return err } @@ -428,7 +408,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { +func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -439,7 +419,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { } // Listen implements tcpip.Endpoint.Listen. -func (*endpoint) Listen(backlog int) tcpip.Error { +func (*endpoint) Listen(int) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -513,12 +493,12 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { +func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -621,7 +601,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } combinedVV.Append(pkt.Data().ExtractVV()) packet.data = combinedVV - packet.timestampNS = e.stack.Clock().NowNanoseconds() + packet.receivedAt = e.stack.Clock().Now() e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 5d6f2709c..39669b445 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -15,11 +15,23 @@ package raw import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *rawPacket) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *rawPacket) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves rawPacket.data field. func (p *rawPacket) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 0f20d3856..8436d2cf0 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -41,7 +41,6 @@ go_library( "protocol.go", "rack.go", "rcv.go", - "rcv_state.go", "reno.go", "reno_recovery.go", "sack.go", @@ -53,7 +52,6 @@ go_library( "segment_state.go", "segment_unsafe.go", "snd.go", - "snd_state.go", "tcp_endpoint_list.go", "tcp_segment_list.go", "timer.go", @@ -134,6 +132,7 @@ go_test( deps = [ "//pkg/sleep", "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", "//pkg/tcpip/stack", "@com_github_google_go_cmp//cmp:go_default_library", ], diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d4bd4e80e..d807b13b7 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -114,8 +114,8 @@ type listenContext struct { } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. -func timeStamp() uint32 { - return uint32(time.Now().Unix()>>6) & tsMask +func timeStamp(clock tcpip.Clock) uint32 { + return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Seconds()) >> 6 & tsMask } // newListenContext creates a new listen context. @@ -171,7 +171,7 @@ func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonc // createCookie creates a SYN cookie for the given id and incoming sequence // number. func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value { - ts := timeStamp() + ts := timeStamp(l.stack.Clock()) v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset) v += (l.cookieHash(id, ts, 1) + data) & hashMask return seqnum.Value(v) @@ -181,7 +181,7 @@ func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Va // sequence number. If it is, it also returns the data originally encoded in the // cookie when createCookie was called. func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) { - ts := timeStamp() + ts := timeStamp(l.stack.Clock()) v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq) cookieTS := v >> tsOffset if ((ts - cookieTS) & tsMask) > maxTSDiff { @@ -247,7 +247,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - isn := generateSecureISN(s.id, l.stack.Seed()) + isn := generateSecureISN(s.id, l.stack.Clock(), l.stack.Seed()) ep, err := l.createConnectingEndpoint(s, opts, queue) if err != nil { return nil, err @@ -550,7 +550,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err e.rcvQueueInfo.rcvQueueMu.Lock() rcvClosed := e.rcvQueueInfo.RcvClosed e.rcvQueueInfo.rcvQueueMu.Unlock() - if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) { + if rcvClosed || s.flags.Contains(header.TCPFlagSyn|header.TCPFlagAck) { // If the endpoint is shutdown, reply with reset. // // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST @@ -560,6 +560,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } switch { + case s.flags.Contains(header.TCPFlagRst): + e.stack.Stats().DroppedPackets.Increment() + return nil + case s.flags == header.TCPFlagSyn: if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() @@ -591,7 +595,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, - TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), + TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())), TSEcr: opts.TSVal, MSS: calculateAdvertisedMSS(e.userMSS, route), } @@ -611,7 +615,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() return nil - case (s.flags & header.TCPFlagAck) != 0: + case s.flags.Contains(header.TCPFlagAck): if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be @@ -736,6 +740,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err mss: rcvdSynOptions.MSS, }) + // Requeue the segment if the ACK completing the handshake has more info + // to be procesed by the newly established endpoint. + if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && n.enqueueSegment(s) { + s.incRef() + n.newSegmentWaker.Assert() + } + // Do the delivery in a separate goroutine so // that we don't block the listen loop in case // the application is slow to accept or stops @@ -753,6 +764,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil default: + e.stack.Stats().DroppedPackets.Increment() return nil } } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5e03e7715..e39d1623d 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -19,7 +19,6 @@ import ( "math" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -92,7 +91,7 @@ type handshake struct { rcvWndScale int // startTime is the time at which the first SYN/SYN-ACK was sent. - startTime time.Time + startTime tcpip.MonotonicTime // deferAccept if non-zero will drop the final ACK for a passive // handshake till an ACK segment with data is received or the timeout is @@ -147,21 +146,16 @@ func FindWndScale(wnd seqnum.Size) int { // resetState resets the state of the handshake object such that it becomes // ready for a new 3-way handshake. func (h *handshake) resetState() { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - h.state = handshakeSynSent h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Seed()) + h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.stack.Seed()) } // generateSecureISN generates a secure Initial Sequence number based on the // recommendation here https://tools.ietf.org/html/rfc6528#page-3. -func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { +func generateSecureISN(id stack.TransportEndpointID, clock tcpip.Clock, seed uint32) seqnum.Value { isnHasher := jenkins.Sum32(seed) isnHasher.Write([]byte(id.LocalAddress)) isnHasher.Write([]byte(id.RemoteAddress)) @@ -180,7 +174,7 @@ func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { // // Which sort of guarantees that we won't reuse the ISN for a new // connection for the same tuple for at least 274s. - isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6) + isn := isnHasher.Sum32() + uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Nanoseconds()>>6) return seqnum.Value(isn) } @@ -212,7 +206,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea // a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in // response. func (h *handshake) checkAck(s *segment) bool { - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber != h.iss+1 { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber != h.iss+1 { // RFC 793, page 36, states that a reset must be generated when // the connection is in any non-synchronized state and an // incoming segment acknowledges something not yet sent. The @@ -230,8 +224,8 @@ func (h *handshake) checkAck(s *segment) bool { func (h *handshake) synSentState(s *segment) tcpip.Error { // RFC 793, page 37, states that in the SYN-SENT state, a reset is // acceptable if the ack field acknowledges the SYN. - if s.flagIsSet(header.TCPFlagRst) { - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 { + if s.flags.Contains(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == h.iss+1 { // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK // was acceptable then signal the user "error: connection reset", drop // the segment, enter CLOSED state, delete TCB, and return." @@ -249,7 +243,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { // We are in the SYN-SENT state. We only care about segments that have // the SYN flag. - if !s.flagIsSet(header.TCPFlagSyn) { + if !s.flags.Contains(header.TCPFlagSyn) { return nil } @@ -270,7 +264,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { // If this is a SYN ACK response, we only need to acknowledge the SYN // and the handshake is completed. - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { h.state = handshakeCompleted h.ep.transitionToStateEstablishedLocked(h) @@ -316,7 +310,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { // synRcvdState handles a segment received when the TCP 3-way handshake is in // the SYN-RCVD state. func (h *handshake) synRcvdState(s *segment) tcpip.Error { - if s.flagIsSet(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagRst) { // RFC 793, page 37, states that in the SYN-RCVD state, a reset // is acceptable if the sequence number is in the window. if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { @@ -340,13 +334,13 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { return nil } - if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 { + if s.flags.Contains(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 { // We received two SYN segments with different sequence // numbers, so we reset this and restart the whole // process, except that we don't reset the timer. ack := s.sequenceNumber.Add(s.logicalLen()) seq := seqnum.Value(0) - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { seq = s.ackNumber } h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0) @@ -378,10 +372,10 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { // We have previously received (and acknowledged) the peer's SYN. If the // peer acknowledges our SYN, the handshake is completed. - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { // If deferAccept is not zero and this is a bare ACK and the // timeout is not hit then drop the ACK. - if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept { + if h.deferAccept != 0 && s.data.Size() == 0 && h.ep.stack.Clock().NowMonotonic().Sub(h.startTime) < h.deferAccept { h.acked = true h.ep.stack.Stats().DroppedPackets.Increment() return nil @@ -412,11 +406,11 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { h.ep.transitionToStateEstablishedLocked(h) - // If the segment has data then requeue it for the receiver - // to process it again once main loop is started. - if s.data.Size() > 0 { + // Requeue the segment if the ACK completing the handshake has more info + // to be procesed by the newly established endpoint. + if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && h.ep.enqueueSegment(s) { s.incRef() - h.ep.enqueueSegment(s) + h.ep.newSegmentWaker.Assert() } return nil } @@ -426,7 +420,7 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { func (h *handshake) handleSegment(s *segment) tcpip.Error { h.sndWnd = s.window - if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 { + if !s.flags.Contains(header.TCPFlagSyn) && h.sndWndScale > 0 { h.sndWnd <<= uint8(h.sndWndScale) } @@ -474,7 +468,7 @@ func (h *handshake) processSegments() tcpip.Error { // start sends the first SYN/SYN-ACK. It does not block, even if link address // resolution is required. func (h *handshake) start() { - h.startTime = time.Now() + h.startTime = h.ep.stack.Clock().NowMonotonic() h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { @@ -527,7 +521,7 @@ func (h *handshake) complete() tcpip.Error { defer s.Done() // Initialize the resend timer. - timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert) + timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert) if err != nil { return err } @@ -552,7 +546,7 @@ func (h *handshake) complete() tcpip.Error { // The last is required to provide a way for the peer to complete // the connection with another ACK or data (as ACKs are never // retransmitted on their own). - if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { + if h.active || !h.acked || h.deferAccept != 0 && h.ep.stack.Clock().NowMonotonic().Sub(h.startTime) > h.deferAccept { h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, @@ -608,15 +602,15 @@ func (h *handshake) complete() tcpip.Error { type backoffTimer struct { timeout time.Duration maxTimeout time.Duration - t *time.Timer + t tcpip.Timer } -func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) { +func newBackoffTimer(clock tcpip.Clock, timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) { if timeout > maxTimeout { return nil, &tcpip.ErrTimeout{} } bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} - bt.t = time.AfterFunc(timeout, f) + bt.t = clock.AfterFunc(timeout, f) return bt, nil } @@ -634,7 +628,7 @@ func (bt *backoffTimer) stop() { } func parseSynSegmentOptions(s *segment) header.TCPSynOptions { - synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck)) + synOpts := header.ParseSynOptions(s.options, s.flags.Contains(header.TCPFlagAck)) if synOpts.TS { s.parsedOptions.TSVal = synOpts.TSVal s.parsedOptions.TSEcr = synOpts.TSEcr @@ -915,30 +909,13 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, se return err } -func (e *endpoint) handleWrite() { - e.sndQueueInfo.sndQueueMu.Lock() - next := e.drainSendQueueLocked() - e.sndQueueInfo.sndQueueMu.Unlock() - - e.sendData(next) -} - -// Move packets from send queue to send list. -// -// Precondition: e.sndBufMu must be locked. -func (e *endpoint) drainSendQueueLocked() *segment { - first := e.sndQueueInfo.sndQueue.Front() - if first != nil { - e.snd.writeList.PushBackList(&e.sndQueueInfo.sndQueue) - e.sndQueueInfo.SndBufInQueue = 0 - } - return first -} - // Precondition: e.mu must be locked. func (e *endpoint) sendData(next *segment) { // Initialize the next segment to write if it's currently nil. if e.snd.writeNext == nil { + if next == nil { + return + } e.snd.writeNext = next } @@ -946,17 +923,6 @@ func (e *endpoint) sendData(next *segment) { e.snd.sendData() } -func (e *endpoint) handleClose() { - if !e.EndpointState().connected() { - return - } - // Drain the send queue. - e.handleWrite() - - // Mark send side as closed. - e.snd.Closed = true -} - // resetConnectionLocked puts the endpoint in an error state with the given // error code and sends a RST if and only if the error is not ErrConnectionReset // indicating that the connection is being reset due to receiving a RST. This @@ -1136,7 +1102,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err tcpip.Error) { func (e *endpoint) handleSegmentsLocked(fastPath bool) tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { - if e.EndpointState().closed() { + if state := e.EndpointState(); state.closed() || state == StateTimeWait { return nil } s := e.segmentQueue.dequeue() @@ -1188,11 +1154,11 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) // the TCPEndpointState after the segment is processed. defer e.probeSegmentLocked() - if s.flagIsSet(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagRst) { if ok, err := e.handleReset(s); !ok { return false, err } - } else if s.flagIsSet(header.TCPFlagSyn) { + } else if s.flags.Contains(header.TCPFlagSyn) { // See: https://tools.ietf.org/html/rfc5961#section-4.1 // 1) If the SYN bit is set, irrespective of the sequence number, TCP // MUST send an ACK (also referred to as challenge ACK) to the remote @@ -1216,7 +1182,7 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) // should then rely on SYN retransmission from the remote end to // re-establish the connection. e.snd.maybeSendOutOfWindowAck(s) - } else if s.flagIsSet(header.TCPFlagAck) { + } else if s.flags.Contains(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. s.window <<= e.snd.SndWndScale @@ -1235,7 +1201,7 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. - state := e.state + state := e.EndpointState() if state == StateClose { // When we get into StateClose while processing from the queue, // return immediately and let the protocolMainloop handle it. @@ -1267,7 +1233,7 @@ func (e *endpoint) keepaliveTimerExpired() tcpip.Error { // If a userTimeout is set then abort the connection if it is // exceeded. - if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { + if userTimeout != 0 && e.stack.Clock().NowMonotonic().Sub(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() return &tcpip.ErrTimeout{} @@ -1322,7 +1288,7 @@ func (e *endpoint) disableKeepaliveTimer() { // segments. func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { e.mu.Lock() - var closeTimer *time.Timer + var closeTimer tcpip.Timer var closeWaker sleep.Waker epilogue := func() { @@ -1408,14 +1374,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ { w: &e.sndQueueInfo.sndWaker, f: func() tcpip.Error { - e.handleWrite() - return nil - }, - }, - { - w: &e.sndQueueInfo.sndCloseWaker, - f: func() tcpip.Error { - e.handleClose() + e.sendData(nil /* next */) return nil }, }, @@ -1480,11 +1439,19 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ return &tcpip.ErrConnectionReset{} } - if n¬ifyClose != 0 && closeTimer == nil { - if e.EndpointState() == StateFinWait2 && e.closed { + if n¬ifyClose != 0 && e.closed { + switch e.EndpointState() { + case StateEstablished: + // Perform full shutdown if the endpoint is still + // established. This can occur when notifyClose + // was asserted just before becoming established. + e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) + case StateFinWait2: // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. - closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + if closeTimer == nil { + closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + } } } @@ -1721,7 +1688,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { var timeWaitWaker sleep.Waker s.AddWaker(&timeWaitWaker, timeWaitDone) - timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert) + timeWaitTimer := e.stack.Clock().AfterFunc(timeWaitDuration, timeWaitWaker.Assert) defer timeWaitTimer.Stop() for { diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index 962f1d687..6985194bb 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -41,7 +41,7 @@ type cubicState struct { func newCubicCC(s *sender) *cubicState { return &cubicState{ TCPCubicState: stack.TCPCubicState{ - T: time.Now(), + T: s.ep.stack.Clock().NowMonotonic(), Beta: 0.7, C: 0.4, }, @@ -60,7 +60,7 @@ func (c *cubicState) enterCongestionAvoidance() { // https://tools.ietf.org/html/rfc8312#section-4.8 if c.numCongestionEvents == 0 { c.K = 0 - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() c.WLastMax = c.WMax c.WMax = float64(c.s.SndCwnd) } @@ -115,14 +115,15 @@ func (c *cubicState) cubicCwnd(t float64) float64 { // getCwnd returns the current congestion window as computed by CUBIC. // Refer: https://tools.ietf.org/html/rfc8312#section-4 func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int { - elapsed := time.Since(c.T).Seconds() + elapsed := c.s.ep.stack.Clock().NowMonotonic().Sub(c.T) + elapsedSeconds := elapsed.Seconds() // Compute the window as per Cubic after 'elapsed' time // since last congestion event. - c.WC = c.cubicCwnd(elapsed - c.K) + c.WC = c.cubicCwnd(elapsedSeconds - c.K) // Compute the TCP friendly estimate of the congestion window. - c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsed/srtt.Seconds()) + c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsedSeconds/srtt.Seconds()) // Make sure in the TCP friendly region CUBIC performs at least // as well as Reno. @@ -134,7 +135,7 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int // In Concave/Convex region of CUBIC, calculate what CUBIC window // will be after 1 RTT and use that to grow congestion window // for every ack. - tEst := (time.Since(c.T) + srtt).Seconds() + tEst := (elapsed + srtt).Seconds() wtRtt := c.cubicCwnd(tEst - c.K) // As per 4.3 for each received ACK cwnd must be incremented // by (w_cubic(t+RTT) - cwnd/cwnd. @@ -151,7 +152,7 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int func (c *cubicState) HandleLossDetected() { // See: https://tools.ietf.org/html/rfc8312#section-4.5 c.numCongestionEvents++ - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() c.WLastMax = c.WMax c.WMax = float64(c.s.SndCwnd) @@ -162,7 +163,7 @@ func (c *cubicState) HandleLossDetected() { // HandleRTOExpired implements congestionContrl.HandleRTOExpired. func (c *cubicState) HandleRTOExpired() { // See: https://tools.ietf.org/html/rfc8312#section-4.6 - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() c.numCongestionEvents = 0 c.WLastMax = c.WMax c.WMax = float64(c.s.SndCwnd) @@ -193,7 +194,7 @@ func (c *cubicState) fastConvergence() { // PostRecovery implemements congestionControl.PostRecovery. func (c *cubicState) PostRecovery() { - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() } // reduceSlowStartThreshold returns new SsThresh as described in diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 512053a04..dff7cb89c 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -16,10 +16,11 @@ package tcp import ( "encoding/binary" + "math/rand" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -141,15 +142,16 @@ func (p *processor) start(wg *sync.WaitGroup) { // in-order. type dispatcher struct { processors []processor - seed uint32 - wg sync.WaitGroup + // seed is a random secret for a jenkins hash. + seed uint32 + wg sync.WaitGroup } -func (d *dispatcher) init(nProcessors int) { +func (d *dispatcher) init(rng *rand.Rand, nProcessors int) { d.close() d.wait() d.processors = make([]processor, nProcessors) - d.seed = generateRandUint32() + d.seed = rng.Uint32() for i := range d.processors { p := &d.processors[i] p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker) @@ -172,12 +174,11 @@ func (d *dispatcher) wait() { d.wg.Wait() } -func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) { ep := stackEP.(*endpoint) - s := newIncomingSegment(id, pkt) + s := newIncomingSegment(id, clock, pkt) if !s.parse(pkt.RXTransportChecksumValidated) { - ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() s.decRef() @@ -185,7 +186,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans } if !s.csumValid { - ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.ChecksumErrors.Increment() ep.stats.ReceiveErrors.ChecksumErrors.Increment() s.decRef() @@ -213,14 +213,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans d.selectProcessor(id).queueEndpoint(ep) } -func generateRandUint32() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - return binary.LittleEndian.Uint32(b) -} - func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { var payload [4]byte binary.LittleEndian.PutUint16(payload[0:], id.LocalPort) diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index f148d505d..5342aacfd 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -421,7 +421,7 @@ func testV4Accept(t *testing.T, c *context.Context) { r.Reset(data) nep.Write(&r, tcpip.WriteOptions{}) b = c.GetPacket() - tcp = header.TCP(header.IPv4(b).Payload()) + tcp = header.IPv4(b).Payload() if string(tcp.Payload()) != data { t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 50d39cbad..4acddc959 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -20,12 +20,12 @@ import ( "fmt" "io" "math" + "math/rand" "runtime" "strings" "sync/atomic" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,19 +38,15 @@ import ( ) // EndpointState represents the state of a TCP endpoint. -type EndpointState uint32 +type EndpointState tcpip.EndpointState // Endpoint states. Note that are represented in a netstack-specific manner and // may not be meaningful externally. Specifically, they need to be translated to // Linux's representation for these states if presented to userspace. const ( - // Endpoint states internal to netstack. These map to the TCP state CLOSED. - StateInitial EndpointState = iota - StateBound - StateConnecting // Connect() called, but the initial SYN hasn't been sent. - StateError - - // TCP protocol states. + _ EndpointState = iota + // TCP protocol states in sync with the definitions in + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/tcp_states.h#L13 StateEstablished StateSynSent StateSynRecv @@ -62,6 +58,12 @@ const ( StateLastAck StateListen StateClosing + + // Endpoint states internal to netstack. + StateInitial + StateBound + StateConnecting // Connect() called, but the initial SYN hasn't been sent. + StateError ) const ( @@ -97,6 +99,16 @@ func (s EndpointState) connecting() bool { } } +// internal returns true when the state is netstack internal. +func (s EndpointState) internal() bool { + switch s { + case StateInitial, StateBound, StateConnecting, StateError: + return true + default: + return false + } +} + // handshake returns true when s is one of the states representing an endpoint // in the middle of a TCP handshake. func (s EndpointState) handshake() bool { @@ -281,16 +293,9 @@ type sndQueueInfo struct { sndQueueMu sync.Mutex `state:"nosave"` stack.TCPSndBufState - // sndQueue holds segments that are ready to be sent. - sndQueue segmentList `state:"wait"` - - // sndWaker is used to signal the protocol goroutine when segments are - // added to the `sndQueue`. + // sndWaker is used to signal the protocol goroutine when there may be + // segments that need to be sent. sndWaker sleep.Waker `state:"manual"` - - // sndCloseWaker is used to notify the protocol goroutine when the send - // side is closed. - sndCloseWaker sleep.Waker `state:"manual"` } // rcvQueueInfo contains the endpoint's rcvQueue and associated metadata. @@ -422,12 +427,12 @@ type endpoint struct { // state must be read/set using the EndpointState()/setEndpointState() // methods. - state EndpointState `state:".(EndpointState)"` + state uint32 `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the // endpoint state at restore time as the socket is moved to it's correct // state. - origEndpointState EndpointState `state:"nosave"` + origEndpointState uint32 `state:"nosave"` isPortReserved bool `state:"manual"` isRegistered bool `state:"manual"` @@ -468,7 +473,7 @@ type endpoint struct { // recentTSTime is the unix time when we last updated // TCPEndpointStateInner.RecentTS. - recentTSTime time.Time `state:".(unixTime)"` + recentTSTime tcpip.MonotonicTime // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -626,7 +631,7 @@ type endpoint struct { // lastOutOfWindowAckTime is the time at which the an ACK was sent in response // to an out of window segment being received by this endpoint. - lastOutOfWindowAckTime time.Time `state:".(unixTime)"` + lastOutOfWindowAckTime tcpip.MonotonicTime } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -747,7 +752,7 @@ func (e *endpoint) ResumeWork() { // // Precondition: e.mu must be held to call this method. func (e *endpoint) setEndpointState(state EndpointState) { - oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + oldstate := EndpointState(atomic.LoadUint32(&e.state)) switch state { case StateEstablished: e.stack.Stats().TCP.CurrentEstablished.Increment() @@ -764,18 +769,18 @@ func (e *endpoint) setEndpointState(state EndpointState) { e.stack.Stats().TCP.CurrentEstablished.Decrement() } } - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32(&e.state, uint32(state)) } // EndpointState returns the current state of the endpoint. func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + return EndpointState(atomic.LoadUint32(&e.state)) } // setRecentTimestamp sets the recentTS field to the provided value. func (e *endpoint) setRecentTimestamp(recentTS uint32) { e.RecentTS = recentTS - e.recentTSTime = time.Now() + e.recentTSTime = e.stack.Clock().NowMonotonic() } // recentTimestamp returns the value of the recentTS field. @@ -806,11 +811,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue }, sndQueueInfo: sndQueueInfo{ TCPSndBufState: stack.TCPSndBufState{ - SndMTU: int(math.MaxInt32), + SndMTU: math.MaxInt32, }, }, waiterQueue: waiterQueue, - state: StateInitial, + state: uint32(StateInitial), keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -870,9 +875,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } e.segmentQueue.ep = e - e.TSOffset = timeStampOffset() + e.TSOffset = timeStampOffset(e.stack.Rand()) e.acceptCond = sync.NewCond(&e.acceptMu) - e.keepalive.timer.init(&e.keepalive.waker) + e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker) return e } @@ -1189,7 +1194,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvQueueInfo.rcvQueueMu.Unlock() return } - now := time.Now() + now := e.stack.Clock().NowMonotonic() if rtt := e.rcvQueueInfo.RcvAutoParams.RTT; rtt == 0 || now.Sub(e.rcvQueueInfo.RcvAutoParams.MeasureTime) < rtt { e.rcvQueueInfo.RcvAutoParams.CopiedBytes += copied e.rcvQueueInfo.rcvQueueMu.Unlock() @@ -1544,12 +1549,11 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } // Add data to the send queue. - s := newOutgoingSegment(e.TransportEndpointInfo.ID, v) + s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v) e.sndQueueInfo.SndBufUsed += len(v) - e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v)) - e.sndQueueInfo.sndQueue.PushBack(s) + e.snd.writeList.PushBack(s) - return e.drainSendQueueLocked(), len(v), nil + return s, len(v), nil }() // Return if either we didn't queue anything or if an error occurred while // attempting to queue data. @@ -1956,6 +1960,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption { info := tcpip.TCPInfoOption{} e.LockUser() + if state := e.EndpointState(); state.internal() { + info.State = tcpip.EndpointState(StateClose) + } else { + info.State = tcpip.EndpointState(state) + } snd := e.snd if snd != nil { // We do not calculate RTT before sending the data packets. If @@ -2198,7 +2207,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp BindToDevice: bindToDevice, Dest: addr, } - if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil { if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } @@ -2224,7 +2233,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but // less than 1 second has elapsed since its recentTS was updated then // we cannot reuse the port. - if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second { + if tcpEP.EndpointState() != StateTimeWait || e.stack.Clock().NowMonotonic().Sub(tcpEP.recentTSTime) < 1*time.Second { tcpEP.UnlockUser() return false, nil } @@ -2245,7 +2254,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp BindToDevice: bindToDevice, Dest: addr, } - if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil { return false, nil } } @@ -2297,7 +2306,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // connection setting here. if !handshake { e.segmentQueue.mu.Lock() - for _, l := range []segmentList{e.segmentQueue.list, e.sndQueueInfo.sndQueue, e.snd.writeList} { + for _, l := range []segmentList{e.segmentQueue.list, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { s.id = e.TransportEndpointInfo.ID e.sndQueueInfo.sndWaker.Assert() @@ -2355,6 +2364,9 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { e.notifyProtocolGoroutine(notifyTickleWorker) return nil } + // Wake up any readers that maybe waiting for the stream to become + // readable. + e.waiterQueue.Notify(waiter.ReadableEvents) } // Close for write. @@ -2370,13 +2382,21 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { } // Queue fin segment. - s := newOutgoingSegment(e.TransportEndpointInfo.ID, nil) - e.sndQueueInfo.sndQueue.PushBack(s) - e.sndQueueInfo.SndBufInQueue++ + s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), nil) + e.snd.writeList.PushBack(s) // Mark endpoint as closed. e.sndQueueInfo.SndClosed = true e.sndQueueInfo.sndQueueMu.Unlock() - e.handleClose() + + // Drain the send queue. + e.sendData(s) + + // Mark send side as closed. + e.snd.Closed = true + + // Wake up any writers that maybe waiting for the stream to become + // writable. + e.waiterQueue.Notify(waiter.WritableEvents) } return nil @@ -2581,7 +2601,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { BindToDevice: bindToDevice, Dest: tcpip.FullAddress{}, } - port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) { + port, err := e.stack.ReservePort(e.stack.Rand(), portRes, func(p uint16) (bool, tcpip.Error) { id := e.TransportEndpointInfo.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a @@ -2731,7 +2751,7 @@ func (e *endpoint) updateSndBufferUsage(v int) { // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndQueueInfo.SndBufUsed < int(sendBufferSize)>>1 + notify = notify && e.sndQueueInfo.SndBufUsed < sendBufferSize>>1 e.sndQueueInfo.sndQueueMu.Unlock() if notify { @@ -2848,23 +2868,20 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { // timestamp returns the timestamp value to be used in the TSVal field of the // timestamp option for outgoing TCP segments for a given endpoint. func (e *endpoint) timestamp() uint32 { - return tcpTimeStamp(time.Now(), e.TSOffset) + return tcpTimeStamp(e.stack.Clock().NowMonotonic(), e.TSOffset) } // tcpTimeStamp returns a timestamp offset by the provided offset. This is // not inlined above as it's used when SYN cookies are in use and endpoint // is not created at the time when the SYN cookie is sent. -func tcpTimeStamp(curTime time.Time, offset uint32) uint32 { - return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset +func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 { + d := curTime.Sub(tcpip.MonotonicTime{}) + return uint32(d.Milliseconds()) + offset } // timeStampOffset returns a randomized timestamp offset to be used when sending // timestamp values in a timestamp option for a TCP segment. -func timeStampOffset() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } +func timeStampOffset(rng *rand.Rand) uint32 { // Initialize a random tsOffset that will be added to the recentTS // everytime the timestamp is sent when the Timestamp option is enabled. // @@ -2874,7 +2891,7 @@ func timeStampOffset() uint32 { // NOTE: This is not completely to spec as normally this should be // initialized in a manner analogous to how sequence numbers are // randomized per connection basis. But for now this is sufficient. - return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 + return rng.Uint32() } // maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint @@ -2909,7 +2926,7 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState { s := stack.TCPEndpointState{ TCPEndpointStateInner: e.TCPEndpointStateInner, ID: stack.TCPEndpointID(e.TransportEndpointInfo.ID), - SegTime: time.Now(), + SegTime: e.stack.Clock().NowMonotonic(), Receiver: e.rcv.TCPReceiverState, Sender: e.snd.TCPSenderState, } @@ -2937,7 +2954,7 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState { if cubic, ok := e.snd.cc.(*cubicState); ok { s.Sender.Cubic = cubic.TCPCubicState - s.Sender.Cubic.TimeSinceLastCongestion = time.Since(s.Sender.Cubic.T) + s.Sender.Cubic.TimeSinceLastCongestion = e.stack.Clock().NowMonotonic().Sub(s.Sender.Cubic.T) } s.Sender.RACKState = e.snd.rc.TCPRACKState @@ -3029,14 +3046,16 @@ func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption { // allowOutOfWindowAck returns true if an out-of-window ACK can be sent now. func (e *endpoint) allowOutOfWindowAck() bool { - var limit stack.TCPInvalidRateLimitOption - if err := e.stack.Option(&limit); err != nil { - panic(fmt.Sprintf("e.stack.Option(%+v) failed with error: %s", limit, err)) - } + now := e.stack.Clock().NowMonotonic() - now := time.Now() - if now.Sub(e.lastOutOfWindowAckTime) < time.Duration(limit) { - return false + if e.lastOutOfWindowAckTime != (tcpip.MonotonicTime{}) { + var limit stack.TCPInvalidRateLimitOption + if err := e.stack.Option(&limit); err != nil { + panic(fmt.Sprintf("e.stack.Option(%+v) failed with error: %s", limit, err)) + } + if now.Sub(e.lastOutOfWindowAckTime) < time.Duration(limit) { + return false + } } e.lastOutOfWindowAckTime = now diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 6e9777fe4..952ccacdd 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -154,20 +154,25 @@ func (e *endpoint) afterLoad() { e.origEndpointState = e.state // Restore the endpoint to InitialState as it will be moved to // its origEndpointState during Resume. - e.state = StateInitial + e.state = uint32(StateInitial) // Condition variables and mutexs are not S/R'ed so reinitialize // acceptCond with e.acceptMu. e.acceptCond = sync.NewCond(&e.acceptMu) - e.keepalive.timer.init(&e.keepalive.waker) stack.StackFromEnv.RegisterRestoredEndpoint(e) } // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.keepalive.timer.init(s.Clock(), &e.keepalive.waker) + if snd := e.snd; snd != nil { + snd.resendTimer.init(s.Clock(), &snd.resendWaker) + snd.reorderTimer.init(s.Clock(), &snd.reorderWaker) + snd.probeTimer.init(s.Clock(), &snd.probeWaker) + } e.stack = s e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.segmentQueue.thaw() - epState := e.origEndpointState + epState := EndpointState(e.origEndpointState) switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss tcpip.TCPSendBufferSizeRangeOption @@ -281,32 +286,12 @@ func (e *endpoint) Resume(s *stack.Stack) { }() case epState == StateClose: e.isPortReserved = false - e.state = StateClose + e.state = uint32(StateClose) e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) case epState == StateError: - e.state = StateError + e.state = uint32(StateError) e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } } - -// saveRecentTSTime is invoked by stateify. -func (e *endpoint) saveRecentTSTime() unixTime { - return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()} -} - -// loadRecentTSTime is invoked by stateify. -func (e *endpoint) loadRecentTSTime(unix unixTime) { - e.recentTSTime = time.Unix(unix.second, unix.nano) -} - -// saveLastOutOfWindowAckTime is invoked by stateify. -func (e *endpoint) saveLastOutOfWindowAckTime() unixTime { - return unixTime{e.lastOutOfWindowAckTime.Unix(), e.lastOutOfWindowAckTime.UnixNano()} -} - -// loadLastOutOfWindowAckTime is invoked by stateify. -func (e *endpoint) loadLastOutOfWindowAckTime(unix unixTime) { - e.lastOutOfWindowAckTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 2f9fe7ee0..65c86823a 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -65,7 +65,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - s := newIncomingSegment(id, pkt) + s := newIncomingSegment(id, f.stack.Clock(), pkt) defer s.decRef() // We only care about well-formed SYN packets. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index a3d1aa1a3..2fc282e73 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -131,7 +131,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - p.dispatcher.queuePacket(ep, id, pkt) + p.dispatcher.queuePacket(ep, id, p.stack.Clock(), pkt) } // HandleUnknownDestinationPacket handles packets targeted at this protocol but @@ -142,14 +142,14 @@ func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEnd // particular, SYNs addressed to a non-existent connection are rejected by this // means." func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { - s := newIncomingSegment(id, pkt) + s := newIncomingSegment(id, p.stack.Clock(), pkt) defer s.decRef() if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid { return stack.UnknownDestinationPacketMalformed } - if !s.flagIsSet(header.TCPFlagRst) { + if !s.flags.Contains(header.TCPFlagRst) { replyWithReset(p.stack, s, stack.DefaultTOS, 0) } @@ -181,7 +181,7 @@ func replyWithReset(st *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error { // reset has sequence number zero and the ACK field is set to the sum // of the sequence number and segment length of the incoming segment. // The connection remains in the CLOSED state. - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { seq = s.ackNumber } else { flags |= header.TCPFlagAck @@ -401,7 +401,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er case *tcpip.TCPTimeWaitReuseOption: p.mu.RLock() - *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse) + *v = p.timeWaitReuse p.mu.RUnlock() return nil @@ -481,6 +481,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { // TODO(gvisor.dev/issue/5243): Set recovery to tcpip.TCPRACKLossDetection. recovery: 0, } - p.dispatcher.init(runtime.GOMAXPROCS(0)) + p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0)) return &p } diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 9e332dcf7..0da4eafaa 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -79,7 +79,7 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) { // update will update the RACK related fields when an ACK has been received. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2 func (rc *rackControl) update(seg *segment, ackSeg *segment) { - rtt := time.Now().Sub(seg.xmitTime) + rtt := rc.snd.ep.stack.Clock().NowMonotonic().Sub(seg.xmitTime) tsOffset := rc.snd.ep.TSOffset // If the ACK is for a retransmitted packet, do not update if it is a @@ -115,7 +115,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { // ending sequence number of the packet which has been acknowledged // most recently. endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime == rc.XmitTime && rc.EndSequence.LessThan(endSeq)) { rc.XmitTime = seg.xmitTime rc.EndSequence = endSeq } @@ -174,7 +174,7 @@ func (s *sender) schedulePTO() { } s.rtt.Unlock() - now := time.Now() + now := s.ep.stack.Clock().NowMonotonic() if s.resendTimer.enabled() { if now.Add(pto).After(s.resendTimer.target) { pto = s.resendTimer.target.Sub(now) @@ -279,7 +279,7 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { // been observed RACK uses reo_wnd of zero during loss recovery, in order to // retransmit quickly, or when the number of DUPACKs exceeds the classic // DUPACKthreshold. -func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { +func (rc *rackControl) updateRACKReorderWindow() { dsackSeen := rc.DSACKSeen snd := rc.snd @@ -352,7 +352,7 @@ func (rc *rackControl) exitRecovery() { // detectLoss marks the segment as lost if the reordering window has elapsed // and the ACK is not received. It will also arm the reorder timer. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 Step 5. -func (rc *rackControl) detectLoss(rcvTime time.Time) int { +func (rc *rackControl) detectLoss(rcvTime tcpip.MonotonicTime) int { var timeout time.Duration numLost := 0 for seg := rc.snd.writeList.Front(); seg != nil && seg.xmitCount != 0; seg = seg.Next() { @@ -366,7 +366,7 @@ func (rc *rackControl) detectLoss(rcvTime time.Time) int { } endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime == rc.XmitTime && rc.EndSequence.LessThan(endSeq)) { timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.RTT + rc.ReoWnd if timeRemaining <= 0 { seg.lost = true @@ -392,7 +392,7 @@ func (rc *rackControl) reorderTimerExpired() tcpip.Error { return nil } - numLost := rc.detectLoss(time.Now()) + numLost := rc.detectLoss(rc.snd.ep.stack.Clock().NowMonotonic()) if numLost == 0 { return nil } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 133371455..9ce8fcae9 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -17,7 +17,6 @@ package tcp import ( "container/heap" "math" - "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -50,7 +49,7 @@ type receiver struct { pendingRcvdSegments segmentHeap // Time when the last ack was received. - lastRcvdAckTime time.Time `state:".(unixTime)"` + lastRcvdAckTime tcpip.MonotonicTime } func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { @@ -63,7 +62,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale }, rcvWnd: rcvWnd, rcvWUP: irs + 1, - lastRcvdAckTime: time.Now(), + lastRcvdAckTime: ep.stack.Clock().NowMonotonic(), } } @@ -137,9 +136,9 @@ func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) { // rcvWUP RcvNxt RcvAcc new RcvAcc // <=====curWnd ===> // <========= newWnd > curWnd ========= > - if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow { + if r.RcvNxt.Add(curWnd).LessThan(r.RcvNxt.Add(newWnd)) && toGrow { // If the new window moves the right edge, then update RcvAcc. - r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd)) + r.RcvAcc = r.RcvNxt.Add(newWnd) } else { if newWnd == 0 { // newWnd is zero but we can't advertise a zero as it would cause window @@ -245,7 +244,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum TrimSACKBlockList(&r.ep.sack, r.RcvNxt) // Handle FIN or FIN-ACK. - if s.flagIsSet(header.TCPFlagFin) { + if s.flags.Contains(header.TCPFlagFin) { r.RcvNxt++ // Send ACK immediately. @@ -261,7 +260,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateEstablished: r.ep.setEndpointState(StateCloseWait) case StateFinWait1: - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { // FIN-ACK, transition to TIME-WAIT. r.ep.setEndpointState(StateTimeWait) } else { @@ -296,7 +295,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { switch r.ep.EndpointState() { case StateFinWait1: r.ep.setEndpointState(StateFinWait2) @@ -325,9 +324,9 @@ func (r *receiver) updateRTT() { // is first acknowledged and the receipt of data that is at least one // window beyond the sequence number that was acknowledged. r.ep.rcvQueueInfo.rcvQueueMu.Lock() - if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime.IsZero() { + if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime == (tcpip.MonotonicTime{}) { // New measurement. - r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = r.ep.stack.Clock().NowMonotonic() r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return @@ -336,14 +335,14 @@ func (r *receiver) updateRTT() { r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return } - rtt := time.Since(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime) + rtt := r.ep.stack.Clock().NowMonotonic().Sub(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime) // We only store the minimum observed RTT here as this is only used in // absence of a SRTT available from either timestamps or a sender // measurement of RTT. if r.ep.rcvQueueInfo.RcvAutoParams.RTT == 0 || rtt < r.ep.rcvQueueInfo.RcvAutoParams.RTT { r.ep.rcvQueueInfo.RcvAutoParams.RTT = rtt } - r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = r.ep.stack.Clock().NowMonotonic() r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) r.ep.rcvQueueInfo.rcvQueueMu.Unlock() } @@ -424,7 +423,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // while the FIN is considered to occur after // the last actual data octet in a segment in // which it occurs. - if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) { + if closed && (!s.flags.Contains(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) { return true, &tcpip.ErrConnectionAborted{} } } @@ -467,11 +466,11 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { } // Store the time of the last ack. - r.lastRcvdAckTime = time.Now() + r.lastRcvdAckTime = r.ep.stack.Clock().NowMonotonic() // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { - if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { + if segLen > 0 || s.flags.Contains(header.TCPFlagFin) { // We only store the segment if it's within our buffer size limit. // // Only use 75% of the receive buffer queue for out-of-order @@ -539,7 +538,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // // As we do not yet support PAWS, we are being conservative in ignoring // RSTs by default. - if s.flagIsSet(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagRst) { return false, false } @@ -559,13 +558,12 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // (2) returns to TIME-WAIT state if the SYN turns out // to be an old duplicate". - if s.flagIsSet(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) { - + if s.flags.Contains(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) { return false, true } // Drop the segment if it does not contain an ACK. - if !s.flagIsSet(header.TCPFlagAck) { + if !s.flags.Contains(header.TCPFlagAck) { return false, false } @@ -574,7 +572,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.MaxSentAck, segSeq) } - if segSeq.Add(1) == r.RcvNxt && s.flagIsSet(header.TCPFlagFin) { + if segSeq.Add(1) == r.RcvNxt && s.flags.Contains(header.TCPFlagFin) { // If it's a FIN-ACK then resetTimeWait and send an ACK, as it // indicates our final ACK could have been lost. r.ep.snd.sendAck() diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go deleted file mode 100644 index 2bf21a2e7..000000000 --- a/pkg/tcpip/transport/tcp/rcv_state.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "time" -) - -// saveLastRcvdAckTime is invoked by stateify. -func (r *receiver) saveLastRcvdAckTime() unixTime { - return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()} -} - -// loadLastRcvdAckTime is invoked by stateify. -func (r *receiver) loadLastRcvdAckTime(unix unixTime) { - r.lastRcvdAckTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 7e5ba6ef7..ca78c96f2 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -17,7 +17,6 @@ package tcp import ( "fmt" "sync/atomic" - "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -73,9 +72,9 @@ type segment struct { parsedOptions header.TCPOptions options []byte `state:".([]byte)"` hasNewSACKInfo bool - rcvdTime time.Time `state:".(unixTime)"` + rcvdTime tcpip.MonotonicTime // xmitTime is the last transmit time of this segment. - xmitTime time.Time `state:".(unixTime)"` + xmitTime tcpip.MonotonicTime xmitCount uint32 // acked indicates if the segment has already been SACKed. @@ -88,7 +87,7 @@ type segment struct { lost bool } -func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { +func newIncomingSegment(id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) *segment { netHdr := pkt.Network() s := &segment{ refCnt: 1, @@ -100,17 +99,17 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) * } s.data = pkt.Data().ExtractVV().Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) - s.rcvdTime = time.Now() + s.rcvdTime = clock.NowMonotonic() s.dataMemSize = s.data.Size() return s } -func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment { +func newOutgoingSegment(id stack.TransportEndpointID, clock tcpip.Clock, v buffer.View) *segment { s := &segment{ refCnt: 1, id: id, } - s.rcvdTime = time.Now() + s.rcvdTime = clock.NowMonotonic() if len(v) != 0 { s.views[0] = v s.data = buffer.NewVectorisedView(len(v), s.views[:1]) @@ -149,16 +148,6 @@ func (s *segment) merge(oth *segment) { oth.dataMemSize = oth.data.Size() } -// flagIsSet checks if at least one flag in flags is set in s.flags. -func (s *segment) flagIsSet(flags header.TCPFlags) bool { - return s.flags&flags != 0 -} - -// flagsAreSet checks if all flags in flags are set in s.flags. -func (s *segment) flagsAreSet(flags header.TCPFlags) bool { - return s.flags&flags == flags -} - // setOwner sets the owning endpoint for this segment. Its required // to be called to ensure memory accounting for receive/send buffer // queues is done properly. @@ -198,10 +187,10 @@ func (s *segment) incRef() { // as the data length plus one for each of the SYN and FIN bits set. func (s *segment) logicalLen() seqnum.Size { l := seqnum.Size(s.data.Size()) - if s.flagIsSet(header.TCPFlagSyn) { + if s.flags.Contains(header.TCPFlagSyn) { l++ } - if s.flagIsSet(header.TCPFlagFin) { + if s.flags.Contains(header.TCPFlagFin) { l++ } return l @@ -243,7 +232,7 @@ func (s *segment) parse(skipChecksumValidation bool) bool { return false } - s.options = []byte(s.hdr[header.TCPMinimumSize:]) + s.options = s.hdr[header.TCPMinimumSize:] s.parsedOptions = header.ParseTCPOptions(s.options) if skipChecksumValidation { s.csumValid = true @@ -262,5 +251,5 @@ func (s *segment) parse(skipChecksumValidation bool) bool { // sackBlock returns a header.SACKBlock that represents this segment. func (s *segment) sackBlock() header.SACKBlock { - return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())} + return header.SACKBlock{Start: s.sequenceNumber, End: s.sequenceNumber.Add(s.logicalLen())} } diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go index 7422d8c02..dcfa80f95 100644 --- a/pkg/tcpip/transport/tcp/segment_state.go +++ b/pkg/tcpip/transport/tcp/segment_state.go @@ -15,8 +15,6 @@ package tcp import ( - "time" - "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -55,23 +53,3 @@ func (s *segment) loadOptions(options []byte) { // allocated so there is no cost here. s.options = options } - -// saveRcvdTime is invoked by stateify. -func (s *segment) saveRcvdTime() unixTime { - return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()} -} - -// loadRcvdTime is invoked by stateify. -func (s *segment) loadRcvdTime(unix unixTime) { - s.rcvdTime = time.Unix(unix.second, unix.nano) -} - -// saveXmitTime is invoked by stateify. -func (s *segment) saveXmitTime() unixTime { - return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()} -} - -// loadXmitTime is invoked by stateify. -func (s *segment) loadXmitTime(unix unixTime) { - s.rcvdTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go index 486016fc0..2e6ea06f5 100644 --- a/pkg/tcpip/transport/tcp/segment_test.go +++ b/pkg/tcpip/transport/tcp/segment_test.go @@ -19,6 +19,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -39,10 +40,11 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW } func TestSegmentMerge(t *testing.T) { + var clock faketime.NullClock id := stack.TransportEndpointID{} - seg1 := newOutgoingSegment(id, buffer.NewView(10)) + seg1 := newOutgoingSegment(id, &clock, buffer.NewView(10)) defer seg1.decRef() - seg2 := newOutgoingSegment(id, buffer.NewView(20)) + seg2 := newOutgoingSegment(id, &clock, buffer.NewView(20)) defer seg2.decRef() checkSegmentSize(t, "seg1", seg1, segmentSizeWants{ diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index f43e86677..72d58dcff 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -94,7 +94,7 @@ type sender struct { // firstRetransmittedSegXmitTime is the original transmit time of // the first segment that was retransmitted due to RTO expiration. - firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` + firstRetransmittedSegXmitTime tcpip.MonotonicTime // zeroWindowProbing is set if the sender is currently probing // for zero receive window. @@ -169,7 +169,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint SndUna: iss + 1, SndNxt: iss + 1, RTTMeasureSeqNum: iss + 1, - LastSendTime: time.Now(), + LastSendTime: ep.stack.Clock().NowMonotonic(), MaxPayloadSize: maxPayloadSize, MaxSentAck: irs + 1, FastRecovery: stack.TCPFastRecoveryState{ @@ -197,9 +197,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s.SndWndScale = uint8(sndWndScale) } - s.resendTimer.init(&s.resendWaker) - s.reorderTimer.init(&s.reorderWaker) - s.probeTimer.init(&s.probeWaker) + s.resendTimer.init(s.ep.stack.Clock(), &s.resendWaker) + s.reorderTimer.init(s.ep.stack.Clock(), &s.reorderWaker) + s.probeTimer.init(s.ep.stack.Clock(), &s.probeWaker) s.updateMaxPayloadSize(int(ep.route.MTU()), 0) @@ -441,7 +441,7 @@ func (s *sender) retransmitTimerExpired() bool { // timeout since the first retransmission. uto := s.ep.userTimeout - if s.firstRetransmittedSegXmitTime.IsZero() { + if s.firstRetransmittedSegXmitTime == (tcpip.MonotonicTime{}) { // We store the original xmitTime of the segment that we are // about to retransmit as the retransmission time. This is // required as by the time the retransmitTimer has expired the @@ -450,7 +450,7 @@ func (s *sender) retransmitTimerExpired() bool { s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime } - elapsed := time.Since(s.firstRetransmittedSegXmitTime) + elapsed := s.ep.stack.Clock().NowMonotonic().Sub(s.firstRetransmittedSegXmitTime) remaining := s.maxRTO if uto != 0 { // Cap to the user specified timeout if one is specified. @@ -616,7 +616,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // 'S2' that meets the following 3 criteria for determinig // loss, the sequence range of one segment of up to SMSS // octects starting with S2 MUST be returned. - if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) { + if !s.ep.scoreboard.IsSACKED(header.SACKBlock{Start: segSeq, End: segSeq.Add(1)}) { // NextSeg(): // // (1.a) S2 is greater than HighRxt @@ -866,8 +866,8 @@ func (s *sender) enableZeroWindowProbing() { // We piggyback the probing on the retransmit timer with the // current retranmission interval, as we may start probing while // segment retransmissions. - if s.firstRetransmittedSegXmitTime.IsZero() { - s.firstRetransmittedSegXmitTime = time.Now() + if s.firstRetransmittedSegXmitTime == (tcpip.MonotonicTime{}) { + s.firstRetransmittedSegXmitTime = s.ep.stack.Clock().NowMonotonic() } s.resendTimer.enable(s.RTO) } @@ -875,7 +875,7 @@ func (s *sender) enableZeroWindowProbing() { func (s *sender) disableZeroWindowProbing() { s.zeroWindowProbing = false s.unackZeroWindowProbes = 0 - s.firstRetransmittedSegXmitTime = time.Time{} + s.firstRetransmittedSegXmitTime = tcpip.MonotonicTime{} s.resendTimer.disable() } @@ -925,7 +925,7 @@ func (s *sender) sendData() { // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding // the retrasmission timeout." - if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && time.Now().Sub(s.LastSendTime) > s.RTO { + if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && s.ep.stack.Clock().NowMonotonic().Sub(s.LastSendTime) > s.RTO { if s.SndCwnd > InitialCwnd { s.SndCwnd = InitialCwnd } @@ -1024,7 +1024,7 @@ func (s *sender) SetPipe() { if segEnd.LessThan(endSeq) { endSeq = segEnd } - sb := header.SACKBlock{startSeq, endSeq} + sb := header.SACKBlock{Start: startSeq, End: endSeq} // SetPipe(): // // After initializing pipe to zero, the following steps are @@ -1132,7 +1132,7 @@ func (s *sender) isDupAck(seg *segment) bool { // (b) The incoming acknowledgment carries no data. seg.logicalLen() == 0 && // (c) The SYN and FIN bits are both off. - !seg.flagIsSet(header.TCPFlagFin) && !seg.flagIsSet(header.TCPFlagSyn) && + !seg.flags.Intersects(header.TCPFlagFin|header.TCPFlagSyn) && // (d) the ACK number is equal to the greatest acknowledgment received on // the given connection (TCP.UNA from RFC793). seg.ackNumber == s.SndUna && @@ -1234,7 +1234,7 @@ func checkDSACK(rcvdSeg *segment) bool { func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Check if we can extract an RTT measurement from this ack. if !rcvdSeg.parsedOptions.TS && s.RTTMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { - s.updateRTO(time.Now().Sub(s.RTTMeasureTime)) + s.updateRTO(s.ep.stack.Clock().NowMonotonic().Sub(s.RTTMeasureTime)) s.RTTMeasureSeqNum = s.SndNxt } @@ -1444,7 +1444,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { if s.SndUna == s.SndNxt { s.Outstanding = 0 // Reset firstRetransmittedSegXmitTime to the zero value. - s.firstRetransmittedSegXmitTime = time.Time{} + s.firstRetransmittedSegXmitTime = tcpip.MonotonicTime{} s.resendTimer.disable() s.probeTimer.disable() } @@ -1455,7 +1455,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // * Upon receiving an ACK: // * Step 4: Update RACK reordering window - s.rc.updateRACKReorderWindow(rcvdSeg) + s.rc.updateRACKReorderWindow() // After the reorder window is calculated, detect any loss by checking // if the time elapsed after the segments are sent is greater than the @@ -1502,7 +1502,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() } } - seg.xmitTime = time.Now() + seg.xmitTime = s.ep.stack.Clock().NowMonotonic() seg.xmitCount++ seg.lost = false err := s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber) @@ -1527,7 +1527,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error { - s.LastSendTime = time.Now() + s.LastSendTime = s.ep.stack.Clock().NowMonotonic() if seq == s.RTTMeasureSeqNum { s.RTTMeasureTime = s.LastSendTime } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go deleted file mode 100644 index 2f805d8ce..000000000 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "time" -) - -// +stateify savable -type unixTime struct { - second int64 - nano int64 -} - -// afterLoad is invoked by stateify. -func (s *sender) afterLoad() { - s.resendTimer.init(&s.resendWaker) - s.reorderTimer.init(&s.reorderWaker) - s.probeTimer.init(&s.probeWaker) -} - -// saveFirstRetransmittedSegXmitTime is invoked by stateify. -func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime { - return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()} -} - -// loadFirstRetransmittedSegXmitTime is invoked by stateify. -func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) { - s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index c58361bc1..d6cf786a1 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -38,7 +38,7 @@ const ( func setStackRACKPermitted(t *testing.T, c *context.Context) { t.Helper() - opt := tcpip.TCPRecovery(tcpip.TCPRACKLossDetection) + opt := tcpip.TCPRACKLossDetection if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil { t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err) } @@ -50,7 +50,7 @@ func TestRACKUpdate(t *testing.T) { c := context.New(t, uint32(mtu)) defer c.Cleanup() - var xmitTime time.Time + var xmitTime tcpip.MonotonicTime probeDone := make(chan struct{}) c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { // Validate that the endpoint Sender.RACKState is what we expect. @@ -79,7 +79,7 @@ func TestRACKUpdate(t *testing.T) { } // Write the data. - xmitTime = time.Now() + xmitTime = c.Stack().Clock().NowMonotonic() var r bytes.Reader r.Reset(data) if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9916182e3..71c4aa85d 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3451,17 +3451,13 @@ loop: for { switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { case *tcpip.ErrWouldBlock: - select { - case <-ch: - // Expect the state to be StateError and subsequent Reads to fail with HardError. - _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) - if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { - t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) - } - break loop - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for reset to arrive") + <-ch + // Expect the state to be StateError and subsequent Reads to fail with HardError. + _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) + if d := cmp.Diff(&tcpip.ErrConnectionReset{}, err); d != "" { + t.Fatalf("c.EP.Read() mismatch (-want +got):\n%s", d) } + break loop case *tcpip.ErrConnectionReset: break loop default: @@ -3472,14 +3468,27 @@ loop: if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } - if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { - t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got) + + checkValid := func() []error { + var errors []error + if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { + errors = append(errors, fmt.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + errors = append(errors, fmt.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + errors = append(errors, fmt.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)) + } + return errors } - if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + + start := time.Now() + for time.Since(start) < time.Minute && len(checkValid()) > 0 { + time.Sleep(50 * time.Millisecond) } - if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { - t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + for _, err := range checkValid() { + t.Error(err) } } @@ -4259,7 +4268,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(iss), + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -5335,7 +5344,7 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next-1)), + checker.TCPSeqNum(next-1), checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), @@ -5360,12 +5369,7 @@ func TestKeepalive(t *testing.T) { }) checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next)), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), + checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)), ) if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { @@ -5507,7 +5511,7 @@ func TestListenBacklogFull(t *testing.T) { // Now execute send one more SYN. The stack should not respond as the backlog // is full at this point. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + uint16(lastPortOffset), + SrcPort: context.TestPort + lastPortOffset, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: seqnum.Value(context.TestInitialSequenceNumber), @@ -5884,7 +5888,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { r.Reset(data) newEP.Write(&r, tcpip.WriteOptions{}) pkt := c.GetPacket() - tcp = header.TCP(header.IPv4(pkt).Payload()) + tcp = header.IPv4(pkt).Payload() if string(tcp.Payload()) != data { t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) } @@ -6073,6 +6077,11 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // complete the connection to test that the large SEQ num // did not change the state from SYN-RCVD. + // Get setup to be notified about connection establishment. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.ReadableEvents) + defer c.WQ.EventUnregister(&we) + // Send ACK to move to ESTABLISHED state. c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -6083,32 +6092,12 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { RcvWnd: 30000, }) + <-ch newEP, _, err := c.EP.Accept(nil) - switch err.(type) { - case nil, *tcpip.ErrWouldBlock: - default: + if err != nil { t.Fatalf("Accept failed: %s", err) } - if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { - // Try to accept the connections in the backlog. - we, ch := waiter.NewChannelEntry(nil) - c.WQ.EventRegister(&we, waiter.ReadableEvents) - defer c.WQ.EventUnregister(&we) - - // Wait for connection to be established. - select { - case <-ch: - newEP, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") - } - } - // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" var r strings.Reader @@ -6118,7 +6107,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { } pkt := c.GetPacket() - tcpHdr = header.TCP(header.IPv4(pkt).Payload()) + tcpHdr = header.IPv4(pkt).Payload() if string(tcpHdr.Payload()) != data { t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) } @@ -6214,12 +6203,26 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { RcvWnd: 30000, }) - time.Sleep(50 * time.Millisecond) - if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want) + checkValid := func() []error { + var errors []error + if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { + errors = append(errors, fmt.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)) + } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { + errors = append(errors, fmt.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)) + } + return errors } - if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { - t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want) + + start := time.Now() + for time.Since(start) < time.Minute && len(checkValid()) > 0 { + time.Sleep(50 * time.Millisecond) + } + for _, err := range checkValid() { + t.Error(err) + } + if t.Failed() { + t.FailNow() } we, ch := waiter.NewChannelEntry(nil) @@ -6230,19 +6233,62 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { _, _, err = c.EP.Accept(nil) if cmp.Equal(&tcpip.ErrWouldBlock{}, err) { // Wait for connection to be established. - select { - case <-ch: - _, _, err = c.EP.Accept(nil) - if err != nil { - t.Fatalf("Accept failed: %s", err) - } - - case <-time.After(1 * time.Second): - t.Fatalf("Timed out waiting for accept") + <-ch + _, _, err = c.EP.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) } } } +func TestListenDropIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + c.Create(-1 /*epRcvBuf*/) + + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := c.EP.Listen(1 /*backlog*/); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + initialDropped := stats.DroppedPackets.Value() + + // Send RST, FIN segments, that are expected to be dropped by the listener. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + }) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin, + }) + + // To ensure that the RST, FIN sent earlier are indeed received and ignored + // by the listener, send a SYN and wait for the SYN to be ACKd. + irs := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + }) + checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs)+1), + )) + + if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want { + t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want) + } +} + func TestEndpointBindListenAcceptState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -6375,7 +6421,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Allocate a large enough payload for the test. payloadSize := receiveBufferSize * 2 - b := make([]byte, int(payloadSize)) + b := make([]byte, payloadSize) worker := (c.EP).(interface { StopWork() @@ -6429,7 +6475,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // ack, 1 for the non-zero window p := c.GetPacket() checker.IPv4(t, p, checker.TCP( - checker.TCPAckNum(uint32(wantAckNum)), + checker.TCPAckNum(wantAckNum), func(t *testing.T, h header.Transport) { tcp, ok := h.(header.TCP) if !ok { @@ -6484,14 +6530,14 @@ func TestReceiveBufferAutoTuning(t *testing.T) { c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - tsVal := uint32(rawEP.TSVal) + tsVal := rawEP.TSVal rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) rawEP.NextSeqNum++ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale scaleRcvWnd := func(rcvWnd int) uint16 { - return uint16(rcvWnd >> uint16(c.WindowScale)) + return uint16(rcvWnd >> c.WindowScale) } // Allocate a large array to send to the endpoint. b := make([]byte, receiveBufferSize*48) @@ -6619,19 +6665,16 @@ func TestDelayEnabled(t *testing.T) { defer c.Cleanup() checkDelayOption(t, c, false, false) // Delay is disabled by default. - for _, v := range []struct { - delayEnabled tcpip.TCPDelayEnabled - wantDelayOption bool - }{ - {delayEnabled: false, wantDelayOption: false}, - {delayEnabled: true, wantDelayOption: true}, - } { - c := context.New(t, defaultMTU) - defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err) - } - checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) + for _, delayEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + opt := tcpip.TCPDelayEnabled(delayEnabled) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err) + } + checkDelayOption(t, c, opt, delayEnabled) + }) } } @@ -7042,7 +7085,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Receive the SYN-ACK reply. b = c.GetPacket() - tcpHdr = header.TCP(header.IPv4(b).Payload()) + tcpHdr = header.IPv4(b).Payload() c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) ackHeaders = &context.Headers{ @@ -7443,7 +7486,7 @@ func TestTCPUserTimeout(t *testing.T) { select { case <-notifyCh: case <-time.After(2 * initRTO): - t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout) + t.Fatalf("connection still alive after %s, should have been closed after %s", 2*initRTO, userTimeout) } // No packet should be received as the connection should be silently @@ -7467,7 +7510,7 @@ func TestTCPUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next)), + checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), @@ -7545,7 +7588,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: iss, - AckNum: seqnum.Value(c.IRS + 1), + AckNum: c.IRS + 1, RcvWnd: 30000, }) diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go index 38a335840..5645c772e 100644 --- a/pkg/tcpip/transport/tcp/timer.go +++ b/pkg/tcpip/transport/tcp/timer.go @@ -15,21 +15,29 @@ package tcp import ( + "math" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" ) type timerState int const ( + // The timer is disabled. timerStateDisabled timerState = iota + // The timer is enabled, but the clock timer may be set to an earlier + // expiration time due to a previous orphaned state. timerStateEnabled + // The timer is disabled, but the clock timer is enabled, which means that + // it will cause a spurious wakeup unless the timer is enabled before the + // clock timer fires. timerStateOrphaned ) // timer is a timer implementation that reduces the interactions with the -// runtime timer infrastructure by letting timers run (and potentially +// clock timer infrastructure by letting timers run (and potentially // eventually expire) even if they are stopped. It makes it cheaper to // disable/reenable timers at the expense of spurious wakes. This is useful for // cases when the same timer is disabled/reenabled repeatedly with relatively @@ -39,44 +47,37 @@ const ( // (currently at least 200ms), and get disabled when acks are received, and // reenabled when new pending segments are sent. // -// It is advantageous to avoid interacting with the runtime because it acquires +// It is advantageous to avoid interacting with the clock because it acquires // a global mutex and performs O(log n) operations, where n is the global number // of timers, whenever a timer is enabled or disabled, and may make a syscall. // // This struct is thread-compatible. type timer struct { - // state is the current state of the timer, it can be one of the - // following values: - // disabled - the timer is disabled. - // orphaned - the timer is disabled, but the runtime timer is - // enabled, which means that it will evetually cause a - // spurious wake (unless it gets enabled again before - // then). - // enabled - the timer is enabled, but the runtime timer may be set - // to an earlier expiration time due to a previous - // orphaned state. state timerState + clock tcpip.Clock + // target is the expiration time of the current timer. It is only // meaningful in the enabled state. - target time.Time + target tcpip.MonotonicTime - // runtimeTarget is the expiration time of the runtime timer. It is + // clockTarget is the expiration time of the clock timer. It is // meaningful in the enabled and orphaned states. - runtimeTarget time.Time + clockTarget tcpip.MonotonicTime - // timer is the runtime timer used to wait on. - timer *time.Timer + // timer is the clock timer used to wait on. + timer tcpip.Timer } // init initializes the timer. Once it expires, it the given waker will be // asserted. -func (t *timer) init(w *sleep.Waker) { +func (t *timer) init(clock tcpip.Clock, w *sleep.Waker) { t.state = timerStateDisabled + t.clock = clock - // Initialize a runtime timer that will assert the waker, then + // Initialize a clock timer that will assert the waker, then // immediately stop it. - t.timer = time.AfterFunc(time.Hour, func() { + t.timer = t.clock.AfterFunc(math.MaxInt64, func() { w.Assert() }) t.timer.Stop() @@ -106,9 +107,9 @@ func (t *timer) checkExpiration() bool { // The timer is enabled, but it may have expired early. Check if that's // the case, and if so, reset the runtime timer to the correct time. - now := time.Now() + now := t.clock.NowMonotonic() if now.Before(t.target) { - t.runtimeTarget = t.target + t.clockTarget = t.target t.timer.Reset(t.target.Sub(now)) return false } @@ -134,11 +135,11 @@ func (t *timer) enabled() bool { // enable enables the timer, programming the runtime timer if necessary. func (t *timer) enable(d time.Duration) { - t.target = time.Now().Add(d) + t.target = t.clock.NowMonotonic().Add(d) // Check if we need to set the runtime timer. - if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) { - t.runtimeTarget = t.target + if t.state == timerStateDisabled || t.target.Before(t.clockTarget) { + t.clockTarget = t.target t.timer.Reset(d) } diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go index dbd6dff54..479752de7 100644 --- a/pkg/tcpip/transport/tcp/timer_test.go +++ b/pkg/tcpip/transport/tcp/timer_test.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) func TestCleanup(t *testing.T) { @@ -27,9 +28,11 @@ func TestCleanup(t *testing.T) { isAssertedTimeoutSeconds = timerDurationSeconds + 1 ) + clock := faketime.NewManualClock() + tmr := timer{} w := sleep.Waker{} - tmr.init(&w) + tmr.init(clock, &w) tmr.enable(timerDurationSeconds * time.Second) tmr.cleanup() @@ -39,7 +42,7 @@ func TestCleanup(t *testing.T) { // The waker should not be asserted. for i := 0; i < isAssertedTimeoutSeconds; i++ { - time.Sleep(time.Second) + clock.Advance(time.Second) if w.IsAsserted() { t.Fatalf("waker asserted unexpectedly") } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index dd5c910ae..cdc344ab7 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -49,6 +49,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index f7dd50d35..def9d7186 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -17,6 +17,7 @@ package udp import ( "io" "sync/atomic" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -34,25 +35,26 @@ type udpPacket struct { destinationAddress tcpip.FullAddress packetInfo tcpip.IPPacketInfo data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + receivedAt time.Time `state:".(int64)"` // tos stores either the receiveTOS or receiveTClass value. tos uint8 } // EndpointState represents the state of a UDP endpoint. -type EndpointState uint32 +type EndpointState tcpip.EndpointState // Endpoint states. Note that are represented in a netstack-specific manner and // may not be meaningful externally. Specifically, they need to be translated to // Linux's representation for these states if presented to userspace. const ( - StateInitial EndpointState = iota + _ EndpointState = iota + StateInitial StateBound StateConnected StateClosed ) -// String implements fmt.Stringer.String. +// String implements fmt.Stringer. func (s EndpointState) String() string { switch s { case StateInitial: @@ -98,7 +100,7 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` // state must be read/set using the EndpointState()/setEndpointState() // methods. - state EndpointState + state uint32 route *stack.Route `state:"manual"` dstPort uint16 ttl uint8 @@ -176,7 +178,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // Linux defaults to TTL=1. multicastTTL: 1, multicastMemberships: make(map[multicastMembership]struct{}), - state: StateInitial, + state: uint32(StateInitial), uniqueID: s.UniqueID(), } e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) @@ -204,15 +206,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Precondition: e.mu must be held to call this method. func (e *endpoint) setEndpointState(state EndpointState) { - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32(&e.state, uint32(state)) } // EndpointState() returns the current state of the endpoint. func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + return EndpointState(atomic.LoadUint32(&e.state)) } -// UniqueID implements stack.TransportEndpoint.UniqueID. +// UniqueID implements stack.TransportEndpoint. func (e *endpoint) UniqueID() uint64 { return e.uniqueID } @@ -226,14 +228,14 @@ func (e *endpoint) LastError() tcpip.Error { return err } -// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +// UpdateLastError implements tcpip.SocketOptionsHandler. func (e *endpoint) UpdateLastError(err tcpip.Error) { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() } -// Abort implements stack.TransportEndpoint.Abort. +// Abort implements stack.TransportEndpoint. func (e *endpoint) Abort() { e.Close() } @@ -289,10 +291,10 @@ func (e *endpoint) Close() { e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } -// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +// ModerateRecvBuf implements tcpip.Endpoint. +func (*endpoint) ModerateRecvBuf(int) {} -// Read implements tcpip.Endpoint.Read. +// Read implements tcpip.Endpoint. func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { if err := e.LastError(); err != nil { return tcpip.ReadResult{}, err @@ -320,7 +322,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // Control Messages cm := tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: p.timestamp, + Timestamp: p.receivedAt.UnixNano(), } if e.ops.GetReceiveTOS() { cm.HasTOS = true @@ -581,21 +583,21 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return int64(len(v)), nil } -// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. +// OnReuseAddressSet implements tcpip.SocketOptionsHandler. func (e *endpoint) OnReuseAddressSet(v bool) { e.mu.Lock() e.portFlags.MostRecent = v e.mu.Unlock() } -// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet. +// OnReusePortSet implements tcpip.SocketOptionsHandler. func (e *endpoint) OnReusePortSet(v bool) { e.mu.Lock() e.portFlags.LoadBalanced = v e.mu.Unlock() } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +// SetSockOptInt implements tcpip.Endpoint. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.MTUDiscoverOption: @@ -629,11 +631,14 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { return nil } +var _ tcpip.SocketOptionsHandler = (*endpoint)(nil) + +// HasNIC implements tcpip.SocketOptionsHandler. func (e *endpoint) HasNIC(id int32) bool { - return id == 0 || e.stack.HasNIC(tcpip.NICID(id)) + return e.stack.HasNIC(tcpip.NICID(id)) } -// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +// SetSockOpt implements tcpip.Endpoint. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch v := opt.(type) { case *tcpip.MulticastInterfaceOption: @@ -749,7 +754,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { return nil } -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +// GetSockOptInt implements tcpip.Endpoint. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.IPv4TOSOption: @@ -795,14 +800,14 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +// GetSockOpt implements tcpip.Endpoint. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { switch o := opt.(type) { case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ - e.multicastNICID, - e.multicastAddr, + NIC: e.multicastNICID, + InterfaceAddr: e.multicastAddr, } e.mu.Unlock() @@ -872,7 +877,7 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres return unwrapped, netProto, nil } -// Disconnect implements tcpip.Endpoint.Disconnect. +// Disconnect implements tcpip.Endpoint. func (e *endpoint) Disconnect() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -1075,7 +1080,7 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id BindToDevice: bindToDevice, Dest: tcpip.FullAddress{}, } - port, err := e.stack.ReservePort(portRes, nil /* testPort */) + port, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */) if err != nil { return id, bindToDevice, err } @@ -1301,12 +1306,12 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB senderAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), }, destinationAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.LocalAddress, - Port: header.UDP(hdr).DestinationPort(), + Port: hdr.DestinationPort(), }, data: pkt.Data().ExtractVV(), } @@ -1328,7 +1333,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB packet.packetInfo.LocalAddr = localAddr packet.packetInfo.DestinationAddr = localAddr packet.packetInfo.NIC = pkt.NICID - packet.timestamp = e.stack.Clock().NowNanoseconds() + packet.receivedAt = e.stack.Clock().Now() e.rcvMu.Unlock() @@ -1386,7 +1391,7 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB } } -// State implements tcpip.Endpoint.State. +// State implements tcpip.Endpoint. func (e *endpoint) State() uint32 { return uint32(e.EndpointState()) } @@ -1405,19 +1410,19 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } -// Wait implements tcpip.Endpoint.Wait. +// Wait implements tcpip.Endpoint. func (*endpoint) Wait() {} func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr) } -// SetOwner implements tcpip.Endpoint.SetOwner. +// SetOwner implements tcpip.Endpoint. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// SocketOptions implements tcpip.Endpoint.SocketOptions. +// SocketOptions implements tcpip.Endpoint. func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 4aba68b21..1f638c3f6 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -15,26 +15,38 @@ package udp import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *udpPacket) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *udpPacket) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves udpPacket.data field. -func (u *udpPacket) saveData() buffer.VectorisedView { - // We cannot save u.data directly as u.data.views may alias to u.views, +func (p *udpPacket) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, // which is not allowed by state framework (in-struct pointer). - return u.data.Clone(nil) + return p.data.Clone(nil) } // loadData loads udpPacket.data field. -func (u *udpPacket) loadData(data buffer.VectorisedView) { - // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization +func (p *udpPacket) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization // here because data.views is not guaranteed to be loaded by now. Plus, // data.views will be allocated anyway so there really is little point - // of utilizing u.views for data.views. - u.data = data + // of utilizing p.views for data.views. + p.data = data } // afterLoad is invoked by stateify. diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 705ad1f64..7c357cb09 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -90,7 +90,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.RegisterNICID = r.pkt.NICID ep.boundPortFlags = ep.portFlags - ep.state = StateConnected + ep.state = uint32(StateConnected) ep.rcvMu.Lock() ep.rcvReady = true diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index dc2e3f493..4008cacf2 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,17 +16,16 @@ package udp_test import ( "bytes" - "context" "fmt" "io/ioutil" "math/rand" "testing" - "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -298,16 +297,18 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() - return newDualTestContextWithOptions(t, mtu, stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, - HandleLocal: true, - }) + return newDualTestContextWithHandleLocal(t, mtu, true) } -func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext { +func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext { t.Helper() + options := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: handleLocal, + Clock: &faketime.NullClock{}, + } s := stack.New(options) ep := channel.New(256, mtu, "") wep := stack.LinkEndpoint(ep) @@ -378,9 +379,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { c.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) + p, ok := c.linkEP.Read() if !ok { c.t.Fatalf("Packet wasn't written out") return nil @@ -534,7 +533,9 @@ func newMinPayload(minSize int) []byte { func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}}) + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: &faketime.NullClock{}, + }) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -606,7 +607,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe case <-ch: res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) - case <-time.After(300 * time.Millisecond): + default: if packetShouldBeDropped { return // expected to time out } @@ -820,11 +821,7 @@ func TestV4ReadSelfSource(t *testing.T) { {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1}, } { t.Run(tt.name, func(t *testing.T) { - c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - HandleLocal: tt.handleLocal, - }) + c := newDualTestContextWithHandleLocal(t, defaultMTU, tt.handleLocal) defer c.cleanup() c.createEndpointForFlow(unicastV4) @@ -1034,17 +1031,17 @@ func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, che payload := testWriteNoVerify(c, flow, setDest) // Received the packet and check the payload. b := c.getPacketAndVerify(flow, checkers...) - var udp header.UDP + var udpH header.UDP if flow.isV4() { - udp = header.UDP(header.IPv4(b).Payload()) + udpH = header.IPv4(b).Payload() } else { - udp = header.UDP(header.IPv6(b).Payload()) + udpH = header.IPv6(b).Payload() } - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + if !bytes.Equal(payload, udpH.Payload()) { + c.t.Fatalf("Bad payload: got %x, want %x", udpH.Payload(), payload) } - return udp.SourcePort() + return udpH.SourcePort() } func testDualWrite(c *testContext) uint16 { @@ -1198,7 +1195,7 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { r.Reset(payload) n, err := c.ep.Write(&r, writeOpts) if err != nil { - c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) + c.t.Fatalf("c.ep.Write(...) = %s, want nil", err) } if got, want := n, int64(len(payload)); got != want { c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) @@ -1462,7 +1459,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { name: "IPv4 unicast", proto: header.IPv4ProtocolNumber, flow: unicastV4, - expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackAddr, Port: stackPort}, }, { name: "IPv4 multicast", @@ -1474,7 +1471,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastAddr, Port: stackPort}, }, { name: "IPv4 broadcast", @@ -1486,13 +1483,13 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: broadcastAddr, Port: stackPort}, }, { name: "IPv6 unicast", proto: header.IPv6ProtocolNumber, flow: unicastV6, - expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackV6Addr, Port: stackPort}, }, { name: "IPv6 multicast", @@ -1504,7 +1501,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastV6Addr, Port: stackPort}, }, } @@ -1614,6 +1611,7 @@ func TestTTL(t *testing.T) { } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{p}, + Clock: &faketime.NullClock{}, }) ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil) wantTTL = ep.DefaultTTL() @@ -1759,7 +1757,7 @@ func TestReceiveTosTClass(t *testing.T) { c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) } - want := true + const want = true optionSetter(want) got := optionGetter() @@ -1889,18 +1887,14 @@ func TestV4UnknownDestination(t *testing.T) { } } if !tc.icmpRequired { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if p, ok := c.linkEP.ReadContext(ctx); ok { + if p, ok := c.linkEP.Read(); ok { t.Fatalf("unexpected packet received: %+v", p) } return } // ICMP required. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) + p, ok := c.linkEP.Read() if !ok { t.Fatalf("packet wasn't written out") return @@ -1987,18 +1981,14 @@ func TestV6UnknownDestination(t *testing.T) { } } if !tc.icmpRequired { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if p, ok := c.linkEP.ReadContext(ctx); ok { + if p, ok := c.linkEP.Read(); ok { t.Fatalf("unexpected packet received: %+v", p) } return } // ICMP required. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) + p, ok := c.linkEP.Read() if !ok { t.Fatalf("packet wasn't written out") return @@ -2115,8 +2105,8 @@ func TestShortHeader(t *testing.T) { Data: buf.ToVectorisedView(), })) - if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { - t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) + if got, want := c.s.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want { + t.Errorf("got c.s.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want) } } @@ -2124,25 +2114,27 @@ func TestShortHeader(t *testing.T) { // global and endpoint stats are incremented. func TestBadChecksumErrors(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV6} { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() + t.Run(flow.String(), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - c.createEndpoint(flow.sockProto()) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } + c.createEndpoint(flow.sockProto()) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } - payload := newPayload() - c.injectPacket(flow, payload, true /* badChecksum */) + payload := newPayload() + c.injectPacket(flow, payload, true /* badChecksum */) - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } + }) } } @@ -2484,6 +2476,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: &faketime.NullClock{}, }) e := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID1, e); err != nil { |