diff options
Diffstat (limited to 'pkg/tcpip')
109 files changed, 3922 insertions, 2570 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index ea46c30da..ed4d7e958 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -39,12 +39,30 @@ go_library( deps_test( name = "netstack_deps_test", allowed = [ + # gVisor deps. + "//pkg/atomicbitops", + "//pkg/buffer", + "//pkg/context", + "//pkg/gohacks", + "//pkg/goid", + "//pkg/ilist", + "//pkg/iovec", + "//pkg/linewriter", + "//pkg/log", + "//pkg/rand", + "//pkg/sleep", + "//pkg/state", + "//pkg/state/wire", + "//pkg/sync", + "//pkg/waiter", + + # Other deps. "@com_github_google_btree//:go_default_library", "@org_golang_x_sys//unix:go_default_library", "@org_golang_x_time//rate:go_default_library", ], allowed_prefixes = [ - "//", + "//pkg/tcpip", "@org_golang_x_sys//internal/unsafeheader", ], targets = [ diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 18e6cc3cd..e0dfe5813 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -50,7 +50,7 @@ func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { ipv4 := header.IPv4(b) if !ipv4.IsValid(len(b)) { - t.Error("Not a valid IPv4 packet") + t.Fatalf("Not a valid IPv4 packet: %x", ipv4) } if !ipv4.IsChecksumValid() { @@ -72,7 +72,7 @@ func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { ipv6 := header.IPv6(b) if !ipv6.IsValid(len(b)) { - t.Error("Not a valid IPv6 packet") + t.Fatalf("Not a valid IPv6 packet: %x", ipv6) } for _, f := range checkers { @@ -701,7 +701,7 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp if !ok { return } - opts := []byte(tcp.Options()) + opts := tcp.Options() limit := len(opts) foundTS := false tsVal := uint32(0) @@ -748,12 +748,6 @@ func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) Transp } } -// TCPNoSACKBlockChecker creates a checker that verifies that the segment does -// not contain any SACK blocks in the TCP options. -func TCPNoSACKBlockChecker() TransportChecker { - return TCPSACKBlockChecker(nil) -} - // TCPSACKBlockChecker creates a checker that verifies that the segment does // contain the specified SACK blocks in the TCP options. func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { @@ -765,7 +759,7 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { } var gotSACKBlocks []header.SACKBlock - opts := []byte(tcp.Options()) + opts := tcp.Options() limit := len(opts) for i := 0; i < limit; { switch opts[i] { diff --git a/pkg/tcpip/faketime/faketime.go b/pkg/tcpip/faketime/faketime.go index fb819d7a8..9f8f51647 100644 --- a/pkg/tcpip/faketime/faketime.go +++ b/pkg/tcpip/faketime/faketime.go @@ -29,14 +29,14 @@ type NullClock struct{} var _ tcpip.Clock = (*NullClock)(nil) -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (*NullClock) NowNanoseconds() int64 { - return 0 +// Now implements tcpip.Clock.Now. +func (*NullClock) Now() time.Time { + return time.Time{} } // NowMonotonic implements tcpip.Clock.NowMonotonic. -func (*NullClock) NowMonotonic() int64 { - return 0 +func (*NullClock) NowMonotonic() tcpip.MonotonicTime { + return tcpip.MonotonicTime{} } // AfterFunc implements tcpip.Clock.AfterFunc. @@ -118,16 +118,17 @@ func NewManualClock() *ManualClock { var _ tcpip.Clock = (*ManualClock)(nil) -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (mc *ManualClock) NowNanoseconds() int64 { +// Now implements tcpip.Clock.Now. +func (mc *ManualClock) Now() time.Time { mc.mu.RLock() defer mc.mu.RUnlock() - return mc.mu.now.UnixNano() + return mc.mu.now } // NowMonotonic implements tcpip.Clock.NowMonotonic. -func (mc *ManualClock) NowMonotonic() int64 { - return mc.NowNanoseconds() +func (mc *ManualClock) NowMonotonic() tcpip.MonotonicTime { + var mt tcpip.MonotonicTime + return mt.Add(mc.Now().Sub(time.Unix(0, 0))) } // AfterFunc implements tcpip.Clock.AfterFunc. @@ -218,6 +219,12 @@ func (mc *ManualClock) stopTimerLocked(mt *manualTimer) { } } +// RunImmediatelyScheduledJobs runs all jobs scheduled to run at the current +// time. +func (mc *ManualClock) RunImmediatelyScheduledJobs() { + mc.Advance(0) +} + // Advance executes all work that have been scheduled to execute within d from // the current time. Blocks until all work has completed execution. func (mc *ManualClock) Advance(d time.Duration) { diff --git a/pkg/tcpip/faketime/faketime_test.go b/pkg/tcpip/faketime/faketime_test.go index c2704df2c..fd2bb470a 100644 --- a/pkg/tcpip/faketime/faketime_test.go +++ b/pkg/tcpip/faketime/faketime_test.go @@ -26,7 +26,7 @@ func TestManualClockAdvance(t *testing.T) { clock := faketime.NewManualClock() start := clock.NowMonotonic() clock.Advance(timeout) - if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, timeout; got != want { + if got, want := clock.NowMonotonic().Sub(start), timeout; got != want { t.Errorf("got = %d, want = %d", got, want) } } @@ -87,7 +87,7 @@ func TestManualClockAfterFunc(t *testing.T) { if got, want := counter2, test.wantCounter2; got != want { t.Errorf("got counter2 = %d, want = %d", got, want) } - if got, want := time.Duration(clock.NowMonotonic()-start)*time.Nanosecond, test.advance; got != want { + if got, want := clock.NowMonotonic().Sub(start), test.advance; got != want { t.Errorf("got elapsed = %d, want = %d", got, want) } }) diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 2be21ec75..e9abbb709 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -17,6 +17,7 @@ package header import ( "encoding/binary" "fmt" + "time" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -181,7 +182,7 @@ const ( // ipv4LinkLocalUnicastSubnet is the IPv4 link local unicast subnet as defined // by RFC 3927 section 1. var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet { - subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", tcpip.AddressMask("\xff\xff\x00\x00")) + subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", "\xff\xff\x00\x00") if err != nil { panic(err) } @@ -191,7 +192,7 @@ var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet { // ipv4LinkLocalMulticastSubnet is the IPv4 link local multicast subnet as // defined by RFC 5771 section 4. var ipv4LinkLocalMulticastSubnet = func() tcpip.Subnet { - subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", tcpip.AddressMask("\xff\xff\xff\x00")) + subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", "\xff\xff\xff\x00") if err != nil { panic(err) } @@ -572,7 +573,7 @@ func (o *IPv4OptionGeneric) Type() IPv4OptionType { func (o *IPv4OptionGeneric) Size() uint8 { return uint8(len(*o)) } // Contents implements IPv4Option. -func (o *IPv4OptionGeneric) Contents() []byte { return []byte(*o) } +func (o *IPv4OptionGeneric) Contents() []byte { return *o } // IPv4OptionIterator is an iterator pointing to a specific IP option // at any point of time. It also holds information as to a new options buffer @@ -610,7 +611,7 @@ func (i *IPv4OptionIterator) InitReplacement(option IPv4Option) IPv4Options { // RemainingBuffer returns the remaining (unused) part of the new option buffer, // into which a new option may be written. func (i *IPv4OptionIterator) RemainingBuffer() IPv4Options { - return IPv4Options(i.newOptions[i.writePoint:]) + return i.newOptions[i.writePoint:] } // ConsumeBuffer marks a portion of the new buffer as used. @@ -813,9 +814,12 @@ const ( // ipv4TimestampTime provides the current time as specified in RFC 791. func ipv4TimestampTime(clock tcpip.Clock) uint32 { - const millisecondsPerDay = 24 * 3600 * 1000 - const nanoPerMilli = 1000000 - return uint32((clock.NowNanoseconds() / nanoPerMilli) % millisecondsPerDay) + // Per RFC 791 page 21: + // The Timestamp is a right-justified, 32-bit timestamp in + // milliseconds since midnight UT. + now := clock.Now().UTC() + midnight := now.Truncate(24 * time.Hour) + return uint32(now.Sub(midnight).Milliseconds()) } // IP Timestamp option fields. @@ -843,7 +847,7 @@ func (ts *IPv4OptionTimestamp) Type() IPv4OptionType { return IPv4OptionTimestam func (ts *IPv4OptionTimestamp) Size() uint8 { return uint8(len(*ts)) } // Contents implements IPv4Option. -func (ts *IPv4OptionTimestamp) Contents() []byte { return []byte(*ts) } +func (ts *IPv4OptionTimestamp) Contents() []byte { return *ts } // Pointer returns the pointer field in the IP Timestamp option. func (ts *IPv4OptionTimestamp) Pointer() uint8 { @@ -947,7 +951,7 @@ func (rr *IPv4OptionRecordRoute) Type() IPv4OptionType { return IPv4OptionRecord func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) } // Contents implements IPv4Option. -func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) } +func (rr *IPv4OptionRecordRoute) Contents() []byte { return *rr } // Router Alert option specific related constants. // @@ -992,7 +996,7 @@ func (*IPv4OptionRouterAlert) Type() IPv4OptionType { return IPv4OptionRouterAle func (ra *IPv4OptionRouterAlert) Size() uint8 { return uint8(len(*ra)) } // Contents implements IPv4Option. -func (ra *IPv4OptionRouterAlert) Contents() []byte { return []byte(*ra) } +func (ra *IPv4OptionRouterAlert) Contents() []byte { return *ra } // Value returns the value of the IPv4OptionRouterAlert. func (ra *IPv4OptionRouterAlert) Value() uint16 { diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 3d1bccd15..d6cad3a94 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -77,12 +77,12 @@ const ( // ndpPrefixInformationOnLinkFlagMask is the mask of the On-Link Flag // field in the flags byte within an NDPPrefixInformation. - ndpPrefixInformationOnLinkFlagMask = (1 << 7) + ndpPrefixInformationOnLinkFlagMask = 1 << 7 // ndpPrefixInformationAutoAddrConfFlagMask is the mask of the // Autonomous Address-Configuration flag field in the flags byte within // an NDPPrefixInformation. - ndpPrefixInformationAutoAddrConfFlagMask = (1 << 6) + ndpPrefixInformationAutoAddrConfFlagMask = 1 << 6 // ndpPrefixInformationReserved1FlagsMask is the mask of the Reserved1 // field in the flags byte within an NDPPrefixInformation. @@ -451,7 +451,7 @@ func (o NDPNonceOption) String() string { // Nonce returns the nonce value this option holds. func (o NDPNonceOption) Nonce() []byte { - return []byte(o) + return o } // NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option diff --git a/pkg/tcpip/header/ndp_router_advert.go b/pkg/tcpip/header/ndp_router_advert.go index bf7610863..7e2f0c797 100644 --- a/pkg/tcpip/header/ndp_router_advert.go +++ b/pkg/tcpip/header/ndp_router_advert.go @@ -19,12 +19,72 @@ import ( "time" ) +// NDPRoutePreference is the preference values for default routers or +// more-specific routes. +// +// As per RFC 4191 section 2.1, +// +// Default router preferences and preferences for more-specific routes +// are encoded the same way. +// +// Preference values are encoded as a two-bit signed integer, as +// follows: +// +// 01 High +// 00 Medium (default) +// 11 Low +// 10 Reserved - MUST NOT be sent +// +// Note that implementations can treat the value as a two-bit signed +// integer. +// +// Having just three values reinforces that they are not metrics and +// more values do not appear to be necessary for reasonable scenarios. +type NDPRoutePreference uint8 + +const ( + // HighRoutePreference indicates a high preference, as per + // RFC 4191 section 2.1. + HighRoutePreference NDPRoutePreference = 0b01 + + // MediumRoutePreference indicates a medium preference, as per + // RFC 4191 section 2.1. + // + // This is the default preference value. + MediumRoutePreference = 0b00 + + // LowRoutePreference indicates a low preference, as per + // RFC 4191 section 2.1. + LowRoutePreference = 0b11 + + // ReservedRoutePreference is a reserved preference value, as per + // RFC 4191 section 2.1. + // + // It MUST NOT be sent. + ReservedRoutePreference = 0b10 +) + // NDPRouterAdvert is an NDP Router Advertisement message. It will only contain // the body of an ICMPv6 packet. // -// See RFC 4861 section 4.2 for more details. +// See RFC 4861 section 4.2 and RFC 4191 section 2.2 for more details. type NDPRouterAdvert []byte +// As per RFC 4191 section 2.2, +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type | Code | Checksum | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Cur Hop Limit |M|O|H|Prf|Resvd| Router Lifetime | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Reachable Time | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Retrans Timer | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Options ... +// +-+-+-+-+-+-+-+-+-+-+-+- const ( // NDPRAMinimumSize is the minimum size of a valid NDP Router // Advertisement message (body of an ICMPv6 packet). @@ -47,6 +107,14 @@ const ( // within the bit-field/flags byte of an NDPRouterAdvert. ndpRAOtherConfFlagMask = (1 << 6) + // ndpDefaultRouterPreferenceShift is the shift of the Prf (Default Router + // Preference) field within the flags byte of an NDPRouterAdvert. + ndpDefaultRouterPreferenceShift = 3 + + // ndpDefaultRouterPreferenceMask is the mask of the Prf (Default Router + // Preference) field within the flags byte of an NDPRouterAdvert. + ndpDefaultRouterPreferenceMask = (0b11 << ndpDefaultRouterPreferenceShift) + // ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime // field within an NDPRouterAdvert. ndpRARouterLifetimeOffset = 2 @@ -80,6 +148,11 @@ func (b NDPRouterAdvert) OtherConfFlag() bool { return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0 } +// DefaultRouterPreference returns the Default Router Preference field. +func (b NDPRouterAdvert) DefaultRouterPreference() NDPRoutePreference { + return NDPRoutePreference((b[ndpRAFlagsOffset] & ndpDefaultRouterPreferenceMask) >> ndpDefaultRouterPreferenceShift) +} + // RouterLifetime returns the lifetime associated with the default router. A // value of 0 means the source of the Router Advertisement is not a default // router and SHOULD NOT appear on the default router list. Note, a value of 0 diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go index 1b5093e58..8fd1f7d13 100644 --- a/pkg/tcpip/header/ndp_test.go +++ b/pkg/tcpip/header/ndp_test.go @@ -126,36 +126,83 @@ func TestNDPNeighborAdvert(t *testing.T) { } func TestNDPRouterAdvert(t *testing.T) { - b := []byte{ - 64, 128, 1, 2, - 3, 4, 5, 6, - 7, 8, 9, 10, + tests := []struct { + hopLimit uint8 + managedFlag, otherConfFlag bool + prf NDPRoutePreference + routerLifetimeS uint16 + reachableTimeMS, retransTimerMS uint32 + }{ + { + hopLimit: 1, + managedFlag: false, + otherConfFlag: true, + prf: HighRoutePreference, + routerLifetimeS: 2, + reachableTimeMS: 3, + retransTimerMS: 4, + }, + { + hopLimit: 64, + managedFlag: true, + otherConfFlag: false, + prf: LowRoutePreference, + routerLifetimeS: 258, + reachableTimeMS: 78492, + retransTimerMS: 13213, + }, } - ra := NDPRouterAdvert(b) + for i, test := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + flags := uint8(0) + if test.managedFlag { + flags |= 1 << 7 + } + if test.otherConfFlag { + flags |= 1 << 6 + } + flags |= uint8(test.prf) << 3 - if got := ra.CurrHopLimit(); got != 64 { - t.Errorf("got ra.CurrHopLimit = %d, want = 64", got) - } + b := []byte{ + test.hopLimit, flags, 1, 2, + 3, 4, 5, 6, + 7, 8, 9, 10, + } + binary.BigEndian.PutUint16(b[2:], test.routerLifetimeS) + binary.BigEndian.PutUint32(b[4:], test.reachableTimeMS) + binary.BigEndian.PutUint32(b[8:], test.retransTimerMS) - if got := ra.ManagedAddrConfFlag(); !got { - t.Errorf("got ManagedAddrConfFlag = false, want = true") - } + ra := NDPRouterAdvert(b) - if got := ra.OtherConfFlag(); got { - t.Errorf("got OtherConfFlag = true, want = false") - } + if got := ra.CurrHopLimit(); got != test.hopLimit { + t.Errorf("got ra.CurrHopLimit() = %d, want = %d", got, test.hopLimit) + } - if got, want := ra.RouterLifetime(), time.Second*258; got != want { - t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want) - } + if got := ra.ManagedAddrConfFlag(); got != test.managedFlag { + t.Errorf("got ManagedAddrConfFlag() = %t, want = %t", got, test.managedFlag) + } - if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want { - t.Errorf("got ra.ReachableTime = %d, want = %d", got, want) - } + if got := ra.OtherConfFlag(); got != test.otherConfFlag { + t.Errorf("got OtherConfFlag() = %t, want = %t", got, test.otherConfFlag) + } + + if got := ra.DefaultRouterPreference(); got != test.prf { + t.Errorf("got DefaultRouterPreference() = %d, want = %d", got, test.prf) + } - if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want { - t.Errorf("got ra.RetransTimer = %d, want = %d", got, want) + if got, want := ra.RouterLifetime(), time.Second*time.Duration(test.routerLifetimeS); got != want { + t.Errorf("got ra.RouterLifetime() = %d, want = %d", got, want) + } + + if got, want := ra.ReachableTime(), time.Millisecond*time.Duration(test.reachableTimeMS); got != want { + t.Errorf("got ra.ReachableTime() = %d, want = %d", got, want) + } + + if got, want := ra.RetransTimer(), time.Millisecond*time.Duration(test.retransTimerMS); got != want { + t.Errorf("got ra.RetransTimer() = %d, want = %d", got, want) + } + }) } } diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 0df517000..8dabe3354 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -48,6 +48,16 @@ const ( // TCPFlags is the dedicated type for TCP flags. type TCPFlags uint8 +// Intersects returns true iff there are flags common to both f and o. +func (f TCPFlags) Intersects(o TCPFlags) bool { + return f&o != 0 +} + +// Contains returns true iff all the flags in o are contained within f. +func (f TCPFlags) Contains(o TCPFlags) bool { + return f&o == o +} + // String implements Stringer.String. func (f TCPFlags) String() string { flagsStr := []byte("FSRPAU") diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index ef9126deb..f26c857eb 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -288,5 +288,5 @@ func (*Endpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (*Endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { } diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index bddb1d0a2..735c28da1 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,7 +41,6 @@ package fdbased import ( "fmt" - "math" "sync/atomic" "golang.org/x/sys/unix" @@ -196,8 +195,12 @@ type Options struct { // option for an FD with a fanoutID already in use by another FD for a different // NIC will return an EINVAL. // +// Since fanoutID must be unique within the network namespace, we start with +// the PID to avoid collisions. The only way to be sure of avoiding collisions +// is to run in a new network namespace. +// // Must be accessed using atomic operations. -var fanoutID int32 = 0 +var fanoutID int32 = int32(unix.Getpid()) // New creates a new fd-based endpoint. // @@ -292,11 +295,6 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin } switch sa.(type) { case *unix.SockaddrLinklayer: - // See: PACKET_FANOUT_MAX in net/packet/internal.h - const packetFanoutMax = 1 << 16 - if fID > packetFanoutMax { - return nil, fmt.Errorf("host fanoutID limit exceeded, fanoutID must be <= %d", math.MaxUint16) - } // Enable PACKET_FANOUT mode if the underlying socket is of type // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will // prevent gvisor from receiving fragmented packets and the host does the @@ -317,7 +315,7 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin // // See: https://github.com/torvalds/linux/blob/7acac4b3196caee5e21fb5ea53f8bc124e6a16fc/net/packet/af_packet.c#L3881 const fanoutType = unix.PACKET_FANOUT_HASH - fanoutArg := int(fID) | fanoutType<<16 + fanoutArg := (int(fID) & 0xffff) | fanoutType<<16 if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err) } diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index a72eb1aad..6fa1aee18 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -28,13 +28,13 @@ go_test( ":arp", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", "//pkg/tcpip/testutil", - "//pkg/tcpip/transport/icmp", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 0efa3a926..6515c31e5 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -278,7 +278,7 @@ func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) { return "", "" } -func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.TransportDispatcher) stack.NetworkEndpoint { e := &endpoint{ protocol: p, nic: nic, diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 94209b026..5fcbfeaa2 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -15,15 +15,14 @@ package arp_test import ( - "context" "fmt" "testing" - "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -31,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/testutil" - "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) const ( @@ -39,15 +37,6 @@ const ( stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") - - defaultChannelSize = 1 - defaultMTU = 65536 - - // eventChanSize defines the size of event channels used by the neighbor - // cache's event dispatcher. The size chosen here needs to be sufficient to - // queue all the events received during tests before consumption. - // If eventChanSize is too small, the tests may deadlock. - eventChanSize = 32 ) var ( @@ -123,24 +112,6 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo d.C <- e } -func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { - select { - case got := <-d.C: - if diff := cmp.Diff(want, got, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { - return fmt.Errorf("got invalid event (-want +got):\n%s", diff) - } - case <-ctx.Done(): - return fmt.Errorf("%s for %s", ctx.Err(), want) - } - return nil -} - -func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return d.waitForEvent(ctx, want) -} - func (d *arpDispatcher) nextEvent() (eventInfo, bool) { select { case event := <-d.C: @@ -153,55 +124,45 @@ func (d *arpDispatcher) nextEvent() (eventInfo, bool) { type testContext struct { s *stack.Stack linkEP *channel.Endpoint - nudDisp *arpDispatcher + nudDisp arpDispatcher } -func newTestContext(t *testing.T) *testContext { - c := stack.DefaultNUDConfigurations() - // Transition from Reachable to Stale almost immediately to test if receiving - // probes refreshes positive reachability. - c.BaseReachableTime = time.Microsecond - - d := arpDispatcher{ - // Create an event channel large enough so the neighbor cache doesn't block - // while dispatching events. Blocking could interfere with the timing of - // NUD transitions. - C: make(chan eventInfo, eventChanSize), +func makeTestContext(t *testing.T, eventDepth int, packetDepth int) testContext { + t.Helper() + + tc := testContext{ + nudDisp: arpDispatcher{ + C: make(chan eventInfo, eventDepth), + }, } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, - NUDConfigs: c, - NUDDisp: &d, + tc.s = stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + NUDDisp: &tc.nudDisp, + Clock: &faketime.NullClock{}, }) - ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - wep := stack.LinkEndpoint(ep) + tc.linkEP = channel.New(packetDepth, header.IPv4MinimumMTU, stackLinkAddr) + tc.linkEP.LinkEPCapabilities |= stack.CapabilityResolutionRequired + wep := stack.LinkEndpoint(tc.linkEP) if testing.Verbose() { - wep = sniffer.New(ep) + wep = sniffer.New(wep) } - if err := s.CreateNIC(nicID, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + if err := tc.s.CreateNIC(nicID, wep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) } - if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) + if err := tc.s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %s", err) } - s.SetRouteTable([]tcpip.Route{{ + tc.s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, NIC: nicID, }}) - return &testContext{ - s: s, - linkEP: ep, - nudDisp: &d, - } + return tc } func (c *testContext) cleanup() { @@ -209,7 +170,7 @@ func (c *testContext) cleanup() { } func TestMalformedPacket(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 0, 0) defer c.cleanup() v := make(buffer.View, header.ARPSize) @@ -228,7 +189,7 @@ func TestMalformedPacket(t *testing.T) { } func TestDisabledEndpoint(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 0, 0) defer c.cleanup() ep, err := c.s.GetNetworkEndpoint(nicID, header.ARPProtocolNumber) @@ -253,7 +214,7 @@ func TestDisabledEndpoint(t *testing.T) { } func TestDirectReply(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 0, 0) defer c.cleanup() const senderMAC = "\x01\x02\x03\x04\x05\x06" @@ -284,7 +245,7 @@ func TestDirectReply(t *testing.T) { } func TestDirectRequest(t *testing.T) { - c := newTestContext(t) + c := makeTestContext(t, 1, 1) defer c.cleanup() tests := []struct { @@ -391,17 +352,21 @@ func TestDirectRequest(t *testing.T) { } // Verify the sender was saved in the neighbor cache. - wantEvent := eventInfo{ - eventType: entryAdded, - nicID: nicID, - entry: stack.NeighborEntry{ - Addr: test.senderAddr, - LinkAddr: tcpip.LinkAddress(test.senderLinkAddr), - State: stack.Stale, - }, - } - if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil { - t.Fatal(err) + if got, ok := c.nudDisp.nextEvent(); ok { + want := eventInfo{ + eventType: entryAdded, + nicID: nicID, + entry: stack.NeighborEntry{ + Addr: test.senderAddr, + LinkAddr: test.senderLinkAddr, + State: stack.Stale, + }, + } + if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" { + t.Errorf("got invalid event (-want +got):\n%s", diff) + } + } else { + t.Fatal("event didn't arrive") } neighbors, err := c.s.Neighbors(nicID, ipv4.ProtocolNumber) @@ -589,7 +554,7 @@ func TestLinkAddressRequest(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, }) - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + linkEP := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr) if err := s.CreateNIC(nicID, &testLinkEndpoint{LinkEndpoint: linkEP, writeErr: test.linkErr}); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } @@ -663,15 +628,16 @@ func TestLinkAddressRequest(t *testing.T) { } func TestDADARPRequestPacket(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocolWithOptions(arp.Options{ DADConfigs: stack.DADConfigurations{ DupAddrDetectTransmits: 1, - RetransmitTimer: time.Second, }, }), ipv4.NewProtocol}, + Clock: clock, }) - e := channel.New(1, defaultMTU, stackLinkAddr) + e := channel.New(1, header.IPv4MinimumMTU, stackLinkAddr) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } @@ -682,7 +648,8 @@ func TestDADARPRequestPacket(t *testing.T) { t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, header.IPv4ProtocolNumber, remoteAddr, res, stack.DADStarting) } - pkt, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + pkt, ok := e.Read() if !ok { t.Fatal("expected to send an ARP request") } diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation.go b/pkg/tcpip/network/internal/fragmentation/fragmentation.go index 5168f5361..1ba4d0d36 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation.go @@ -251,7 +251,7 @@ func (f *Fragmentation) releaseReassemblersLocked() { // The list is empty. break } - elapsed := time.Duration(now-r.creationTime) * time.Nanosecond + elapsed := now.Sub(r.createdAt) if f.timeout > elapsed { // If the oldest reassembler has not expired, schedule the release // job so that this function is called back when it has expired. diff --git a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go index 7daf64b4a..dadfc28cc 100644 --- a/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/internal/fragmentation/fragmentation_test.go @@ -275,15 +275,23 @@ func TestMemoryLimits(t *testing.T) { highLimit := 3 * lowLimit // Allow at most 3 such packets. f := NewFragmentation(minBlockSize, highLimit, lowLimit, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")) + if _, _, _, err := f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")) + if _, _, _, err := f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")); err != nil { + t.Fatal(err) + } // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")) + if _, _, _, err := f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")); err != nil { + t.Fatal(err) + } if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -300,9 +308,13 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { memSize := pkt(1, "0").MemSize() f := NewFragmentation(minBlockSize, memSize, 0, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) + if _, _, _, err := f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")); err != nil { + t.Fatal(err) + } if got, want := f.memSize, memSize; got != want { t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) diff --git a/pkg/tcpip/network/internal/fragmentation/reassembler.go b/pkg/tcpip/network/internal/fragmentation/reassembler.go index 56b76a284..5b7e4b361 100644 --- a/pkg/tcpip/network/internal/fragmentation/reassembler.go +++ b/pkg/tcpip/network/internal/fragmentation/reassembler.go @@ -35,21 +35,21 @@ type hole struct { type reassembler struct { reassemblerEntry - id FragmentID - memSize int - proto uint8 - mu sync.Mutex - holes []hole - filled int - done bool - creationTime int64 - pkt *stack.PacketBuffer + id FragmentID + memSize int + proto uint8 + mu sync.Mutex + holes []hole + filled int + done bool + createdAt tcpip.MonotonicTime + pkt *stack.PacketBuffer } func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { r := &reassembler{ - id: id, - creationTime: clock.NowMonotonic(), + id: id, + createdAt: clock.NowMonotonic(), } r.holes = append(r.holes, hole{ first: 0, diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go index eed49f5d2..5123b7d6a 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go @@ -83,6 +83,8 @@ func (d *DAD) Init(protocolMU sync.Locker, configs stack.DADConfigurations, opts panic(fmt.Sprintf("given a non-zero value for NonceSize (%d) but zero for ExtendDADTransmits", opts.NonceSize)) } + configs.Validate() + *d = DAD{ opts: opts, configs: configs, diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go index a22b712c6..24687cf06 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go @@ -133,7 +133,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADDisabled) } // Wait for any initially fired timers to complete. - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check(nil); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -147,7 +147,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -156,7 +156,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADAlreadyRunning { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADAlreadyRunning) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check(nil); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -170,7 +170,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr2}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -208,7 +208,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { if res := dad.checkDuplicateAddress(addr2, handler(ch, addr2)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr2, addr2}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -247,7 +247,7 @@ func TestDADStop(t *testing.T) { if res := dad.checkDuplicateAddress(addr3, handler(ch, addr3)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr2, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1, addr2, addr3}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -272,7 +272,7 @@ func TestDADStop(t *testing.T) { if res := dad.checkDuplicateAddress(addr1, handler(ch, addr1)); res != stack.DADStarting { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() if diff := dad.check([]tcpip.Address{addr1}); diff != "" { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } @@ -347,7 +347,7 @@ func TestNonce(t *testing.T) { t.Errorf("got dad.checkDuplicateAddress(%s, _) = %d, want = %d", addr1, res, stack.DADStarting) } - clock.Advance(0) + clock.RunImmediatelyScheduledJobs() for i, want := range test.expectedResults { if got := dad.extendIfNonceEqual(addr1, test.mockedReceivedNonce); got != want { t.Errorf("(i=%d) got dad.extendIfNonceEqual(%s, _) = %d, want = %d", i, addr1, got, want) diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go index d22974b12..671dfbf32 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go @@ -611,7 +611,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr // If a timer for any address is already running, it is reset to the new // random value only if the requested Maximum Response Delay is less than // the remaining value of the running timer. - now := time.Unix(0 /* seconds */, g.opts.Clock.NowNanoseconds()) + now := g.opts.Clock.Now() if info.state == delayingMember { if info.delayedReportJobFiresAt.IsZero() { panic(fmt.Sprintf("delayed report unscheduled while in the delaying member state; group = %s", groupAddress)) diff --git a/pkg/tcpip/network/internal/ip/stats.go b/pkg/tcpip/network/internal/ip/stats.go index 0c2b62127..40ab21cb6 100644 --- a/pkg/tcpip/network/internal/ip/stats.go +++ b/pkg/tcpip/network/internal/ip/stats.go @@ -42,6 +42,10 @@ type MultiCounterIPForwardingStats struct { // were too big for the outgoing MTU. PacketTooBig tcpip.MultiCounterStat + // HostUnreachable is the number of IP packets received which could not be + // successfully forwarded due to an unresolvable next hop. + HostUnreachable tcpip.MultiCounterStat + // ExtensionHeaderProblem is the number of IP packets which were dropped // because of a problem encountered when processing an IPv6 extension // header. @@ -61,6 +65,7 @@ func (m *MultiCounterIPForwardingStats) Init(a, b *tcpip.IPForwardingStats) { m.ExtensionHeaderProblem.Init(a.ExtensionHeaderProblem, b.ExtensionHeaderProblem) m.PacketTooBig.Init(a.PacketTooBig, b.PacketTooBig) m.ExhaustedTTL.Init(a.ExhaustedTTL, b.ExhaustedTTL) + m.HostUnreachable.Init(a.HostUnreachable, b.HostUnreachable) } // LINT.ThenChange(:MultiCounterIPForwardingStats, ../../../tcpip.go:IPForwardingStats) diff --git a/pkg/tcpip/network/internal/testutil/BUILD b/pkg/tcpip/network/internal/testutil/BUILD index cec3e62c4..a180e5c75 100644 --- a/pkg/tcpip/network/internal/testutil/BUILD +++ b/pkg/tcpip/network/internal/testutil/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/tcpip/network/internal/fragmentation:__pkg__", "//pkg/tcpip/network/ipv4:__pkg__", "//pkg/tcpip/network/ipv6:__pkg__", + "//pkg/tcpip/tests/integration:__pkg__", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index d1a82b584..2aa38eb98 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -173,9 +173,8 @@ func (e *endpoint) handleControl(errInfo stack.TransportError, pkt *stack.Packet func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { received := e.stats.icmp.packetsReceived - // TODO(gvisor.dev/issue/170): ICMP packets don't have their - // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set. See + // icmp/protocol.go:protocol.Parse for a full explanation. v, ok := pkt.Data().PullUp(header.ICMPv4MinimumSize) if !ok { received.invalid.Increment() @@ -222,7 +221,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() e.stats.ip.MalformedPacketsReceived.Increment() } return @@ -481,6 +479,22 @@ func (*icmpReasonFragmentationNeeded) isForwarding() bool { return true } +// icmpReasonHostUnreachable is an error in which the host specified in the +// internet destination field of the datagram is unreachable. +type icmpReasonHostUnreachable struct{} + +func (*icmpReasonHostUnreachable) isICMPReason() {} +func (*icmpReasonHostUnreachable) isForwarding() bool { + // If we hit a Host Unreachable error, then we know we are operating as a + // router. As per RFC 792 page 5, Destination Unreachable Message, + // + // In addition, in some networks, the gateway may be able to determine + // if the internet destination host is unreachable. Gateways in these + // networks may send destination unreachable messages to the source host + // when the destination host is unreachable. + return true +} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as @@ -537,7 +551,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip defer route.Release() p.mu.Lock() - netEP, ok := p.mu.eps[pkt.NICID] + // We retrieve an endpoint using the newly constructed route's NICID rather + // than the packet's NICID. The packet's NICID corresponds to the NIC on + // which it arrived, which isn't necessarily the same as the NIC on which it + // will be transmitted. On the other hand, the route's NIC *is* guaranteed + // to be the NIC on which the packet will be transmitted. + netEP, ok := p.mu.eps[route.NICID()] p.mu.Unlock() if !ok { return &tcpip.ErrNotConnected{} @@ -653,6 +672,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4NetUnreachable) counter = sent.dstUnreachable + case *icmpReasonHostUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) + icmpHdr.SetCode(header.ICMPv4HostUnreachable) + counter = sent.dstUnreachable case *icmpReasonFragmentationNeeded: icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4FragmentationNeeded) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 049811cbb..f08b008ac 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -63,9 +63,15 @@ const ( fragmentblockSize = 8 ) +const ( + forwardingDisabled = 0 + forwardingEnabled = 1 +) + var ipv4BroadcastAddr = header.IPv4Broadcast.WithPrefix() var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) +var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -76,12 +82,18 @@ type endpoint struct { protocol *protocol stats sharedStats - // enabled is set to 1 when the enpoint is enabled and 0 when it is + // enabled is set to 1 when the endpoint is enabled and 0 when it is // disabled. // // Must be accessed using atomic operations. enabled uint32 + // forwarding is set to forwardingEnabled when the endpoint has forwarding + // enabled and forwardingDisabled when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + mu struct { sync.RWMutex @@ -92,6 +104,16 @@ type endpoint struct { // HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // If we are operating as a router, return an ICMP error to the original + // packet's sender. + if pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP + // errors to local endpoints. + e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt) + e.stats.ip.Forwarding.Errors.Increment() + e.stats.ip.Forwarding.HostUnreachable.Increment() + return + } // handleControl expects the entire offending packet to be in the packet // buffer's data field. pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -151,14 +173,32 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { delete(p.mu.eps, nicID) } -// transitionForwarding transitions the endpoint's forwarding status to -// forwarding. +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) Forwarding() bool { + return atomic.LoadUint32(&e.forwarding) == forwardingEnabled +} + +// setForwarding sets the forwarding status for the endpoint. // -// Must only be called when the forwarding status changes. -func (e *endpoint) transitionForwarding(forwarding bool) { +// Returns true if the forwarding status was updated. +func (e *endpoint) setForwarding(v bool) bool { + forwarding := uint32(forwardingDisabled) + if v { + forwarding = forwardingEnabled + } + + return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding +} + +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) SetForwarding(forwarding bool) { e.mu.Lock() defer e.mu.Unlock() + if !e.setForwarding(forwarding) { + return + } + if forwarding { // There does not seem to be an RFC requirement for a node to join the all // routers multicast address but @@ -292,8 +332,8 @@ func (e *endpoint) DefaultTTL() uint8 { return e.protocol.DefaultTTL() } -// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus -// the network layer max header length. +// MTU implements stack.NetworkEndpoint. It returns the link-layer MTU minus the +// network layer max header length. func (e *endpoint) MTU() uint32 { networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize) if err != nil { @@ -308,7 +348,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.IPv4MaximumHeaderSize } -// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +// NetworkProtocolNumber implements stack.NetworkEndpoint. func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } @@ -323,7 +363,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet if hdrLen > header.IPv4MaximumHeaderSize { return &tcpip.ErrMessageTooLong{} } - ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) + ipH := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) length := pkt.Size() if length > math.MaxUint16 { return &tcpip.ErrMessageTooLong{} @@ -332,7 +372,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1) - ip.Encode(&header.IPv4Fields{ + ipH.Encode(&header.IPv4Fields{ TotalLength: uint16(length), ID: uint16(id), TTL: params.TTL, @@ -342,7 +382,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet DstAddr: dstAddr, Options: options, }) - ip.SetChecksum(^ip.CalculateChecksum()) + ipH.SetChecksum(^ipH.CalculateChecksum()) pkt.NetworkProtocolNumber = ProtocolNumber return nil } @@ -351,7 +391,7 @@ func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.Packet // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the // original packet. -func (e *endpoint) handleFragments(r *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { +func (e *endpoint) handleFragments(_ *stack.Route, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) tcpip.Error) (int, int, tcpip.Error) { // Round the MTU down to align to 8 bytes. fragmentPayloadSize := networkMTU &^ 7 networkHeader := header.IPv4(pkt.NetworkHeader().View()) @@ -389,9 +429,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // based on destination address and do not send the packet to link // layer. // - // TODO(gvisor.dev/issue/170): We should do this for every - // packet, rather than only NATted packets, but removing this check - // short circuits broadcasts before they are sent out to other hosts. + // We should do this for every packet, rather than only NATted packets, but + // removing this check short circuits broadcasts before they are sent out to + // other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { @@ -460,7 +500,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn return nil } -// WritePackets implements stack.NetworkEndpoint.WritePackets. +// WritePackets implements stack.NetworkEndpoint. func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop()&stack.PacketLoop != 0 { panic("multiple packets in local loop") @@ -563,34 +603,34 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu if !ok { return &tcpip.ErrMalformedHeader{} } - ip := header.IPv4(h) + ipH := header.IPv4(h) // Always set the total length. pktSize := pkt.Data().Size() - ip.SetTotalLength(uint16(pktSize)) + ipH.SetTotalLength(uint16(pktSize)) // Set the source address when zero. - if ip.SourceAddress() == header.IPv4Any { - ip.SetSourceAddress(r.LocalAddress()) + if ipH.SourceAddress() == header.IPv4Any { + ipH.SetSourceAddress(r.LocalAddress()) } // Set the destination. If the packet already included a destination, it will // be part of the route anyways. - ip.SetDestinationAddress(r.RemoteAddress()) + ipH.SetDestinationAddress(r.RemoteAddress()) // Set the packet ID when zero. - if ip.ID() == 0 { + if ipH.ID() == 0 { // RFC 6864 section 4.3 mandates uniqueness of ID values for // non-atomic datagrams, so assign an ID to all such datagrams // according to the definition given in RFC 6864 section 4. - if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { - ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress(), r.RemoteAddress(), 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) + if ipH.Flags()&header.IPv4FlagDontFragment == 0 || ipH.Flags()&header.IPv4FlagMoreFragments != 0 || ipH.FragmentOffset() > 0 { + ipH.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress(), r.RemoteAddress(), 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } } // Always set the checksum. - ip.SetChecksum(0) - ip.SetChecksum(^ip.CalculateChecksum()) + ipH.SetChecksum(0) + ipH.SetChecksum(^ipH.CalculateChecksum()) // Populate the packet buffer's network header and don't allow an invalid // packet to be sent. @@ -680,7 +720,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return nil } - ep.handleValidatedPacket(h, pkt) + // The packet originally arrived on e so provide its NIC as the input NIC. + ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) return nil } @@ -796,7 +837,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } // handleLocalPacket is like HandlePacket except it does not perform the @@ -815,10 +856,10 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum return } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } -func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) { pkt.NICID = e.nic.ID() stats := e.stats stats.ip.ValidPacketsReceived.Increment() @@ -852,7 +893,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) addressEndpoint.DecRef() pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast } else if !e.IsInGroup(dstAddr) { - if !e.protocol.Forwarding() { + if !e.Forwarding() { stats.ip.InvalidDestinationAddressesReceived.Increment() return } @@ -868,7 +909,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) case *ip.ErrNoRoute: stats.ip.Forwarding.Unrouteable.Increment() case *ip.ErrParameterProblem: - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() stats.ip.MalformedPacketsReceived.Increment() case *ip.ErrMessageTooLong: stats.ip.Forwarding.PacketTooBig.Increment() @@ -881,8 +921,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return @@ -905,7 +944,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() e.stats.ip.MalformedPacketsReceived.Increment() } return @@ -955,7 +993,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) // The reassembler doesn't take care of fixing up the header, so we need // to do it here. - h.SetTotalLength(uint16(pkt.Data().Size() + len((h)))) + h.SetTotalLength(uint16(pkt.Data().Size() + len(h))) h.SetFlagsFragmentOffset(0, 0) } stats.ip.PacketsDelivered.Increment() @@ -978,7 +1016,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) _ = e.protocol.returnError(&icmpReasonParamProblem{ pointer: optProblem.Pointer, }, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() stats.ip.MalformedPacketsReceived.Increment() } return @@ -1144,7 +1181,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { return &e.stats.localStats } -var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -1165,12 +1201,6 @@ type protocol struct { // Must be accessed using atomic operations. defaultTTL uint32 - // forwarding is set to 1 when the protocol has forwarding enabled and 0 - // when it is disabled. - // - // Must be accessed using atomic operations. - forwarding uint32 - ids []uint32 hashIV uint32 @@ -1194,13 +1224,13 @@ func (p *protocol) DefaultPrefixLen() int { return header.IPv4AddressSize * 8 } -// ParseAddresses implements NetworkProtocol.ParseAddresses. +// ParseAddresses implements stack.NetworkProtocol. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv4(v) return h.SourceAddress(), h.DestinationAddress() } -// SetOption implements NetworkProtocol.SetOption. +// SetOption implements stack.NetworkProtocol. func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: @@ -1211,7 +1241,7 @@ func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.E } } -// Option implements NetworkProtocol.Option. +// Option implements stack.NetworkProtocol. func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: @@ -1232,10 +1262,10 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } -// Close implements stack.TransportProtocol.Close. +// Close implements stack.TransportProtocol. func (*protocol) Close() {} -// Wait implements stack.TransportProtocol.Wait. +// Wait implements stack.TransportProtocol. func (*protocol) Wait() {} // parseAndValidate parses the packet (including its transport layer header) and @@ -1273,7 +1303,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv4, bool) return h, true } -// Parse implements stack.NetworkProtocol.Parse. +// Parse implements stack.NetworkProtocol. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { if ok := parse.IPv4(pkt); !ok { return 0, false, false @@ -1283,35 +1313,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) Forwarding() bool { - return uint8(atomic.LoadUint32(&p.forwarding)) == 1 -} - -// setForwarding sets the forwarding status for the protocol. -// -// Returns true if the forwarding status was updated. -func (p *protocol) setForwarding(v bool) bool { - if v { - return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) - } - return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) -} - -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) SetForwarding(v bool) { - p.mu.Lock() - defer p.mu.Unlock() - - if !p.setForwarding(v) { - return - } - - for _, ep := range p.mu.eps { - ep.transitionForwarding(v) - } -} - // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload mtu. func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) { @@ -1332,7 +1333,7 @@ func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error networkMTU = MaxTotalSize } - return networkMTU - uint32(networkHeaderSize), nil + return networkMTU - networkHeaderSize, nil } func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32) bool { @@ -1741,9 +1742,8 @@ type optionTracker struct { // // If there were no errors during parsing, the new set of options is returned as // a new buffer. -func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (header.IPv4Options, optionTracker, *header.IPv4OptParameterProblem) { +func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, opts header.IPv4Options, usage optionsUsage) (header.IPv4Options, optionTracker, *header.IPv4OptParameterProblem) { stats := e.stats.ip - opts := header.IPv4Options(orig) optIter := opts.MakeIterator() // Except NOP, each option must only appear at most once (RFC 791 section 3.1, diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 5f45b9ee6..4a4448cf9 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -16,7 +16,6 @@ package ipv4_test import ( "bytes" - "context" "encoding/hex" "fmt" "io/ioutil" @@ -36,10 +35,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" - "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" + iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" - tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -112,33 +111,29 @@ func TestExcludeBroadcast(t *testing.T) { }) } -type forwardedPacket struct { - fragments []fragmentInfo -} - func TestForwarding(t *testing.T) { const ( - nicID1 = 1 - nicID2 = 2 + incomingNICID = 1 + outgoingNICID = 2 randomSequence = 123 randomIdent = 42 randomTimeOffset = 0x10203040 ) - ipv4Addr1 := tcpip.AddressWithPrefix{ + incomingIPv4Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), PrefixLen: 8, } - ipv4Addr2 := tcpip.AddressWithPrefix{ + outgoingIPv4Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), PrefixLen: 8, } - linkAddr2 := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4()) - remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4()) - unreachableIPv4Addr := tcpip.Address(net.ParseIP("12.0.0.2").To4()) - multicastIPv4Addr := tcpip.Address(net.ParseIP("225.0.0.0").To4()) - linkLocalIPv4Addr := tcpip.Address(net.ParseIP("169.254.0.0").To4()) + outgoingLinkAddr := tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + remoteIPv4Addr1 := testutil.MustParse4("10.0.0.2") + remoteIPv4Addr2 := testutil.MustParse4("11.0.0.2") + unreachableIPv4Addr := testutil.MustParse4("12.0.0.2") + multicastIPv4Addr := testutil.MustParse4("225.0.0.0") + linkLocalIPv4Addr := testutil.MustParse4("169.254.0.0") tests := []struct { name string @@ -345,6 +340,7 @@ func TestForwarding(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { clock := faketime.NewManualClock() + s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, @@ -356,36 +352,36 @@ func TestForwarding(t *testing.T) { clock.Advance(time.Millisecond * randomTimeOffset) // We expect at most a single packet in response to our ICMP Echo Request. - e1 := channel.New(1, test.mtu, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + incomingEndpoint := channel.New(1, test.mtu, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } - ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1} - if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err) + incomingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: incomingIPv4Addr} + if err := s.AddProtocolAddress(incomingNICID, incomingIPv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv4ProtoAddr, err) } expectedEmittedPacketCount := 1 if len(test.expectedFragmentsForwarded) > expectedEmittedPacketCount { expectedEmittedPacketCount = len(test.expectedFragmentsForwarded) } - e2 := channel.New(expectedEmittedPacketCount, test.mtu, linkAddr2) - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + outgoingEndpoint := channel.New(expectedEmittedPacketCount, test.mtu, outgoingLinkAddr) + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } - ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2} - if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err) + outgoingIPv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: outgoingIPv4Addr} + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv4ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ { - Destination: ipv4Addr1.Subnet(), - NIC: nicID1, + Destination: incomingIPv4Addr.Subnet(), + NIC: incomingNICID, }, { - Destination: ipv4Addr2.Subnet(), - NIC: nicID2, + Destination: outgoingIPv4Addr.Subnet(), + NIC: outgoingNICID, }, }) @@ -401,13 +397,13 @@ func TestForwarding(t *testing.T) { totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength hdr := buffer.NewPrependable(totalLength) hdr.Prepend(test.payloadLength) - icmp := header.ICMPv4(hdr.Prepend(icmpHeaderLength)) - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv4Echo) - icmp.SetCode(header.ICMPv4UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(^header.Checksum(icmp, 0)) + icmpH := header.ICMPv4(hdr.Prepend(icmpHeaderLength)) + icmpH.SetIdent(randomIdent) + icmpH.SetSequence(randomSequence) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^header.Checksum(icmpH, 0)) ip := header.IPv4(hdr.Prepend(ipHeaderLength)) ip.Encode(&header.IPv4Fields{ TotalLength: uint16(totalLength), @@ -431,9 +427,10 @@ func TestForwarding(t *testing.T) { Data: hdr.View().ToVectorisedView(), }) requestPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + incomingEndpoint.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + + reply, ok := incomingEndpoint.Read() - reply, ok := e1.Read() if test.expectErrorICMP { if !ok { t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) @@ -451,15 +448,15 @@ func TestForwarding(t *testing.T) { return len(hdr.View()) } - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(ipv4Addr1.Address), + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(incomingIPv4Addr.Address), checker.DstAddr(test.sourceAddr), checker.TTL(ipv4.DefaultTTL), checker.ICMPv4( checker.ICMPv4Checksum(), checker.ICMPv4Type(test.icmpType), checker.ICMPv4Code(test.icmpCode), - checker.ICMPv4Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])), + checker.ICMPv4Payload(hdr.View()[:expectedICMPPayloadLength()]), ), ) } else if ok { @@ -468,9 +465,9 @@ func TestForwarding(t *testing.T) { if test.expectPacketForwarded { if len(test.expectedFragmentsForwarded) != 0 { - fragmentedPackets := []*stack.PacketBuffer{} + var fragmentedPackets []*stack.PacketBuffer for i := 0; i < len(test.expectedFragmentsForwarded); i++ { - reply, ok = e2.Read() + reply, ok = outgoingEndpoint.Read() if !ok { t.Fatal("expected ICMP Echo fragment through outgoing NIC") } @@ -485,16 +482,16 @@ func TestForwarding(t *testing.T) { // maximum IP header size and the maximum size allocated for link layer // headers. In this case, no size is allocated for link layer headers. expectedAvailableHeaderBytes := header.IPv4MaximumHeaderSize - if err := compareFragments(fragmentedPackets, requestPkt, uint32(test.mtu), test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { + if err := compareFragments(fragmentedPackets, requestPkt, test.mtu, test.expectedFragmentsForwarded, header.ICMPv4ProtocolNumber, true /* withIPHeader */, expectedAvailableHeaderBytes); err != nil { t.Error(err) } } else { - reply, ok = e2.Read() + reply, ok = outgoingEndpoint.Read() if !ok { t.Fatal("expected ICMP Echo packet through outgoing NIC") } - checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(test.sourceAddr), checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), @@ -508,11 +505,10 @@ func TestForwarding(t *testing.T) { ) } } else { - if reply, ok = e2.Read(); ok { + if reply, ok = outgoingEndpoint.Read(); ok { t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", reply) } } - boolToInt := func(val bool) uint64 { if val { return 1 @@ -1211,15 +1207,15 @@ func TestIPv4Sanity(t *testing.T) { } totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(totalLen)) - icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) // Specify ident/seq to make sure we get the same in the response. - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv4Echo) - icmp.SetCode(header.ICMPv4UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(^header.Checksum(icmp, 0)) + icmpH.SetIdent(randomIdent) + icmpH.SetSequence(randomSequence) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^header.Checksum(icmpH, 0)) ip := header.IPv4(hdr.Prepend(ipHeaderLength)) if test.maxTotalLength < totalLen { totalLen = test.maxTotalLength @@ -1314,7 +1310,7 @@ func TestIPv4Sanity(t *testing.T) { checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), checker.ICMPv4Pointer(test.paramProblemPointer), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload(hdr.View()), ), ) return @@ -1333,7 +1329,7 @@ func TestIPv4Sanity(t *testing.T) { checker.ICMPv4( checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Payload([]byte(hdr.View())), + checker.ICMPv4Payload(hdr.View()), ), ) return @@ -1545,9 +1541,9 @@ func TestFragmentationWritePacket(t *testing.T) { for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) source := pkt.Clone() err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -1601,7 +1597,7 @@ func TestFragmentationWritePackets(t *testing.T) { insertAfter: 1, }, } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) + tinyPacket := iptestutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) for _, test := range writePacketsTests { t.Run(test.description, func(t *testing.T) { @@ -1611,13 +1607,13 @@ func TestFragmentationWritePackets(t *testing.T) { for i := 0; i < test.insertBefore; i++ { pkts.PushBack(tinyPacket.Clone()) } - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) pkts.PushBack(pkt.Clone()) for i := 0; i < test.insertAfter; i++ { pkts.PushBack(tinyPacket.Clone()) } - ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter @@ -1725,8 +1721,8 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) - ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + pkt := iptestutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + ep := iptestutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -2300,7 +2296,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) { checker.ICMPv4Type(header.ICMPv4TimeExceeded), checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), checker.ICMPv4Checksum(), - checker.ICMPv4Payload([]byte(firstFragmentSent)), + checker.ICMPv4Payload(firstFragmentSent), ), ) }) @@ -2704,7 +2700,7 @@ func TestReceiveFragments(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, RawFactory: raw.EndpointFactory{}, }) - e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) + e := channel.New(0, 1280, "\xf0\x00") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -2947,7 +2943,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) + ep := iptestutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, &tcpip.ErrInvalidEndpointState{}, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList @@ -3034,7 +3030,7 @@ func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string, string) return false, false } -func TestPacketQueing(t *testing.T) { +func TestPacketQueuing(t *testing.T) { const nicID = 1 var ( @@ -3073,7 +3069,7 @@ func TestPacketQueing(t *testing.T) { Length: header.UDPMinimumSize, }) sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv4Addr.AddressWithPrefix.Address, host1IPv4Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ @@ -3089,7 +3085,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -3132,7 +3128,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -3156,9 +3152,11 @@ func TestPacketQueing(t *testing.T) { t.Run(test.name, func(t *testing.T) { e := channel.New(1, defaultMTU, host1NICLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -3181,7 +3179,8 @@ func TestPacketQueing(t *testing.T) { // Wait for a ARP request since link address resolution should be // performed. { - p, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -3222,6 +3221,7 @@ func TestPacketQueing(t *testing.T) { } // Expect the response now that the link address has resolved. + clock.RunImmediatelyScheduledJobs() test.checkResp(t, e) // Since link resolution was already performed, it shouldn't be performed @@ -3243,8 +3243,8 @@ func TestCloseLocking(t *testing.T) { ) var ( - src = tcptestutil.MustParse4("16.0.0.1") - dst = tcptestutil.MustParse4("16.0.0.2") + src = testutil.MustParse4("16.0.0.1") + dst = testutil.MustParse4("16.0.0.2") ) s := stack.New(stack.Options{ @@ -3252,7 +3252,7 @@ func TestCloseLocking(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - // Perform NAT so that the endoint tries to search for a sibling endpoint + // Perform NAT so that the endpoint tries to search for a sibling endpoint // which ends up taking the protocol and endpoint lock (in that order). table := stack.Table{ Rules: []stack.Rule{ diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 4051fda07..94caaae6c 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -285,8 +285,8 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, routerAlert *header.IPv6RouterAlertOption) { sent := e.stats.icmp.packetsSent received := e.stats.icmp.packetsReceived - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set. See icmp/protocol.go:protocol.Parse for a full explanation. + // ICMP packets don't have their TransportHeader fields set. See + // icmp/protocol.go:protocol.Parse for a full explanation. v, ok := pkt.Data().PullUp(header.ICMPv6HeaderSize) if !ok { received.invalid.Increment() @@ -745,11 +745,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r return } - stack := e.protocol.stack - - // Is the networking stack operating as a router? - if !stack.Forwarding(ProtocolNumber) { - // ... No, silently drop the packet. + if !e.Forwarding() { received.routerOnlyPacketsDroppedByHost.Increment() return } @@ -1033,6 +1029,26 @@ func (*icmpReasonNetUnreachable) respondsToMulticast() bool { return false } +// icmpReasonHostUnreachable is an error in which the host specified in the +// internet destination field of the datagram is unreachable. +type icmpReasonHostUnreachable struct{} + +func (*icmpReasonHostUnreachable) isICMPReason() {} +func (*icmpReasonHostUnreachable) isForwarding() bool { + // If we hit a Host Unreachable error, then we know we are operating as a + // router. As per RFC 4443 page 8, Destination Unreachable Message, + // + // If the reason for the failure to deliver cannot be mapped to any of + // other codes, the Code field is set to 3. Example of such cases are + // an inability to resolve the IPv6 destination address into a + // corresponding link address, or a link-specific problem of some sort. + return true +} + +func (*icmpReasonHostUnreachable) respondsToMulticast() bool { + return false +} + // icmpReasonFragmentationNeeded is an error where a packet is to big to be sent // out through the outgoing MTU, as per RFC 4443 page 9, Packet Too Big Message. type icmpReasonPacketTooBig struct{} @@ -1147,7 +1163,12 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip defer route.Release() p.mu.Lock() - netEP, ok := p.mu.eps[pkt.NICID] + // We retrieve an endpoint using the newly constructed route's NICID rather + // than the packet's NICID. The packet's NICID corresponds to the NIC on + // which it arrived, which isn't necessarily the same as the NIC on which it + // will be transmitted. On the other hand, the route's NIC *is* guaranteed + // to be the NIC on which the packet will be transmitted. + netEP, ok := p.mu.eps[route.NICID()] p.mu.Unlock() if !ok { return &tcpip.ErrNotConnected{} @@ -1226,6 +1247,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6NetworkUnreachable) counter = sent.dstUnreachable + case *icmpReasonHostUnreachable: + icmpHdr.SetType(header.ICMPv6DstUnreachable) + icmpHdr.SetCode(header.ICMPv6AddressUnreachable) + counter = sent.dstUnreachable case *icmpReasonPacketTooBig: icmpHdr.SetType(header.ICMPv6PacketTooBig) icmpHdr.SetCode(header.ICMPv6UnusedCode) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 040cd4bc8..c2e9544c1 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -16,17 +16,16 @@ package ipv6 import ( "bytes" - "context" "net" "reflect" "strings" "testing" - "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -46,16 +45,12 @@ const ( defaultChannelSize = 1 defaultMTU = 65536 - // Extra time to use when waiting for an async event to occur. - defaultAsyncPositiveEventTimeout = 30 * time.Second - arbitraryHopLimit = 42 ) var ( lladdr0 = header.LinkLocalAddr(linkAddr0) lladdr1 = header.LinkLocalAddr(linkAddr1) - lladdr2 = header.LinkLocalAddr(linkAddr2) ) type stubLinkEndpoint struct { @@ -371,6 +366,8 @@ type testContext struct { linkEP0 *channel.Endpoint linkEP1 *channel.Endpoint + + clock *faketime.ManualClock } type endpointWithResolutionCapability struct { @@ -382,15 +379,19 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab } func newTestContext(t *testing.T) *testContext { + clock := faketime.NewManualClock() c := &testContext{ s0: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + Clock: clock, }), s1: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + Clock: clock, }), + clock: clock, } c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0) @@ -457,10 +458,14 @@ type routeArgs struct { remoteLinkAddr tcpip.LinkAddress } -func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) { +func routeICMPv6Packet(t *testing.T, clock *faketime.ManualClock, args routeArgs, fn func(*testing.T, header.ICMPv6)) { t.Helper() - pi, _ := args.src.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + pi, ok := args.src.Read() + if !ok { + t.Fatal("packet didn't arrive") + } { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -533,7 +538,7 @@ func TestLinkResolution(t *testing.T) { {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert}, } { - routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) { + routeICMPv6Packet(t, c.clock, args, func(t *testing.T, icmpv6 header.ICMPv6) { if got, want := tcpip.Address(icmpv6[8:][:16]), lladdr1; got != want { t.Errorf("%d: got target = %s, want = %s", icmpv6.Type(), got, want) } @@ -544,7 +549,7 @@ func TestLinkResolution(t *testing.T) { {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6EchoRequest}, {src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6EchoReply}, } { - routeICMPv6Packet(t, args, nil) + routeICMPv6Packet(t, c.clock, args, nil) } } @@ -1309,7 +1314,7 @@ func TestPacketQueing(t *testing.T) { Length: header.UDPMinimumSize, }) sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) @@ -1325,7 +1330,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -1371,7 +1376,7 @@ func TestPacketQueing(t *testing.T) { })) }, checkResp: func(t *testing.T, e *channel.Endpoint) { - p, ok := e.ReadContext(context.Background()) + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -1396,9 +1401,11 @@ func TestPacketQueing(t *testing.T) { e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -1421,7 +1428,8 @@ func TestPacketQueing(t *testing.T) { // Wait for a neighbor solicitation since link address resolution should // be performed. { - p, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, ok := e.Read() if !ok { t.Fatalf("timed out waiting for packet") } @@ -1475,6 +1483,7 @@ func TestPacketQueing(t *testing.T) { } // Expect the response now that the link address has resolved. + clock.RunImmediatelyScheduledJobs() test.checkResp(t, e) // Since link resolution was already performed, it shouldn't be performed diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index f0e06f86b..8c8fafcda 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -63,6 +63,11 @@ const ( buckets = 2048 ) +const ( + forwardingDisabled = 0 + forwardingEnabled = 1 +) + // policyTable is the default policy table defined in RFC 6724 section 2.1. // // A more human-readable version: @@ -168,6 +173,7 @@ func getLabel(addr tcpip.Address) uint8 { var _ stack.DuplicateAddressDetector = (*endpoint)(nil) var _ stack.LinkAddressResolver = (*endpoint)(nil) var _ stack.LinkResolvableNetworkEndpoint = (*endpoint)(nil) +var _ stack.ForwardingNetworkEndpoint = (*endpoint)(nil) var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -178,7 +184,6 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol - stack *stack.Stack stats sharedStats // enabled is set to 1 when the endpoint is enabled and 0 when it is @@ -187,6 +192,12 @@ type endpoint struct { // Must be accessed using atomic operations. enabled uint32 + // forwarding is set to forwardingEnabled when the endpoint has forwarding + // enabled and forwardingDisabled when it is disabled. + // + // Must be accessed using atomic operations. + forwarding uint32 + mu struct { sync.RWMutex @@ -219,11 +230,11 @@ type endpoint struct { // If the NIC was created with a name, it is passed to NICNameFromID. // // NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are -// generated for the same prefix on differnt NICs. +// generated for the same prefix on different NICs. type NICNameFromID func(tcpip.NICID, string) string // OpaqueInterfaceIdentifierOptions holds the options related to the generation -// of opaque interface indentifiers (IIDs) as defined by RFC 7217. +// of opaque interface identifiers (IIDs) as defined by RFC 7217. type OpaqueInterfaceIdentifierOptions struct { // NICNameFromID is a function that returns a stable name for a specified NIC, // even if the NIC ID changes over time. @@ -270,6 +281,16 @@ func (*endpoint) DuplicateAddressProtocol() tcpip.NetworkProtocolNumber { // HandleLinkResolutionFailure implements stack.LinkResolvableNetworkEndpoint. func (e *endpoint) HandleLinkResolutionFailure(pkt *stack.PacketBuffer) { + // If we are operating as a router, we should return an ICMP error to the + // original packet's sender. + if pkt.NetworkPacketInfo.IsForwardedPacket { + // TODO(gvisor.dev/issue/6005): Propagate asynchronously generated ICMP + // errors to local endpoints. + e.protocol.returnError(&icmpReasonHostUnreachable{}, pkt) + e.stats.ip.Forwarding.Errors.Increment() + e.stats.ip.Forwarding.HostUnreachable.Increment() + return + } // handleControl expects the entire offending packet to be in the packet // buffer's data field. pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -405,20 +426,38 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t } } -// transitionForwarding transitions the endpoint's forwarding status to -// forwarding. +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) Forwarding() bool { + return atomic.LoadUint32(&e.forwarding) == forwardingEnabled +} + +// setForwarding sets the forwarding status for the endpoint. // -// Must only be called when the forwarding status changes. -func (e *endpoint) transitionForwarding(forwarding bool) { +// Returns true if the forwarding status was updated. +func (e *endpoint) setForwarding(v bool) bool { + forwarding := uint32(forwardingDisabled) + if v { + forwarding = forwardingEnabled + } + + return atomic.SwapUint32(&e.forwarding, forwarding) != forwarding +} + +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (e *endpoint) SetForwarding(forwarding bool) { + e.mu.Lock() + defer e.mu.Unlock() + + if !e.setForwarding(forwarding) { + return + } + allRoutersGroups := [...]tcpip.Address{ header.IPv6AllRoutersInterfaceLocalMulticastAddress, header.IPv6AllRoutersLinkLocalMulticastAddress, header.IPv6AllRoutersSiteLocalMulticastAddress, } - e.mu.Lock() - defer e.mu.Unlock() - if forwarding { // As per RFC 4291 section 2.8: // @@ -506,7 +545,7 @@ func (e *endpoint) Enable() tcpip.Error { // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent // state. // - // Addresses may have aleady completed DAD but in the time since the endpoint + // Addresses may have already completed DAD but in the time since the endpoint // was last enabled, other devices may have acquired the same addresses. var err tcpip.Error e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { @@ -611,8 +650,8 @@ func (e *endpoint) DefaultTTL() uint8 { return e.protocol.DefaultTTL() } -// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus -// the network layer max header length. +// MTU implements stack.NetworkEndpoint. It returns the link-layer MTU minus the +// network layer max header length. func (e *endpoint) MTU() uint32 { networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize) if err != nil { @@ -636,8 +675,7 @@ func addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params if length > math.MaxUint16 { return &tcpip.ErrMessageTooLong{} } - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) - ip.Encode(&header.IPv6Fields{ + header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)).Encode(&header.IPv6Fields{ PayloadLength: uint16(length), TransportProtocol: params.Protocol, HopLimit: params.TTL, @@ -717,9 +755,9 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // based on destination address and do not send the packet to link // layer. // - // TODO(gvisor.dev/issue/170): We should do this for every - // packet, rather than only NATted packets, but removing this check - // short circuits broadcasts before they are sent out to other hosts. + // We should do this for every packet, rather than only NATted packets, but + // removing this check short circuits broadcasts before they are sent out to + // other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { @@ -788,7 +826,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol return nil } -// WritePackets implements stack.NetworkEndpoint.WritePackets. +// WritePackets implements stack.NetworkEndpoint. func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { if r.Loop()&stack.PacketLoop != 0 { panic("not implemented") @@ -879,20 +917,20 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu if !ok { return &tcpip.ErrMalformedHeader{} } - ip := header.IPv6(h) + ipH := header.IPv6(h) // Always set the payload length. pktSize := pkt.Data().Size() - ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) + ipH.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) // Set the source address when zero. - if ip.SourceAddress() == header.IPv6Any { - ip.SetSourceAddress(r.LocalAddress()) + if ipH.SourceAddress() == header.IPv6Any { + ipH.SetSourceAddress(r.LocalAddress()) } // Set the destination. If the packet already included a destination, it will // be part of the route anyways. - ip.SetDestinationAddress(r.RemoteAddress()) + ipH.SetDestinationAddress(r.RemoteAddress()) // Populate the packet buffer's network header and don't allow an invalid // packet to be sent. @@ -953,7 +991,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return nil } - ep.handleValidatedPacket(h, pkt) + // The packet originally arrived on e so provide its NIC as the input NIC. + ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) return nil } @@ -1066,7 +1105,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } // handleLocalPacket is like HandlePacket except it does not perform the @@ -1085,10 +1124,10 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum return } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } -func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) { pkt.NICID = e.nic.ID() stats := e.stats.ip stats.ValidPacketsReceived.Increment() @@ -1109,7 +1148,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil { addressEndpoint.DecRef() } else if !e.IsInGroup(dstAddr) { - if !e.protocol.Forwarding() { + if !e.Forwarding() { stats.InvalidDestinationAddressesReceived.Increment() return } @@ -1137,8 +1176,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return @@ -1580,7 +1618,7 @@ func (e *endpoint) Close() { e.protocol.forgetEndpoint(e.nic.ID()) } -// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. +// NetworkProtocolNumber implements stack.NetworkEndpoint. func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } @@ -1589,8 +1627,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) } @@ -1631,8 +1669,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre // RemovePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { @@ -1932,7 +1970,6 @@ func (e *endpoint) Stats() stack.NetworkEndpointStats { return &e.stats.localStats } -var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) var _ fragmentation.TimeoutHandler = (*protocol)(nil) @@ -1957,12 +1994,6 @@ type protocol struct { // Must be accessed using atomic operations. defaultTTL uint32 - // forwarding is set to 1 when the protocol has forwarding enabled and 0 - // when it is disabled. - // - // Must be accessed using atomic operations. - forwarding uint32 - fragmentation *fragmentation.Fragmentation } @@ -1981,7 +2012,7 @@ func (p *protocol) DefaultPrefixLen() int { return header.IPv6AddressSize * 8 } -// ParseAddresses implements NetworkProtocol.ParseAddresses. +// ParseAddresses implements stack.NetworkProtocol. func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { h := header.IPv6(v) return h.SourceAddress(), h.DestinationAddress() @@ -2058,7 +2089,7 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { delete(p.mu.eps, nicID) } -// SetOption implements NetworkProtocol.SetOption. +// SetOption implements stack.NetworkProtocol. func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: @@ -2069,7 +2100,7 @@ func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) tcpip.E } } -// Option implements NetworkProtocol.Option. +// Option implements stack.NetworkProtocol. func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: @@ -2090,10 +2121,10 @@ func (p *protocol) DefaultTTL() uint8 { return uint8(atomic.LoadUint32(&p.defaultTTL)) } -// Close implements stack.TransportProtocol.Close. +// Close implements stack.TransportProtocol. func (*protocol) Close() {} -// Wait implements stack.TransportProtocol.Wait. +// Wait implements stack.TransportProtocol. func (*protocol) Wait() {} // parseAndValidate parses the packet (including its transport layer header) and @@ -2127,7 +2158,7 @@ func (p *protocol) parseAndValidate(pkt *stack.PacketBuffer) (header.IPv6, bool) return h, true } -// Parse implements stack.NetworkProtocol.Parse. +// Parse implements stack.NetworkProtocol. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) if !ok { @@ -2137,35 +2168,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return proto, !fragMore && fragOffset == 0, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) Forwarding() bool { - return uint8(atomic.LoadUint32(&p.forwarding)) == 1 -} - -// setForwarding sets the forwarding status for the protocol. -// -// Returns true if the forwarding status was updated. -func (p *protocol) setForwarding(v bool) bool { - if v { - return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) - } - return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) -} - -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (p *protocol) SetForwarding(v bool) { - p.mu.Lock() - defer p.mu.Unlock() - - if !p.setForwarding(v) { - return - } - - for _, ep := range p.mu.eps { - ep.transitionForwarding(v) - } -} - // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload MTU and the length of every IPv6 header. // Note that this is different than the Payload Length field of the IPv6 header, @@ -2186,7 +2188,7 @@ func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, tcpip.Error return 0, &tcpip.ErrMalformedHeader{} } - networkMTU := linkMTU - uint32(networkHeadersLen) + networkMTU := linkMTU - networkHeadersLen if networkMTU > maxPayloadSize { networkMTU = maxPayloadSize } @@ -2205,7 +2207,7 @@ type Options struct { // Note, setting this to true does not mean that a link-local address is // assigned right away, or at all. If Duplicate Address Detection is enabled, // an address is only assigned if it successfully resolves. If it fails, no - // further attempts are made to auto-generate a link-local adddress. + // further attempts are made to auto-generate a link-local address. // // The generated link-local address follows RFC 4291 Appendix A guidelines. AutoGenLinkLocal bool @@ -2221,7 +2223,7 @@ type Options struct { // TempIIDSeed is used to seed the initial temporary interface identifier // history value used to generate IIDs for temporary SLAAC addresses. // - // Temporary SLAAC adresses are short-lived addresses which are unpredictable + // Temporary SLAAC addresses are short-lived addresses which are unpredictable // and random from the perspective of other nodes on the network. It is // recommended that the seed be a random byte buffer of at least // header.IIDSize bytes to make sure that temporary SLAAC addresses are diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 30325160a..d2a23fd4f 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -129,7 +129,7 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize) // UDP checksum - sum = header.Checksum(header.UDP([]byte{}), sum) + sum = header.Checksum(nil, sum) u.SetChecksum(^u.CalculateChecksum(sum)) payloadLength := hdr.UsedLength() @@ -402,7 +402,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }{ { name: "None", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr }, + extHdr: func(nextHdr uint8) ([]byte, uint8) { return nil, nextHdr }, shouldAccept: true, expectICMP: false, }, @@ -612,8 +612,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { { name: "No next header", extHdr: func(nextHdr uint8) ([]byte, uint8) { - return []byte{}, - noNextHdrID + return nil, noNextHdrID }, shouldAccept: false, expectICMP: false, @@ -1160,7 +1159,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2), []buffer.View{ // Fragment extension header. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}, ipv6Payload1Addr1ToAddr2, }, @@ -1180,7 +1179,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2), []buffer.View{ // Fragment extension header. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}, ipv6Payload3Addr1ToAddr2, }, @@ -1202,7 +1201,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1218,7 +1217,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1240,7 +1239,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1256,7 +1255,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1278,7 +1277,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1296,7 +1295,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment offset = 8, More = false, ID = 1 // NextHeader value is different than the one in the first fragment, so // this NextHeader should be ignored. - buffer.View([]byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1318,7 +1317,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload3Addr1ToAddr2[:64], }, @@ -1334,7 +1333,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload3Addr1ToAddr2[64:], }, @@ -1356,7 +1355,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload3Addr1ToAddr2[:63], }, @@ -1372,7 +1371,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload3Addr1ToAddr2[63:], }, @@ -1394,7 +1393,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1410,7 +1409,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1432,7 +1431,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], }, @@ -1448,10 +1447,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + []byte{uint8(header.UDPProtocolNumber), 0, udpMaximumSizeMinus15 >> 8, udpMaximumSizeMinus15 & 0xff, - 0, 0, 0, 1}), + 0, 0, 0, 1}, ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], }, @@ -1473,7 +1472,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], }, @@ -1489,10 +1488,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + []byte{uint8(header.UDPProtocolNumber), 0, udpMaximumSizeMinus15 >> 8, (udpMaximumSizeMinus15 & 0xff) + 1, - 0, 0, 0, 1}), + 0, 0, 0, 1}, ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], }, @@ -1514,12 +1513,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header. // // Segments left = 0. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), + []byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}, // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1535,12 +1534,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header. // // Segments left = 0. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}), + []byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}, // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1562,12 +1561,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header. // // Segments left = 1. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), + []byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}, // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1583,12 +1582,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Routing extension header. // // Segments left = 1. - buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}), + []byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}, // Fragment extension header. // // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1610,12 +1609,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, // Routing extension header. // // Segments left = 0. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}), + []byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1631,7 +1630,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1653,12 +1652,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, // Routing extension header. // // Segments left = 1. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}), + []byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1674,7 +1673,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 9, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1699,12 +1698,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, // Routing extension header (part 1) // // Segments left = 0. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}), + []byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}, }, ), }, @@ -1721,10 +1720,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 1, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}, // Routing extension header (part 2) - buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), + []byte{6, 7, 8, 9, 10, 11, 12, 13}, ipv6Payload1Addr1ToAddr2, }, @@ -1749,12 +1748,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}, // Routing extension header (part 1) // // Segments left = 1. - buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}), + []byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}, }, ), }, @@ -1771,10 +1770,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 1, More = false, ID = 1 - buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}), + []byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}, // Routing extension header (part 2) - buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}), + []byte{6, 7, 8, 9, 10, 11, 12, 13}, ipv6Payload1Addr1ToAddr2, }, @@ -1798,7 +1797,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1816,7 +1815,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}, ipv6Payload2Addr1ToAddr2, }, @@ -1832,7 +1831,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1854,7 +1853,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1870,7 +1869,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}, ipv6Payload2Addr1ToAddr2[:32], }, @@ -1886,7 +1885,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1902,7 +1901,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 4, More = false, ID = 2 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}, ipv6Payload2Addr1ToAddr2[32:], }, @@ -1924,7 +1923,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[:64], }, @@ -1940,7 +1939,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}, ipv6Payload1Addr3ToAddr2[:32], }, @@ -1956,7 +1955,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 8, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}, ipv6Payload1Addr1ToAddr2[64:], }, @@ -1972,7 +1971,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Fragment extension header. // // Fragment offset = 4, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}), + []byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}, ipv6Payload1Addr3ToAddr2[32:], }, @@ -2208,7 +2207,7 @@ func TestInvalidIPv6Fragments(t *testing.T) { checker.ICMPv6Type(test.expectICMPType), checker.ICMPv6Code(test.expectICMPCode), checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific), - checker.ICMPv6Payload([]byte(expectICMPPayload)), + checker.ICMPv6Payload(expectICMPPayload), ), ) }) @@ -2459,7 +2458,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) { checker.ICMPv6( checker.ICMPv6Type(header.ICMPv6TimeExceeded), checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout), - checker.ICMPv6Payload([]byte(firstFragmentSent)), + checker.ICMPv6Payload(firstFragmentSent), ), ) }) @@ -2795,11 +2794,7 @@ var fragmentationTests = []struct { } func TestFragmentationWritePacket(t *testing.T) { - const ( - ttl = 42 - tos = stack.DefaultTOS - transportProto = tcp.ProtocolNumber - ) + const ttl = 42 for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { @@ -3004,17 +2999,17 @@ func TestFragmentationErrors(t *testing.T) { func TestForwarding(t *testing.T) { const ( - nicID1 = 1 - nicID2 = 2 + incomingNICID = 1 + outgoingNICID = 2 randomSequence = 123 randomIdent = 42 ) - ipv6Addr1 := tcpip.AddressWithPrefix{ + incomingIPv6Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("10::1").To16()), PrefixLen: 64, } - ipv6Addr2 := tcpip.AddressWithPrefix{ + outgoingIPv6Addr := tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("11::1").To16()), PrefixLen: 64, } @@ -3022,6 +3017,7 @@ func TestForwarding(t *testing.T) { Address: tcpip.Address(net.ParseIP("ff00::").To16()), PrefixLen: 64, } + remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16()) remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16()) unreachableIPv6Addr := tcpip.Address(net.ParseIP("12::2").To16()) @@ -3296,36 +3292,36 @@ func TestForwarding(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, }) // We expect at most a single packet in response to our ICMP Echo Request. - e1 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID1, e1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + incomingEndpoint := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) } - ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1} - if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err) + incomingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: incomingIPv6Addr} + if err := s.AddProtocolAddress(incomingNICID, incomingIPv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingIPv6ProtoAddr, err) } - e2 := channel.New(1, header.IPv6MinimumMTU, "") - if err := s.CreateNIC(nicID2, e2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + outgoingEndpoint := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) } - ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2} - if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil { - t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err) + outgoingIPv6ProtoAddr := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: outgoingIPv6Addr} + if err := s.AddProtocolAddress(outgoingNICID, outgoingIPv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingIPv6ProtoAddr, err) } s.SetRouteTable([]tcpip.Route{ { - Destination: ipv6Addr1.Subnet(), - NIC: nicID1, + Destination: incomingIPv6Addr.Subnet(), + NIC: incomingNICID, }, { - Destination: ipv6Addr2.Subnet(), - NIC: nicID2, + Destination: outgoingIPv6Addr.Subnet(), + NIC: outgoingNICID, }, { Destination: multicastIPv6Addr.Subnet(), - NIC: nicID2, + NIC: outgoingNICID, }, }) @@ -3334,7 +3330,7 @@ func TestForwarding(t *testing.T) { } transportProtocol := header.ICMPv6ProtocolNumber - extHdrBytes := []byte{} + var extHdrBytes []byte extHdrChecker := checker.IPv6ExtHdr() if test.extHdr != nil { nextHdrID := hopByHopExtHdrID @@ -3348,15 +3344,15 @@ func TestForwarding(t *testing.T) { totalLength := ipHeaderLength + icmpHeaderLength + test.payloadLength + extHdrLen hdr := buffer.NewPrependable(totalLength) hdr.Prepend(test.payloadLength) - icmp := header.ICMPv6(hdr.Prepend(icmpHeaderLength)) - - icmp.SetIdent(randomIdent) - icmp.SetSequence(randomSequence) - icmp.SetType(header.ICMPv6EchoRequest) - icmp.SetCode(header.ICMPv6UnusedCode) - icmp.SetChecksum(0) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp, + icmpH := header.ICMPv6(hdr.Prepend(icmpHeaderLength)) + + icmpH.SetIdent(randomIdent) + icmpH.SetSequence(randomSequence) + icmpH.SetType(header.ICMPv6EchoRequest) + icmpH.SetCode(header.ICMPv6UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpH, Src: test.sourceAddr, Dst: test.destAddr, })) @@ -3372,9 +3368,10 @@ func TestForwarding(t *testing.T) { requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) - e1.InjectInbound(ProtocolNumber, requestPkt) + incomingEndpoint.InjectInbound(ProtocolNumber, requestPkt) + + reply, ok := incomingEndpoint.Read() - reply, ok := e1.Read() if test.expectErrorICMP { if !ok { t.Fatalf("expected ICMP packet type %d through incoming NIC", test.icmpType) @@ -3393,31 +3390,31 @@ func TestForwarding(t *testing.T) { return len(hdr.View()) } - checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), - checker.SrcAddr(ipv6Addr1.Address), + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(incomingIPv6Addr.Address), checker.DstAddr(test.sourceAddr), checker.TTL(DefaultTTL), checker.ICMPv6( checker.ICMPv6Type(test.icmpType), checker.ICMPv6Code(test.icmpCode), - checker.ICMPv6Payload([]byte(hdr.View()[0:expectedICMPPayloadLength()])), + checker.ICMPv6Payload(hdr.View()[:expectedICMPPayloadLength()]), ), ) - if n := e2.Drain(); n != 0 { + if n := outgoingEndpoint.Drain(); n != 0 { t.Fatalf("got e2.Drain() = %d, want = 0", n) } } else if ok { t.Fatalf("expected no ICMP packet through incoming NIC, instead found: %#v", reply) } - reply, ok = e2.Read() + reply, ok = outgoingEndpoint.Read() if test.expectPacketForwarded { if !ok { t.Fatal("expected ICMP Echo Request packet through outgoing NIC") } - checker.IPv6WithExtHdr(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.IPv6WithExtHdr(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), checker.SrcAddr(test.sourceAddr), checker.DstAddr(test.destAddr), checker.TTL(test.TTL-1), @@ -3429,7 +3426,7 @@ func TestForwarding(t *testing.T) { ), ) - if n := e1.Drain(); n != 0 { + if n := incomingEndpoint.Drain(); n != 0 { t.Fatalf("got e1.Drain() = %d, want = 0", n) } } else if ok { diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 71d1c3e28..bc9cf6999 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -16,6 +16,7 @@ package ipv6_test import ( "bytes" + "math/rand" "testing" "time" @@ -138,7 +139,8 @@ func TestSendQueuedMLDReports(t *testing.T) { var secureRNG bytes.Reader secureRNG.Reset(secureRNGBytes[:]) s := stack.New(stack.Options{ - SecureRNG: &secureRNG, + SecureRNG: &secureRNG, + RandSource: rand.NewSource(time.Now().UnixNano()), NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ DupAddrDetectTransmits: test.dadTransmits, diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index b29fed347..ee36ed254 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -16,7 +16,6 @@ package ipv6 import ( "fmt" - "math/rand" "time" "gvisor.dev/gvisor/pkg/sync" @@ -215,28 +214,23 @@ type NDPDispatcher interface { // is also not permitted to call into the stack. OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) - // OnDefaultRouterDiscovered is called when a new default router is - // discovered. Implementations must return true if the newly discovered - // router should be remembered. + // OnOffLinkRouteUpdated is called when an off-link route is updated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool + OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address) - // OnDefaultRouterInvalidated is called when a discovered default router that - // was remembered is invalidated. + // OnOffLinkRouteInvalidated is called when an off-link route is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) + OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered. - // Implementations must return true if the newly discovered on-link prefix - // should be remembered. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool + OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that // was remembered is invalidated. @@ -246,13 +240,11 @@ type NDPDispatcher interface { OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) // OnAutoGenAddress is called when a new prefix with its autonomous address- - // configuration flag set is received and SLAAC was performed. Implementations - // may prevent the stack from assigning the address to the NIC by returning - // false. + // configuration flag set is received and SLAAC was performed. // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. - OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool + OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) // OnAutoGenAddressDeprecated is called when an auto-generated address (SLAAC) // is deprecated, but is still considered valid. Note, if an address is @@ -549,7 +541,7 @@ type tempSLAACAddrState struct { // Must not be nil. regenJob *tcpip.Job - createdAt time.Time + createdAt tcpip.MonotonicTime // The address's endpoint. // @@ -573,10 +565,10 @@ type slaacPrefixState struct { invalidationJob *tcpip.Job // Nonzero only when the address is not valid forever. - validUntil time.Time + validUntil tcpip.MonotonicTime // Nonzero only when the address is not preferred forever. - preferredUntil time.Time + preferredUntil tcpip.MonotonicTime // State associated with the stable address generated for the prefix. stableAddr struct { @@ -705,7 +697,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // per-interface basis; it is a protocol-wide configuration, so we check the // protocol's forwarding flag to determine if the IPv6 endpoint is forwarding // packets. - if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) { + if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) { ndp.ep.stats.localStats.UnhandledRouterAdvertisements.Increment() return } @@ -832,7 +824,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { // Let the integrator know a discovered default router is invalidated. if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { - ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip) + ndpDisp.OnOffLinkRouteInvalidated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip) } } @@ -849,11 +841,7 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { } // Inform the integrator when we discovered a default router. - if !ndpDisp.OnDefaultRouterDiscovered(ndp.ep.nic.ID(), ip) { - // Informed by the integrator to not remember the router, do - // nothing further. - return - } + ndpDisp.OnOffLinkRouteUpdated(ndp.ep.nic.ID(), header.IPv6EmptySubnet, ip) state := defaultRouterState{ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { @@ -879,11 +867,7 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) } // Inform the integrator when we discovered an on-link prefix. - if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) { - // Informed by the integrator to not remember the prefix, do - // nothing further. - return - } + ndpDisp.OnOnLinkPrefixDiscovered(ndp.ep.nic.ID(), prefix) state := onLinkPrefixState{ invalidationJob: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { @@ -1051,7 +1035,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1, } - now := time.Now() + now := ndp.ep.protocol.stack.Clock().NowMonotonic() // The time an address is preferred until is needed to properly generate the // address. @@ -1097,16 +1081,13 @@ func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, config return nil } - if !ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) { - // Informed by the integrator not to add the address. - return nil - } - addressEndpoint, err := ndp.ep.addAndAcquirePermanentAddressLocked(addr, stack.FirstPrimaryEndpoint, configType, deprecated) if err != nil { panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", addr, err)) } + ndpDisp.OnAutoGenAddress(ndp.ep.nic.ID(), addr) + return addressEndpoint } @@ -1182,7 +1163,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt state.stableAddr.localGenerationFailures++ } - if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil { + if addressEndpoint := ndp.addAndAcquireSLAACAddr(generatedAddr, stack.AddressConfigSlaac, ndp.ep.protocol.stack.Clock().NowMonotonic().Sub(state.preferredUntil) >= 0 /* deprecated */); addressEndpoint != nil { state.stableAddr.addressEndpoint = addressEndpoint state.generationAttempts++ return true @@ -1237,13 +1218,13 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla } stableAddr := prefixState.stableAddr.addressEndpoint.AddressWithPrefix().Address - now := time.Now() + now := ndp.ep.protocol.stack.Clock().NowMonotonic() // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary // address is the lower of the valid lifetime of the stable address or the // maximum temporary address valid lifetime. vl := ndp.configs.MaxTempAddrValidLifetime - if prefixState.validUntil != (time.Time{}) { + if prefixState.validUntil != (tcpip.MonotonicTime{}) { if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL { vl = prefixVL } @@ -1259,7 +1240,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla // maximum temporary address preferred lifetime - the temporary address desync // factor. pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor - if prefixState.preferredUntil != (time.Time{}) { + if prefixState.preferredUntil != (tcpip.MonotonicTime{}) { if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL { // Respect the preferred lifetime of the prefix, as per RFC 4941 section // 3.3 step 4. @@ -1394,7 +1375,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // deprecation job so it can be reset. prefixState.deprecationJob.Cancel() - now := time.Now() + now := ndp.ep.protocol.stack.Clock().NowMonotonic() // Schedule the deprecation job if prefix has a finite preferred lifetime. if pl < header.NDPInfiniteLifetime { @@ -1403,7 +1384,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } prefixState.preferredUntil = now.Add(pl) } else { - prefixState.preferredUntil = time.Time{} + prefixState.preferredUntil = tcpip.MonotonicTime{} } // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix: @@ -1421,17 +1402,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // Handle the infinite valid lifetime separately as we do not schedule a // job in this case. prefixState.invalidationJob.Cancel() - prefixState.validUntil = time.Time{} + prefixState.validUntil = tcpip.MonotonicTime{} } else { var effectiveVl time.Duration var rl time.Duration // If the prefix was originally set to be valid forever, assume the // remaining time to be the maximum possible value. - if prefixState.validUntil == (time.Time{}) { + if prefixState.validUntil == (tcpip.MonotonicTime{}) { rl = header.NDPInfiniteLifetime } else { - rl = time.Until(prefixState.validUntil) + rl = prefixState.validUntil.Sub(now) } if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl { @@ -1463,7 +1444,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // maximum temporary address valid lifetime. Note, the valid lifetime of a // temporary address is relative to the address's creation time. validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime) - if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 { + if prefixState.validUntil != (tcpip.MonotonicTime{}) && validUntil.Sub(prefixState.validUntil) > 0 { validUntil = prefixState.validUntil } @@ -1483,7 +1464,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // desync factor. Note, the preferred lifetime of a temporary address is // relative to the address's creation time. preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor) - if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 { + if prefixState.preferredUntil != (tcpip.MonotonicTime{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 { preferredUntil = prefixState.preferredUntil } @@ -1710,7 +1691,7 @@ func (ndp *ndpState) startSolicitingRouters() { return } - if !ndp.configs.HandleRAs.enabled(ndp.ep.protocol.Forwarding()) { + if !ndp.configs.HandleRAs.enabled(ndp.ep.Forwarding()) { return } @@ -1718,7 +1699,7 @@ func (ndp *ndpState) startSolicitingRouters() { // 4861 section 6.3.7. var delay time.Duration if ndp.configs.MaxRtrSolicitationDelay > 0 { - delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) + delay = time.Duration(ndp.ep.protocol.stack.Rand().Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) } // Protected by ndp.ep.mu. @@ -1754,7 +1735,7 @@ func (ndp *ndpState) startSolicitingRouters() { header.NDPSourceLinkLayerAddressOption(linkAddress), } } - payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) + payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + optsSerializer.Length() icmpData := header.ICMPv6(buffer.NewView(payloadSize)) icmpData.SetType(header.ICMPv6RouterSolicit) rs := header.NDPRouterSolicit(icmpData.MessageBody()) @@ -1861,7 +1842,7 @@ func (ndp *ndpState) init(ep *endpoint, dadOptions ip.DADOptions) { header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) if MaxDesyncFactor != 0 { - ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) + ndp.temporaryAddressDesyncFactor = time.Duration(ep.protocol.stack.Rand().Int63n(int64(MaxDesyncFactor))) } } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 570c6c00c..95d23f200 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -16,7 +16,7 @@ package ipv6 import ( "bytes" - "context" + "math/rand" "strings" "testing" "time" @@ -32,58 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" ) -// setupStackAndEndpoint creates a stack with a single NIC with a link-local -// address llladdr and an IPv6 endpoint to a remote with link-local address -// rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, - }) - - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - { - subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } - - netProto := s.NetworkProtocolInstance(ProtocolNumber) - if netProto == nil { - t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) - } - - ep := netProto.NewEndpoint(&testInterface{}, &stubDispatcher{}) - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - t.Cleanup(ep.Close) - - addressableEndpoint, ok := ep.(stack.AddressableEndpoint) - if !ok { - t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") - } - addr := llladdr.WithPrefix() - if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { - t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) - } else { - addressEP.DecRef() - } - - return s, ep -} - var _ NDPDispatcher = (*testNDPDispatcher)(nil) // testNDPDispatcher is an NDPDispatcher only allows default router discovery. @@ -94,24 +42,21 @@ type testNDPDispatcher struct { func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { } -func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { +func (t *testNDPDispatcher) OnOffLinkRouteUpdated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address) { t.addr = addr - return true } -func (t *testNDPDispatcher) OnDefaultRouterInvalidated(_ tcpip.NICID, addr tcpip.Address) { +func (t *testNDPDispatcher) OnOffLinkRouteInvalidated(_ tcpip.NICID, _ tcpip.Subnet, addr tcpip.Address) { t.addr = addr } -func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { - return false +func (*testNDPDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) { } func (*testNDPDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) { } -func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { - return false +func (*testNDPDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) { } func (*testNDPDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) { @@ -163,11 +108,6 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { } } -type linkResolutionResult struct { - linkAddr tcpip.LinkAddress - ok bool -} - // TestNeighborSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. @@ -456,8 +396,10 @@ func TestNeighborSolicitationResponse(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + Clock: clock, }) e := channel.New(1, 1280, nicLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired @@ -527,7 +469,8 @@ func TestNeighborSolicitationResponse(t *testing.T) { } if test.performsLinkResolution { - p, got := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, got := e.Read() if !got { t.Fatal("expected an NDP NS response") } @@ -582,7 +525,8 @@ func TestNeighborSolicitationResponse(t *testing.T) { })) } - p, got := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, got := e.Read() if !got { t.Fatal("expected an NDP NA response") } @@ -732,15 +676,7 @@ func TestNeighborAdvertisementWithTargetLinkLayerOption(t *testing.T) { } func TestNDPValidation(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) - - return s, ep - } + const nicID = 1 handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { var extHdrs header.IPv6ExtHdrSerializer @@ -865,6 +801,11 @@ func TestNDPValidation(t *testing.T) { }, } + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) + if err != nil { + t.Fatal(err) + } + for _, typ := range types { for _, isRouter := range []bool{false, true} { name := typ.name @@ -875,7 +816,10 @@ func TestNDPValidation(t *testing.T) { t.Run(name, func(t *testing.T) { for _, test := range subTests { t.Run(test.name, func(t *testing.T) { - s, ep := setup(t) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + }) if isRouter { if err := s.SetForwardingDefaultAndAllNICs(ProtocolNumber, true); err != nil { @@ -883,17 +827,35 @@ func TestNDPValidation(t *testing.T) { } } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, lladdr0, err) + } + + ep, err := s.GetNetworkEndpoint(nicID, ProtocolNumber) + if err != nil { + t.Fatal("cannot find network endpoint instance for IPv6") + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}) + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid routerOnly := stats.RouterOnlyPacketsDroppedByHost typStat := typ.statCounter(stats) - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ - Header: icmp[:typ.size], + icmpH := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmpH[typ.size:], typ.extraData) + icmpH.SetType(typ.typ) + icmpH.SetCode(test.code) + icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpH[:typ.size], Src: lladdr0, Dst: lladdr1, PayloadCsum: header.Checksum(typ.extraData /* initial */, 0), @@ -907,19 +869,19 @@ func TestNDPValidation(t *testing.T) { // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) + t.Errorf("got invalid.Value() = %d, want = 0", got) } - // RouterOnlyPacketsReceivedByHost count should initially be 0. + // Should initially not have dropped any packets. if got := routerOnly.Value(); got != 0 { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + t.Errorf("got routerOnly.Value() = %d, want = 0", got) } if t.Failed() { t.FailNow() } - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) + handleIPv6Payload(buffer.View(icmpH), test.hopLimit, test.atomicFragment, ep) // Rx count of the NDP packet should have increased. if got := typStat.Value(); got != 1 { @@ -932,18 +894,18 @@ func TestNDPValidation(t *testing.T) { want = 1 } if got := invalid.Value(); got != want { - t.Errorf("got invalid = %d, want = %d", got, want) + t.Errorf("got invalid.Value() = %d, want = %d", got, want) } want = 0 if test.valid && !isRouter && typ.routerOnly { - // RouterOnlyPacketsReceivedByHost count should have increased. + // Router only packets are expected to be dropped when operating + // as a host. want = 1 } if got := routerOnly.Value(); got != want { - t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) + t.Errorf("got routerOnly.Value() = %d, want = %d", got, want) } - }) } }) @@ -1279,8 +1241,9 @@ func TestCheckDuplicateAddress(t *testing.T) { var secureRNG bytes.Reader secureRNG.Reset(secureRNGBytes[:]) s := stack.New(stack.Options{ - SecureRNG: &secureRNG, - Clock: clock, + Clock: clock, + RandSource: rand.NewSource(time.Now().UnixNano()), + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocolWithOptions(Options{ DADConfigs: dadConfigs, })}, @@ -1297,7 +1260,8 @@ func TestCheckDuplicateAddress(t *testing.T) { snmc := header.SolicitedNodeAddr(lladdr0) remoteLinkAddr := header.EthernetAddressFromMulticastIPv6Address(snmc) checkDADMsg := func() { - p, ok := e.ReadContext(context.Background()) + clock.RunImmediatelyScheduledJobs() + p, ok := e.Read() if !ok { t.Fatalf("expected %d-th DAD message", dadPacketsSent) } @@ -1345,7 +1309,7 @@ func TestCheckDuplicateAddress(t *testing.T) { t.Fatalf("RemoveAddress(%d, %s): %s", nicID, lladdr0, err) } // Should not restart DAD since we already requested DAD above - the handler - // should be called when the original request compeletes so we should not send + // should be called when the original request completes so we should not send // an extra DAD message here. dadRequestsMade++ if res, err := s.CheckDuplicateAddress(nicID, ProtocolNumber, lladdr0, func(r stack.DADResult) { diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index b5b013b64..854d6a6ba 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -101,7 +101,7 @@ func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) { // Wildcard destinations affect all destinations for TupleOnly. if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress { // Only bitwise and the TupleOnlyFlag. - intersection &= ((^TupleOnlyFlag) | counter.SharedFlags()) + intersection &= (^TupleOnlyFlag) | counter.SharedFlags() count++ } } @@ -238,13 +238,13 @@ type PortTester func(port uint16) (good bool, err tcpip.Error) // possible ephemeral ports, allowing the caller to decide whether a given port // is suitable for its needs, and stopping when a port is found or an error // occurs. -func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err tcpip.Error) { +func (pm *PortManager) PickEphemeralPort(rng *rand.Rand, testPort PortTester) (port uint16, err tcpip.Error) { pm.ephemeralMu.RLock() firstEphemeral := pm.firstEphemeral numEphemeral := pm.numEphemeral pm.ephemeralMu.RUnlock() - offset := uint32(rand.Int31n(int32(numEphemeral))) + offset := uint32(rng.Int31n(int32(numEphemeral))) return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort) } @@ -303,7 +303,7 @@ func pickEphemeralPort(offset uint32, first, count uint16, testPort PortTester) // An optional PortTester can be passed in which if provided will be used to // test if the picked port can be used. The function should return true if the // port is safe to use, false otherwise. -func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) { +func (pm *PortManager) ReservePort(rng *rand.Rand, res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) { pm.mu.Lock() defer pm.mu.Unlock() @@ -328,7 +328,7 @@ func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reserv } // A port wasn't specified, so try to find one. - return pm.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { + return pm.PickEphemeralPort(rng, func(p uint16) (bool, tcpip.Error) { res.Port = p if !pm.reserveSpecificPortLocked(res) { return false, nil diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 6c4fb8c68..a91b130df 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -18,6 +18,7 @@ import ( "math" "math/rand" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" @@ -331,6 +332,7 @@ func TestPortReservation(t *testing.T) { t.Run(test.tname, func(t *testing.T) { pm := NewPortManager() net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} + rng := rand.New(rand.NewSource(time.Now().UnixNano())) for _, test := range test.actions { first, _ := pm.PortRange() @@ -356,7 +358,7 @@ func TestPortReservation(t *testing.T) { BindToDevice: test.device, Dest: test.dest, } - gotPort, err := pm.ReservePort(portRes, nil /* testPort */) + gotPort, err := pm.ReservePort(rng, portRes, nil /* testPort */) if diff := cmp.Diff(test.want, err); diff != "" { t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff) } @@ -417,10 +419,11 @@ func TestPickEphemeralPort(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() + rng := rand.New(rand.NewSource(time.Now().UnixNano())) if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { t.Fatalf("failed to set ephemeral port range: %s", err) } - port, err := pm.PickEphemeralPort(test.f) + port, err := pm.PickEphemeralPort(rng, test.f) if diff := cmp.Diff(test.wantErr, err); diff != "" { t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) } diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index b26936b7f..0ea85f9ed 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -222,7 +222,7 @@ type SocketOptions struct { getReceiveBufferLimits GetReceiveBufferLimits `state:"manual"` // receiveBufferSize determines the receive buffer size for this socket. - receiveBufferSize int64 + receiveBufferSize atomicbitops.AlignedAtomicInt64 // mu protects the access to the below fields. mu sync.Mutex `state:"nosave"` @@ -601,9 +601,10 @@ func (so *SocketOptions) GetBindToDevice() int32 { return atomic.LoadInt32(&so.bindToDevice) } -// SetBindToDevice sets value for SO_BINDTODEVICE option. +// SetBindToDevice sets value for SO_BINDTODEVICE option. If bindToDevice is +// zero, the socket device binding is removed. func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error { - if !so.handler.HasNIC(bindToDevice) { + if bindToDevice != 0 && !so.handler.HasNIC(bindToDevice) { return &ErrUnknownDevice{} } @@ -653,13 +654,13 @@ func (so *SocketOptions) SetSendBufferSize(sendBufferSize int64, notify bool) { // GetReceiveBufferSize gets value for SO_RCVBUF option. func (so *SocketOptions) GetReceiveBufferSize() int64 { - return atomic.LoadInt64(&so.receiveBufferSize) + return so.receiveBufferSize.Load() } // SetReceiveBufferSize sets value for SO_RCVBUF option. func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bool) { if !notify { - atomic.StoreInt64(&so.receiveBufferSize, receiveBufferSize) + so.receiveBufferSize.Store(receiveBufferSize) return } @@ -684,8 +685,8 @@ func (so *SocketOptions) SetReceiveBufferSize(receiveBufferSize int64, notify bo v = math.MaxInt32 } - oldSz := atomic.LoadInt64(&so.receiveBufferSize) + oldSz := so.receiveBufferSize.Load() // Notify endpoint about change in buffer size. newSz := so.handler.OnSetReceiveBufferSize(v, oldSz) - atomic.StoreInt64(&so.receiveBufferSize, newSz) + so.receiveBufferSize.Store(newSz) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 84aa6a9e4..395ff9a07 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -56,6 +56,7 @@ go_library( "neighbor_entry_list.go", "neighborstate_string.go", "nic.go", + "nic_stats.go", "nud.go", "packet_buffer.go", "packet_buffer_list.go", diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 5720e7543..18e0d4374 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -35,7 +35,6 @@ import ( // Currently, only TCP tracking is supported. // Our hash table has 16K buckets. -// TODO(gvisor.dev/issue/170): These should be tunable. const numBuckets = 1 << 14 // Direction of the tuple. @@ -125,6 +124,8 @@ type conn struct { tcb tcpconntrack.TCB // lastUsed is the last time the connection saw a relevant packet, and // is updated by each packet on the connection. It is protected by mu. + // + // TODO(gvisor.dev/issue/5939): do not use the ambient clock. lastUsed time.Time `state:".(unixTime)"` } @@ -163,8 +164,6 @@ func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. - // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle - // other tcp states. if cn.tcb.IsEmpty() { cn.tcb.Init(tcpHeader) } else if hook == cn.tcbHook { @@ -244,8 +243,7 @@ func (ct *ConnTrack) init() { // connFor gets the conn for pkt if it exists, or returns nil // if it does not. It returns an error when pkt does not contain a valid TCP // header. -// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support -// other transport protocols. +// TODO(gvisor.dev/issue/6168): Support UDP. func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { tid, err := packetToTupleID(pkt) if err != nil { @@ -383,7 +381,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } - // TODO(gvisor.dev/issue/170): Support other transport protocols. + // TODO(gvisor.dev/issue/6168): Support UDP. if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { return false } @@ -464,8 +462,6 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { } // Update the state of tcb. - // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle - // other tcp states. conn.mu.Lock() defer conn.mu.Unlock() @@ -542,8 +538,6 @@ func (ct *ConnTrack) bucket(id tupleID) int { // reapUnused returns the next bucket that should be checked and the time after // which it should be called again. func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { - // TODO(gvisor.dev/issue/170): This can be more finely controlled, as - // it is in Linux via sysctl. const fractionPerReaping = 128 const maxExpiredPct = 50 const maxFullTraversal = 60 * time.Second diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index ff555722e..72f66441f 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -54,6 +55,11 @@ type fwdTestNetworkEndpoint struct { nic NetworkInterface proto *fwdTestNetworkProtocol dispatcher TransportDispatcher + + mu struct { + sync.RWMutex + forwarding bool + } } func (*fwdTestNetworkEndpoint) Enable() tcpip.Error { @@ -109,10 +115,6 @@ func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fwdTestNetHeaderLen } -func (*fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -129,7 +131,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, params NetworkHeaderParam } // WritePackets implements LinkEndpoint.WritePackets. -func (*fwdTestNetworkEndpoint) WritePackets(r *Route, pkts PacketBufferList, params NetworkHeaderParams) (int, tcpip.Error) { +func (*fwdTestNetworkEndpoint) WritePackets(*Route, PacketBufferList, NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } @@ -169,11 +171,6 @@ type fwdTestNetworkProtocol struct { addrResolveDelay time.Duration onLinkAddressResolved func(*neighborCache, tcpip.Address, tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) - - mu struct { - sync.RWMutex - forwarding bool - } } func (*fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -224,7 +221,7 @@ func (*fwdTestNetworkProtocol) Wait() {} func (f *fwdTestNetworkEndpoint) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error { if fn := f.proto.onLinkAddressResolved; fn != nil { - time.AfterFunc(f.proto.addrResolveDelay, func() { + f.proto.stack.clock.AfterFunc(f.proto.addrResolveDelay, func() { fn(f.proto.neigh, addr, remoteLinkAddr) }) } @@ -242,16 +239,16 @@ func (*fwdTestNetworkEndpoint) LinkAddressProtocol() tcpip.NetworkProtocolNumber return fwdTestNetNumber } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (f *fwdTestNetworkProtocol) Forwarding() bool { +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (f *fwdTestNetworkEndpoint) Forwarding() bool { f.mu.RLock() defer f.mu.RUnlock() return f.mu.forwarding } -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (f *fwdTestNetworkProtocol) SetForwarding(v bool) { +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (f *fwdTestNetworkEndpoint) SetForwarding(v bool) { f.mu.Lock() defer f.mu.Unlock() f.mu.forwarding = v @@ -319,7 +316,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { return e.linkAddr } -func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { +func (e fwdTestLinkEndpoint) WritePacket(r RouteInfo, _ tcpip.NetworkProtocolNumber, pkt *PacketBuffer) tcpip.Error { p := fwdTestPacketInfo{ RemoteLinkAddress: r.RemoteLinkAddress, LocalLinkAddress: r.LocalLinkAddress, @@ -354,17 +351,19 @@ func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { } // AddHeader implements stack.LinkEndpoint.AddHeader. -func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { +func (e *fwdTestLinkEndpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *PacketBuffer) { panic("not implemented") } -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (*faketime.ManualClock, *fwdTestLinkEndpoint, *fwdTestLinkEndpoint) { + clock := faketime.NewManualClock() // Create a stack with the network protocol and two NICs. s := New(Options{ NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { proto.stack = s return proto }}, + Clock: clock, }) protoNum := proto.Number() @@ -373,7 +372,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f } // NIC 1 has the link address "a", and added the network address 1. - ep1 = &fwdTestLinkEndpoint{ + ep1 := &fwdTestLinkEndpoint{ C: make(chan fwdTestPacketInfo, 300), mtu: fwdTestNetDefaultMTU, linkAddr: "a", @@ -386,7 +385,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f } // NIC 2 has the link address "b", and added the network address 2. - ep2 = &fwdTestLinkEndpoint{ + ep2 := &fwdTestLinkEndpoint{ C: make(chan fwdTestPacketInfo, 300), mtu: fwdTestNetDefaultMTU, linkAddr: "b", @@ -416,7 +415,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}}) } - return ep1, ep2 + return clock, ep1, ep2 } func TestForwardingWithStaticResolver(t *testing.T) { @@ -432,7 +431,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -444,6 +443,7 @@ func TestForwardingWithStaticResolver(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: default: @@ -475,7 +475,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -487,9 +487,10 @@ func TestForwardingWithFakeResolver(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -508,7 +509,7 @@ func TestForwardingWithNoResolver(t *testing.T) { // Whether or not we use the neighbor cache here does not matter since // neither linkAddrCache nor neighborCache will be used. - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. @@ -518,10 +519,11 @@ func TestForwardingWithNoResolver(t *testing.T) { Data: buf.ToVectorisedView(), })) + clock.Advance(proto.addrResolveDelay) select { case <-ep2.C: t.Fatal("Packet should not be forwarded") - case <-time.After(time.Second): + default: } } @@ -533,7 +535,7 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + clock, ep1, ep2 := fwdTestNetFactory(t, proto) const numPackets int = 5 // These packets will all be enqueued in the packet queue to wait for link @@ -547,12 +549,12 @@ func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { } // All packets should fail resolution. - // TODO(gvisor.dev/issue/5141): Use a fake clock. for i := 0; i < numPackets; i++ { + clock.Advance(proto.addrResolveDelay) select { case got := <-ep2.C: t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got) - case <-time.After(100 * time.Millisecond): + default: } } } @@ -576,7 +578,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { } }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject an inbound packet to address 4 on NIC 1. This packet should // not be forwarded. @@ -596,9 +598,10 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -631,7 +634,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) // Inject two inbound packets to address 3 on NIC 1. for i := 0; i < 2; i++ { @@ -645,9 +648,10 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { for i := 0; i < 2; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -681,7 +685,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) for i := 0; i < maxPendingPacketsPerResolution+5; i++ { // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. @@ -697,9 +701,10 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { for i := 0; i < maxPendingPacketsPerResolution; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } @@ -745,7 +750,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { }) }, } - ep1, ep2 := fwdTestNetFactory(t, &proto) + clock, ep1, ep2 := fwdTestNetFactory(t, &proto) for i := 0; i < maxPendingResolutions+5; i++ { // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. @@ -761,9 +766,10 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { for i := 0; i < maxPendingResolutions; i++ { var p fwdTestPacketInfo + clock.Advance(proto.addrResolveDelay) select { case p = <-ep2.C: - case <-time.After(time.Second): + default: t.Fatal("packet not forwarded") } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 3670d5995..f152c0d83 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -42,7 +42,7 @@ const reaperDelay = 5 * time.Second // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. -func DefaultTables() *IPTables { +func DefaultTables(seed uint32) *IPTables { return &IPTables{ v4Tables: [NumTables]Table{ NATID: { @@ -182,7 +182,7 @@ func DefaultTables() *IPTables { Postrouting: {MangleID, NATID}, }, connections: ConnTrack{ - seed: generateRandUint32(), + seed: seed, }, reaperDone: make(chan struct{}, 1), } @@ -268,10 +268,6 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // -// TODO(gvisor.dev/issue/170): PacketBuffer should hold the route, from -// which address can be gathered. Currently, address is only needed for -// prerouting. -// // Precondition: pkt.NetworkHeader is set. func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { @@ -371,6 +367,7 @@ func (it *IPTables) startReaper(interval time.Duration) { select { case <-it.reaperDone: return + // TODO(gvisor.dev/issue/5939): do not use the ambient clock. case <-time.After(interval): bucket, interval = it.connections.reapUnused(bucket, interval) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 2812c89aa..91e266de8 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -87,9 +87,6 @@ func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Addre // destination port/IP. Outgoing packets are redirected to the loopback device, // and incoming packets are redirected to the incoming interface (rather than // forwarded). -// -// TODO(gvisor.dev/issue/170): Other flags need to be added after we support -// them. type RedirectTarget struct { // Port indicates port used to redirect. It is immutable. Port uint16 @@ -100,9 +97,6 @@ type RedirectTarget struct { } // Action implements Target.Action. -// TODO(gvisor.dev/issue/170): Parse headers without copying. The current -// implementation only works for Prerouting and calls pkt.Clone(), neither -// of which should be the case. func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { @@ -136,8 +130,6 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r panic("redirect target is supported only on output and prerouting hooks") } - // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if - // we need to change dest address (for OUTPUT chain) or ports. switch protocol := pkt.TransportProtocolNumber; protocol { case header.UDPProtocolNumber: udpHeader := header.UDP(pkt.TransportHeader().View()) diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 93592e7f5..66e5f22ac 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -242,7 +242,6 @@ type IPHeaderFilter struct { func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool { // Extract header fields. var ( - // TODO(gvisor.dev/issue/170): Support other filter fields. transProto tcpip.TransportProtocolNumber dstAddr tcpip.Address srcAddr tcpip.Address @@ -291,7 +290,6 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicNa return true case Postrouting: - // TODO(gvisor.dev/issue/170): Add the check for POSTROUTING. return true default: panic(fmt.Sprintf("unknown hook: %d", hook)) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index d4ac9e1f8..b5c6626d6 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -16,14 +16,14 @@ package stack_test import ( "bytes" - "context" "encoding/binary" "fmt" + "math/rand" "testing" "time" "github.com/google/go-cmp/cmp" - "gvisor.dev/gvisor/pkg/rand" + cryptorand "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -112,11 +112,12 @@ type ndpDADEvent struct { res stack.DADResult } -type ndpRouterEvent struct { - nicID tcpip.NICID - addr tcpip.Address - // true if router was discovered, false if invalidated. - discovered bool +type ndpOffLinkRouteEvent struct { + nicID tcpip.NICID + subnet tcpip.Subnet + router tcpip.Address + // true if route was updated, false if invalidated. + updated bool } type ndpPrefixEvent struct { @@ -167,10 +168,8 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) // related events happen for test purposes. type ndpDispatcher struct { dadC chan ndpDADEvent - routerC chan ndpRouterEvent - rememberRouter bool + offLinkRouteC chan ndpOffLinkRouteEvent prefixC chan ndpPrefixEvent - rememberPrefix bool autoGenAddrC chan ndpAutoGenAddrEvent rdnssC chan ndpRDNSSEvent dnsslC chan ndpDNSSLEvent @@ -189,32 +188,32 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, add } } -// Implements ipv6.NDPDispatcher.OnDefaultRouterDiscovered. -func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ +// Implements ipv6.NDPDispatcher.OnOffLinkRouteUpdated. +func (n *ndpDispatcher) OnOffLinkRouteUpdated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) { + if c := n.offLinkRouteC; c != nil { + c <- ndpOffLinkRouteEvent{ nicID, - addr, + subnet, + router, true, } } - - return n.rememberRouter } -// Implements ipv6.NDPDispatcher.OnDefaultRouterInvalidated. -func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) { - if c := n.routerC; c != nil { - c <- ndpRouterEvent{ +// Implements ipv6.NDPDispatcher.OnOffLinkRouteInvalidated. +func (n *ndpDispatcher) OnOffLinkRouteInvalidated(nicID tcpip.NICID, subnet tcpip.Subnet, router tcpip.Address) { + if c := n.offLinkRouteC; c != nil { + c <- ndpOffLinkRouteEvent{ nicID, - addr, + subnet, + router, false, } } } // Implements ipv6.NDPDispatcher.OnOnLinkPrefixDiscovered. -func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool { +func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) { if c := n.prefixC; c != nil { c <- ndpPrefixEvent{ nicID, @@ -222,8 +221,6 @@ func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip true, } } - - return n.rememberPrefix } // Implements ipv6.NDPDispatcher.OnOnLinkPrefixInvalidated. @@ -237,7 +234,7 @@ func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpi } } -func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool { +func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { if c := n.autoGenAddrC; c != nil { c <- ndpAutoGenAddrEvent{ nicID, @@ -245,7 +242,6 @@ func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWi newAddr, } } - return true } func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) { @@ -481,13 +477,9 @@ func TestDADResolve(t *testing.T) { } for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } e := channelLinkWithHeaderLength{ @@ -499,8 +491,11 @@ func TestDADResolve(t *testing.T) { var secureRNG bytes.Reader secureRNG.Reset(secureRNGBytes) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ - SecureRNG: &secureRNG, + Clock: clock, + RandSource: rand.NewSource(time.Now().UnixNano()), + SecureRNG: &secureRNG, NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: stack.DADConfigurations{ @@ -529,14 +524,10 @@ func TestDADResolve(t *testing.T) { t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err) } - // Address should not be considered bound to the NIC yet (DAD ongoing). - if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { - t.Fatal(err) - } - // Make sure the address does not resolve before the resolution time has // passed. - time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout) + const delta = time.Nanosecond + clock.Advance(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta) if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Error(err) } @@ -566,13 +557,14 @@ func TestDADResolve(t *testing.T) { } // Wait for DAD to resolve. + clock.Advance(delta) select { - case <-time.After(defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + t.Fatalf("expected DAD event for %s on NIC(%d)", addr1, nicID) } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { t.Error(err) @@ -610,7 +602,10 @@ func TestDADResolve(t *testing.T) { // Validate the sent Neighbor Solicitation messages. for i := uint8(0); i < test.dupAddrDetectTransmits; i++ { - p, _ := e.ReadContext(context.Background()) + p, ok := e.Read() + if !ok { + t.Fatal("packet didn't arrive") + } // Make sure its an IPv6 packet. if p.Proto != header.IPv6ProtocolNumber { @@ -736,11 +731,13 @@ func TestDADFail(t *testing.T) { dadConfigs.RetransmitTimer = time.Second * 2 e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -766,16 +763,17 @@ func TestDADFail(t *testing.T) { // Wait for DAD to fail and make sure the address did // not get resolved. + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the - // expected resolution time + extra 1s buffer, - // something is wrong. - t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + // If we don't get a failure event after the + // expected resolution time + extra 1s buffer, + // something is wrong. + t.Fatal("timed out waiting for DAD failure") } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Fatal(err) @@ -844,11 +842,13 @@ func TestDADStop(t *testing.T) { } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) @@ -866,15 +866,16 @@ func TestDADStop(t *testing.T) { test.stopFn(t, s) // Wait for DAD to fail (since the address was removed during DAD). + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): - // If we don't get a failure event after the expected resolution - // time + extra 1s buffer, something is wrong. - t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + // If we don't get a failure event after the expected resolution + // time + extra 1s buffer, something is wrong. + t.Fatal("timed out waiting for DAD failure") } if !test.skipFinalAddrCheck { @@ -925,10 +926,12 @@ func TestSetNDPConfigurations(t *testing.T) { dadC: make(chan ndpDADEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPDisp: &ndpDisp, })}, + Clock: clock, }) expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { @@ -1007,28 +1010,23 @@ func TestSetNDPConfigurations(t *testing.T) { t.Fatal(err) } - // Sleep until right (500ms before) before resolution to - // make sure the address didn't resolve on NIC(1) yet. - const delta = 500 * time.Millisecond - time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) + // Sleep until right before resolution to make sure the address didn't + // resolve on NIC(1) yet. + const delta = 1 + clock.Advance(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta) if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { t.Fatal(err) } // Wait for DAD to resolve. + clock.Advance(delta) select { - case <-time.After(2 * delta): - // We should get a resolution event after 500ms - // (delta) since we wait for 500ms less than the - // expected resolution time above to make sure - // that the address did not yet resolve. Waiting - // for 1s (2x delta) without a resolution event - // means something is wrong. - t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD resolution") } if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { t.Fatal(err) @@ -1040,7 +1038,7 @@ func TestSetNDPConfigurations(t *testing.T) { // raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options // and DHCPv6 configurations specified. func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer { - icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length()) + icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + optSer.Length() hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) @@ -1053,13 +1051,13 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo if managedAddress { // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing) // of the RA payload. - raPayload[1] |= (1 << 7) + raPayload[1] |= 1 << 7 } // Populate the Other Configurations flag field. if otherConfigurations { // The Other Configurations flag field is the 6th bit of byte #1 // (0-indexing) of the RA payload. - raPayload[1] |= (1 << 6) + raPayload[1] |= 1 << 6 } opts := ra.Options() opts.Serialize(optSer) @@ -1203,9 +1201,9 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { t.Run(fmt.Sprintf("HandleRAs(%s), Forwarding(%t), Enabled(%t)", handle, forwarding, enable), func(t *testing.T) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - prefixC: make(chan ndpPrefixEvent, 1), - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } ndpConfigs := test.config(enable) ndpConfigs.HandleRAs = handle @@ -1275,8 +1273,8 @@ func TestDynamicConfigurationsDisabled(t *testing.T) { t.Errorf("got v6Stats.UnhandledRouterAdvertisements.Value() = %d, want = %d", got, want) } select { - case e := <-ndpDisp.routerC: - t.Errorf("unexpectedly discovered a router when configured not to: %#v", e) + case e := <-ndpDisp.offLinkRouteC: + t.Errorf("unexpectedly updated an off-link route when configured not to: %#v", e) default: } select { @@ -1303,9 +1301,9 @@ func boolToUint64(v bool) uint64 { } // Check e to make sure that the event is for addr on nic with ID 1, and the -// discovered flag set to discovered. -func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string { - return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e)) +// update flag set to updated. +func checkOffLinkRouteEvent(e ndpOffLinkRouteEvent, router tcpip.Address, updated bool) string { + return cmp.Diff(ndpOffLinkRouteEvent{nicID: 1, subnet: header.IPv6EmptySubnet, router: router, updated: updated}, e, cmp.AllowUnexported(e)) } func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, bool)) { @@ -1338,56 +1336,13 @@ func testWithRAs(t *testing.T, f func(*testing.T, ipv6.HandleRAsConfiguration, b } } -// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered router when the dispatcher asks it not to. -func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) { - ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverDefaultRouters: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA for a router we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds)) - select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr2, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected router discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the router in the first place. - select { - case <-ndpDisp.routerC: - t.Fatal("should not have received any router events") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): - } -} - func TestRouterDiscovery(t *testing.T) { testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1396,30 +1351,32 @@ func TestRouterDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) - expectRouterEvent := func(addr tcpip.Address, discovered bool) { + expectOffLinkRouteEvent := func(addr tcpip.Address, updated bool) { t.Helper() select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, discovered); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, addr, updated); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected router discovery event") } } - expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { + expectAsyncOffLinkRouteInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) { t.Helper() + clock.Advance(timeout) select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, addr, false); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, addr, false); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } - case <-time.After(timeout): + default: t.Fatal("timed out waiting for router discovery event") } } @@ -1436,26 +1393,26 @@ func TestRouterDiscovery(t *testing.T) { // remembered. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) select { - case <-ndpDisp.routerC: - t.Fatal("unexpectedly discovered a router with 0 lifetime") + case <-ndpDisp.offLinkRouteC: + t.Fatal("unexpectedly updated an off-link route with 0 lifetime") default: } // Rx an RA from lladdr2 with a huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + expectOffLinkRouteEvent(llAddr2, true) // Rx an RA from another router (lladdr3) with non-zero lifetime. const l3LifetimeSeconds = 6 e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds)) - expectRouterEvent(llAddr3, true) + expectOffLinkRouteEvent(llAddr3, true) // Rx an RA from lladdr2 with lesser lifetime. const l2LifetimeSeconds = 2 e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds)) select { - case <-ndpDisp.routerC: - t.Fatal("Should not receive a router event when updating lifetimes for known routers") + case <-ndpDisp.offLinkRouteC: + t.Fatal("should not receive a off-link route event when updating lifetimes for known routers") default: } @@ -1466,15 +1423,15 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + expectAsyncOffLinkRouteInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second) // Rx an RA from lladdr2 with huge lifetime. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000)) - expectRouterEvent(llAddr2, true) + expectOffLinkRouteEvent(llAddr2, true) // Rx an RA from lladdr2 with zero lifetime. It should be invalidated. e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) - expectRouterEvent(llAddr2, false) + expectOffLinkRouteEvent(llAddr2, false) // Wait for lladdr3's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) @@ -1483,7 +1440,7 @@ func TestRouterDiscovery(t *testing.T) { // Wait for the normal lifetime plus an extra bit for the // router to get invalidated. If we don't get an invalidation // event after this time, then something is wrong. - expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout) + expectAsyncOffLinkRouteInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second) }) } @@ -1491,8 +1448,7 @@ func TestRouterDiscovery(t *testing.T) { // ipv6.MaxDiscoveredDefaultRouters discovered routers are remembered. func TestRouterDiscoveryMaxRouters(t *testing.T) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -1519,9 +1475,9 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { if i <= ipv6.MaxDiscoveredDefaultRouters { select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr, true); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, llAddr, true); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected router discovery event") @@ -1529,7 +1485,7 @@ func TestRouterDiscoveryMaxRouters(t *testing.T) { } else { select { - case <-ndpDisp.routerC: + case <-ndpDisp.offLinkRouteC: t.Fatal("should not have discovered a new router after we already discovered the max number of routers") default: } @@ -1543,51 +1499,6 @@ func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) st return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e)) } -// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not -// remember a discovered on-link prefix when the dispatcher asks it not to. -func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) { - prefix, subnet, _ := prefixSubnetAddr(0, "") - - ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - } - e := channel.New(0, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - HandleRAs: ipv6.HandlingRAsEnabledWhenForwardingDisabled, - DiscoverOnLinkPrefixes: true, - }, - NDPDisp: &ndpDisp, - })}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(1) = %s", err) - } - - // Receive an RA with prefix that we should not remember. - const lifetimeSeconds = 1 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0)) - select { - case e := <-ndpDisp.prefixC: - if diff := checkPrefixEvent(e, subnet, true); diff != "" { - t.Errorf("prefix event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected prefix discovery event") - } - - // Wait for the invalidation time plus some buffer to make sure we do - // not actually receive any invalidation events as we should not have - // remembered the prefix in the first place. - select { - case <-ndpDisp.prefixC: - t.Fatal("should not have received any prefix events") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): - } -} - func TestPrefixDiscovery(t *testing.T) { prefix1, subnet1, _ := prefixSubnetAddr(0, "") prefix2, subnet2, _ := prefixSubnetAddr(1, "") @@ -1595,10 +1506,10 @@ func TestPrefixDiscovery(t *testing.T) { testWithRAs(t, func(t *testing.T, handleRAs ipv6.HandleRAsConfiguration, forwarding bool) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1607,6 +1518,7 @@ func TestPrefixDiscovery(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -1667,12 +1579,13 @@ func TestPrefixDiscovery(t *testing.T) { // Wait for prefix2's most recent invalidation job plus some buffer to // expire. + clock.Advance(time.Duration(lifetime) * time.Second) select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet2, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for prefix discovery event") } @@ -1701,10 +1614,10 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { subnet := prefix.Subnet() ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -1713,6 +1626,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -1736,21 +1650,23 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // with infinite valid lifetime which should not get invalidated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) expectPrefixEvent(subnet, true) + clock.Advance(testInfiniteLifetime) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): + default: } // Receive an RA with finite lifetime. // The prefix should get invalidated after 1s. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0)) + clock.Advance(testInfiniteLifetime) select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet, false); diff != "" { t.Errorf("prefix event mismatch (-want +got):\n%s", diff) } - case <-time.After(testInfiniteLifetime): + default: t.Fatal("timed out waiting for prefix discovery event") } @@ -1761,19 +1677,21 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // Receive an RA with prefix with an infinite lifetime. // The prefix should not be invalidated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0)) + clock.Advance(testInfiniteLifetime) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout): + default: } // Receive an RA with a prefix with a lifetime value greater than the // set infinite lifetime value. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0)) + clock.Advance((testInfiniteLifetimeSeconds + 1) * time.Second) select { case <-ndpDisp.prefixC: t.Fatal("unexpectedly invalidated a prefix with infinite lifetime") - case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout): + default: } // Receive an RA with 0 lifetime. @@ -1786,8 +1704,7 @@ func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) { // ipv6.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered. func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) { ndpDisp := ndpDispatcher{ - prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), - rememberPrefix: true, + prefixC: make(chan ndpPrefixEvent, ipv6.MaxDiscoveredOnLinkPrefixes+3), } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -3618,6 +3535,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3626,6 +3544,7 @@ func TestAutoGenAddrRemoval(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -3659,10 +3578,11 @@ func TestAutoGenAddrRemoval(t *testing.T) { // Wait for the original valid lifetime to make sure the original job got // cancelled/cleaned up. + clock.Advance(lifetimeSeconds * time.Second) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + default: } } @@ -3784,6 +3704,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3792,6 +3713,7 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) if err := s.CreateNIC(1, e); err != nil { @@ -3821,30 +3743,36 @@ func TestAutoGenAddrStaticConflict(t *testing.T) { // Should not get an invalidation event after the PI's invalidation // time. + clock.Advance(lifetimeSeconds * time.Second) select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") - case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout): + default: } if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) { t.Fatalf("Should have %s in the list of addresses", addr1) } } +func makeSecretKey(t *testing.T) []byte { + secretKey := make([]byte, header.OpaqueIIDSecretKeyMinBytes) + n, err := cryptorand.Read(secretKey) + if err != nil { + t.Fatalf("cryptorand.Read(_): %s", err) + } + if l := len(secretKey); n != l { + t.Fatalf("got cryptorand.Read(_) = (%d, nil), want = (%d, nil)", n, l) + } + return secretKey +} + // TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use // opaque interface identifiers when configured to do so. func TestAutoGenAddrWithOpaqueIID(t *testing.T) { const nicID = 1 const nicName = "nic1" - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + + secretKey := makeSecretKey(t) prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1) prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1) @@ -3866,6 +3794,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -3880,6 +3809,7 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -3918,12 +3848,13 @@ func TestAutoGenAddrWithOpaqueIID(t *testing.T) { } // Wait for addr of prefix1 to be invalidated. + clock.Advance(validLifetimeSecondPrefix1 * time.Second) select { case e := <-ndpDisp.autoGenAddrC: if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for addr auto gen event") } if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { @@ -3949,15 +3880,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { }() ipv6.MaxDesyncFactor = time.Nanosecond - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + secretKey := makeSecretKey(t) prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) @@ -4236,13 +4159,12 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { addrType := addrType t.Run(addrType.name, func(t *testing.T) { - t.Parallel() - ndpDisp := ndpDispatcher{ dadC: make(chan ndpDADEvent, 1), autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: addrType.autoGenLinkLocal, @@ -4253,6 +4175,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { RetransmitTimer: retransmitTimer, }, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -4297,7 +4220,7 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { select { case e := <-ndpDisp.autoGenAddrC: t.Fatalf("unexpectedly got an auto-generated address event = %+v", e) - case <-time.After(defaultAsyncNegativeEventTimeout): + default: } }) } @@ -4314,15 +4237,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { const maxRetries = 1 const lifetimeSeconds = 5 - var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte - secretKey := secretKeyBuf[:] - n, err := rand.Read(secretKey) - if err != nil { - t.Fatalf("rand.Read(_): %s", err) - } - if n != header.OpaqueIIDSecretKeyMinBytes { - t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes) - } + secretKey := makeSecretKey(t) prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1) @@ -4331,6 +4246,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2), } e := channel.New(0, 1280, linkAddr1) + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -4350,6 +4266,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { SecretKey: secretKey, }, })}, + Clock: clock, }) opts := stack.NICOptions{Name: nicName} if err := s.CreateNICWithOptions(nicID, e, opts); err != nil { @@ -4380,7 +4297,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { expectAutoGenAddrEvent(addr, newAddr) // Simulate a DAD conflict after some time has passed. - time.Sleep(failureTimer) + clock.Advance(failureTimer) rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(addr, invalidatedAddr) select { @@ -4395,12 +4312,13 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { // Let the next address resolve. addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey)) expectAutoGenAddrEvent(addr, newAddr) + clock.Advance(dadTransmits * retransmitTimer) select { case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for DAD event") } @@ -4414,6 +4332,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { // // We expect either just the invalidation event or the deprecation event // followed by the invalidation event. + clock.Advance(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer) select { case e := <-ndpDisp.autoGenAddrC: if e.eventType == deprecatedAddr { @@ -4426,7 +4345,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } - case <-time.After(defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation") } } else { @@ -4434,7 +4353,7 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) } } - case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): + default: t.Fatal("timed out waiting for auto gen addr event") } } @@ -4696,11 +4615,9 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { ) ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, 1), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, 1), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, 1), + prefixC: make(chan ndpPrefixEvent, 1), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ @@ -4743,17 +4660,17 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { ), ) select { - case e := <-ndpDisp.routerC: - if diff := checkRouterEvent(e, llAddr3, true /* discovered */); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + case e := <-ndpDisp.offLinkRouteC: + if diff := checkOffLinkRouteEvent(e, llAddr3, true /* discovered */); diff != "" { + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID) + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID) } select { case e := <-ndpDisp.prefixC: if diff := checkPrefixEvent(e, subnet, true /* discovered */); diff != "" { - t.Errorf("router event mismatch (-want +got):\n%s", diff) + t.Errorf("off-link route event mismatch (-want +got):\n%s", diff) } default: t.Errorf("expected prefix event for %s on NIC(%d)", prefix, nicID) @@ -4775,8 +4692,8 @@ func TestNoCleanupNDPStateWhenForwardingEnabled(t *testing.T) { t.Fatalf("SetForwardingDefaultAndAllNICs(%d, %t): %s", ipv6.ProtocolNumber, forwarding, err) } select { - case e := <-ndpDisp.routerC: - t.Errorf("unexpected router event = %#v", e) + case e := <-ndpDisp.offLinkRouteC: + t.Errorf("unexpected off-link route event = %#v", e) default: } select { @@ -4862,12 +4779,11 @@ func TestCleanupNDPState(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { ndpDisp := ndpDispatcher{ - routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents), - rememberRouter: true, - prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), - rememberPrefix: true, - autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), + offLinkRouteC: make(chan ndpOffLinkRouteEvent, maxRouterAndPrefixEvents), + prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents), + autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents), } + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, @@ -4879,16 +4795,17 @@ func TestCleanupNDPState(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, }) - expectRouterEvent := func() (bool, ndpRouterEvent) { + expectOffLinkRouteEvent := func() (bool, ndpOffLinkRouteEvent) { select { - case e := <-ndpDisp.routerC: + case e := <-ndpDisp.offLinkRouteC: return true, e default: } - return false, ndpRouterEvent{} + return false, ndpOffLinkRouteEvent{} } expectPrefixEvent := func() (bool, ndpPrefixEvent) { @@ -4933,8 +4850,8 @@ func TestCleanupNDPState(t *testing.T) { // multiple addresses. e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1) @@ -4944,8 +4861,8 @@ func TestCleanupNDPState(t *testing.T) { } e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID1) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1) @@ -4955,8 +4872,8 @@ func TestCleanupNDPState(t *testing.T) { } e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr3, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2) @@ -4966,8 +4883,8 @@ func TestCleanupNDPState(t *testing.T) { } e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds)) - if ok, _ := expectRouterEvent(); !ok { - t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2) + if ok, _ := expectOffLinkRouteEvent(); !ok { + t.Errorf("expected off-link route event for %s on NIC(%d)", llAddr4, nicID2) } if ok, _ := expectPrefixEvent(); !ok { t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2) @@ -5008,14 +4925,14 @@ func TestCleanupNDPState(t *testing.T) { test.cleanupFn(t, s) // Collect invalidation events after having NDP state cleaned up. - gotRouterEvents := make(map[ndpRouterEvent]int) + gotOffLinkRouteEvents := make(map[ndpOffLinkRouteEvent]int) for i := 0; i < maxRouterAndPrefixEvents; i++ { - ok, e := expectRouterEvent() + ok, e := expectOffLinkRouteEvent() if !ok { - t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) + t.Errorf("expected %d off-link route events after becoming a router; got = %d", maxRouterAndPrefixEvents, i) break } - gotRouterEvents[e]++ + gotOffLinkRouteEvents[e]++ } gotPrefixEvents := make(map[ndpPrefixEvent]int) for i := 0; i < maxRouterAndPrefixEvents; i++ { @@ -5042,14 +4959,14 @@ func TestCleanupNDPState(t *testing.T) { t.FailNow() } - expectedRouterEvents := map[ndpRouterEvent]int{ - {nicID: nicID1, addr: llAddr3, discovered: false}: 1, - {nicID: nicID1, addr: llAddr4, discovered: false}: 1, - {nicID: nicID2, addr: llAddr3, discovered: false}: 1, - {nicID: nicID2, addr: llAddr4, discovered: false}: 1, + expectedOffLinkRouteEvents := map[ndpOffLinkRouteEvent]int{ + {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, + {nicID: nicID1, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, + {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr3, updated: false}: 1, + {nicID: nicID2, subnet: header.IPv6EmptySubnet, router: llAddr4, updated: false}: 1, } - if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" { - t.Errorf("router events mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(expectedOffLinkRouteEvents, gotOffLinkRouteEvents); diff != "" { + t.Errorf("off-link route events mismatch (-want +got):\n%s", diff) } expectedPrefixEvents := map[ndpPrefixEvent]int{ {nicID: nicID1, prefix: subnet1, discovered: false}: 1, @@ -5111,10 +5028,10 @@ func TestCleanupNDPState(t *testing.T) { // Should not get any more events (invalidation timers should have been // cancelled when the NDP state was cleaned up). - time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout) + clock.Advance(lifetimeSeconds * time.Second) select { - case <-ndpDisp.routerC: - t.Error("unexpected router event") + case <-ndpDisp.offLinkRouteC: + t.Error("unexpected off-link route event") default: } select { @@ -5139,7 +5056,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { ndpDisp := ndpDispatcher{ dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1), - rememberRouter: true, } e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ @@ -5241,6 +5157,23 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { expectNoDHCPv6Event() } +var _ rand.Source = (*savingRandSource)(nil) + +type savingRandSource struct { + s rand.Source + + lastInt63 int64 +} + +func (d *savingRandSource) Int63() int64 { + i := d.s.Int63() + d.lastInt63 = i + return i +} +func (d *savingRandSource) Seed(seed int64) { + d.s.Seed(seed) +} + // TestRouterSolicitation tests the initial Router Solicitations that are sent // when a NIC newly becomes enabled. func TestRouterSolicitation(t *testing.T) { @@ -5407,6 +5340,9 @@ func TestRouterSolicitation(t *testing.T) { t.Fatalf("unexpectedly got a packet = %#v", p) } } + randSource := savingRandSource{ + s: rand.NewSource(time.Now().UnixNano()), + } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -5416,8 +5352,10 @@ func TestRouterSolicitation(t *testing.T) { MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, }, })}, - Clock: clock, + Clock: clock, + RandSource: &randSource, }) + if err := s.CreateNIC(nicID, &e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -5430,19 +5368,27 @@ func TestRouterSolicitation(t *testing.T) { // Make sure each RS is sent at the right time. remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay) + if remaining != 0 { + maxRtrSolicitDelay := test.maxRtrSolicitDelay + if maxRtrSolicitDelay < 0 { + maxRtrSolicitDelay = ipv6.DefaultNDPConfigurations().MaxRtrSolicitationDelay + } + var actualRtrSolicitDelay time.Duration + if maxRtrSolicitDelay != 0 { + actualRtrSolicitDelay = time.Duration(randSource.lastInt63) % maxRtrSolicitDelay + } + waitForPkt(actualRtrSolicitDelay) remaining-- } subTest.afterFirstRS(t, s) - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + for ; remaining != 0; remaining-- { + if test.effectiveRtrSolicitInt != 0 { waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) waitForPkt(time.Nanosecond) } else { - waitForPkt(test.effectiveRtrSolicitInt) + waitForPkt(0) } } @@ -5538,12 +5484,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { e := channel.New(maxRtrSolicitations, 1280, linkAddr1) - waitForPkt := func(timeout time.Duration) { + waitForPkt := func(clock *faketime.ManualClock, timeout time.Duration) { t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := e.ReadContext(ctx) + clock.Advance(timeout) + p, ok := e.Read() if !ok { t.Fatal("timed out waiting for packet") } @@ -5557,6 +5502,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { checker.TTL(header.NDPHopLimit), checker.NDPRS()) } + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ NDPConfigs: ipv6.NDPConfigurations{ @@ -5566,6 +5512,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { MaxRtrSolicitationDelay: delay, }, })}, + Clock: clock, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -5573,13 +5520,11 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stop soliciting routers. test.stopFn(t, s, true /* first */) - ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(delay) + if _, ok := e.Read(); ok { // A single RS may have been sent before solicitations were stopped. - ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok = e.ReadContext(ctx); ok { + clock.Advance(interval) + if _, ok = e.Read(); ok { t.Fatal("should not have sent more than one RS message") } } @@ -5587,9 +5532,8 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Stopping router solicitations after it has already been stopped should // do nothing. test.stopFn(t, s, false /* first */) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(delay) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got a packet after router solicitation has been stopepd") } @@ -5600,21 +5544,19 @@ func TestStopStartSolicitingRouters(t *testing.T) { // Start soliciting routers. test.startFn(t, s) - waitForPkt(delay + defaultAsyncPositiveEventTimeout) - waitForPkt(interval + defaultAsyncPositiveEventTimeout) - waitForPkt(interval + defaultAsyncPositiveEventTimeout) - ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + waitForPkt(clock, delay) + waitForPkt(clock, interval) + waitForPkt(clock, interval) + clock.Advance(interval) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got an extra packet after sending out the expected RSs") } // Starting router solicitations after it has already completed should do // nothing. test.startFn(t, s) - ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { + clock.Advance(interval) + if _, ok := e.Read(); ok { t.Fatal("unexpectedly got a packet after finishing router solicitations") } }) diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 509f5ce5c..08857e1a9 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -310,7 +310,7 @@ func (n *neighborCache) handleUpperLevelConfirmation(addr tcpip.Address) { func (n *neighborCache) init(nic *nic, r LinkAddressResolver) { *n = neighborCache{ nic: nic, - state: NewNUDState(nic.stack.nudConfigs, nic.stack.randomGenerator), + state: NewNUDState(nic.stack.nudConfigs, nic.stack.clock, nic.stack.randomGenerator), linkRes: r, } n.mu.Lock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 9821a18d3..7de25fe37 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -15,8 +15,6 @@ package stack import ( - "bytes" - "encoding/binary" "fmt" "math" "math/rand" @@ -48,9 +46,6 @@ const ( // be sent to all nodes. testEntryBroadcastAddr = tcpip.Address("broadcast") - // testEntryLocalAddr is the source address of neighbor probes. - testEntryLocalAddr = tcpip.Address("local_addr") - // testEntryBroadcastLinkAddr is a special link address sent back to // multicast neighbor probes. testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") @@ -95,7 +90,7 @@ func newTestNeighborResolver(nudDisp NUDDispatcher, config NUDConfigurations, cl randomGenerator: rng, }, id: 1, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), }, linkRes) return linkRes } @@ -106,20 +101,24 @@ type testEntryStore struct { entriesMap map[tcpip.Address]NeighborEntry } -func toAddress(i int) tcpip.Address { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint16(i)) - return tcpip.Address(buf.String()) +func toAddress(i uint16) tcpip.Address { + return tcpip.Address([]byte{ + 1, + 0, + byte(i >> 8), + byte(i), + }) } -func toLinkAddress(i int) tcpip.LinkAddress { - buf := new(bytes.Buffer) - binary.Write(buf, binary.BigEndian, uint8(1)) - binary.Write(buf, binary.BigEndian, uint8(0)) - binary.Write(buf, binary.BigEndian, uint32(i)) - return tcpip.LinkAddress(buf.String()) +func toLinkAddress(i uint16) tcpip.LinkAddress { + return tcpip.LinkAddress([]byte{ + 1, + 0, + 0, + 0, + byte(i >> 8), + byte(i), + }) } // newTestEntryStore returns a testEntryStore pre-populated with entries. @@ -127,7 +126,7 @@ func newTestEntryStore() *testEntryStore { store := &testEntryStore{ entriesMap: make(map[tcpip.Address]NeighborEntry), } - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) linkAddr := toLinkAddress(i) @@ -140,15 +139,15 @@ func newTestEntryStore() *testEntryStore { } // size returns the number of entries in the store. -func (s *testEntryStore) size() int { +func (s *testEntryStore) size() uint16 { s.mu.RLock() defer s.mu.RUnlock() - return len(s.entriesMap) + return uint16(len(s.entriesMap)) } // entry returns the entry at index i. Returns an empty entry and false if i is // out of bounds. -func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { +func (s *testEntryStore) entry(i uint16) (NeighborEntry, bool) { return s.entryByAddr(toAddress(i)) } @@ -166,7 +165,7 @@ func (s *testEntryStore) entries() []NeighborEntry { entries := make([]NeighborEntry, 0, len(s.entriesMap)) s.mu.RLock() defer s.mu.RUnlock() - for i := 0; i < entryStoreSize; i++ { + for i := uint16(0); i < entryStoreSize; i++ { addr := toAddress(i) if entry, ok := s.entriesMap[addr]; ok { entries = append(entries, entry) @@ -176,7 +175,7 @@ func (s *testEntryStore) entries() []NeighborEntry { } // set modifies the link addresses of an entry. -func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { +func (s *testEntryStore) set(i uint16, linkAddr tcpip.LinkAddress) { addr := toAddress(i) s.mu.Lock() defer s.mu.Unlock() @@ -236,13 +235,6 @@ func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return 0 } -type entryEvent struct { - nicID tcpip.NICID - address tcpip.Address - linkAddr tcpip.LinkAddress - state NeighborState -} - func TestNeighborCacheGetConfig(t *testing.T) { nudDisp := testNUDDispatcher{} c := DefaultNUDConfigurations() @@ -301,10 +293,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }) } @@ -313,10 +305,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }) @@ -347,10 +339,10 @@ func addReachableEntryWithRemoved(nudDisp *testNUDDispatcher, clock *faketime.Ma EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -419,10 +411,10 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -461,7 +453,7 @@ func newTestContext(c NUDConfigurations) testContext { } type overflowOptions struct { - startAtEntryIndex int + startAtEntryIndex uint16 wantStaticEntries []NeighborEntry } @@ -500,12 +492,12 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if !ok { return fmt.Errorf("got c.linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(c.linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(c.linkRes.entries.size()-i-1) * typicalLatency wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now().Add(-durationReachableNanos), } wantUnorderedEntries = append(wantUnorderedEntries, wantEntry) } @@ -571,10 +563,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, } @@ -616,10 +608,10 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -663,10 +655,10 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -689,10 +681,10 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -733,10 +725,10 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -758,10 +750,10 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -814,20 +806,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -844,10 +836,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -875,10 +867,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Errorf("unexpected error from c.linkRes.neigh.entry(%s, \"\", nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("c.linkRes.neigh.entry(%s, \"\", nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -890,10 +882,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -910,10 +902,10 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + UpdatedAt: c.clock.Now(), }, }, } @@ -947,10 +939,10 @@ func TestNeighborCacheClear(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + UpdatedAt: clock.Now(), }, }, } @@ -973,20 +965,20 @@ func TestNeighborCacheClear(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + UpdatedAt: clock.Now(), }, }, } @@ -1027,10 +1019,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { EventType: entryTestRemoved, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: c.clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: c.clock.Now(), }, }, } @@ -1062,13 +1054,13 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { clock := faketime.NewManualClock() linkRes := newTestNeighborResolver(&nudDisp, config, clock) - startedAt := clock.NowNanoseconds() + startedAt := clock.Now() // The following logic is very similar to overflowCache, but // periodically refreshes the frequently used entry. // Fill the neighbor cache to capacity - for i := 0; i < neighborCacheSize; i++ { + for i := uint16(0); i < neighborCacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) @@ -1084,7 +1076,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Keep adding more entries - for i := neighborCacheSize; i < linkRes.entries.size(); i++ { + for i := uint16(neighborCacheSize); i < linkRes.entries.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { if _, _, err := linkRes.neigh.entry(frequentlyUsedEntry.Addr, "", nil); err != nil { @@ -1118,7 +1110,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { State: Reachable, // Can be inferred since the frequently used entry is the first to // be created and transitioned to Reachable. - UpdatedAtNanos: startedAt + typicalLatency.Nanoseconds(), + UpdatedAt: startedAt.Add(typicalLatency), }, } @@ -1127,12 +1119,12 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now().Add(-durationReachableNanos), }) } @@ -1190,12 +1182,12 @@ func TestNeighborCacheConcurrent(t *testing.T) { if !ok { t.Errorf("got linkRes.entries.entry(%d) = _, false, want = true", i) } - durationReachableNanos := int64(linkRes.entries.size()-i-1) * typicalLatency.Nanoseconds() + durationReachableNanos := time.Duration(linkRes.entries.size()-i-1) * typicalLatency wantUnsortedEntries = append(wantUnsortedEntries, NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds() - durationReachableNanos, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now().Add(-durationReachableNanos), }) } @@ -1244,10 +1236,10 @@ func TestNeighborCacheReplace(t *testing.T) { t.Fatalf("linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Delay, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Delay, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1263,10 +1255,10 @@ func TestNeighborCacheReplace(t *testing.T) { t.Errorf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: updatedLinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, e); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1301,10 +1293,10 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { t.Fatalf("unexpected error from linkRes.neigh.entry(%s, '', nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(want, got); diff != "" { t.Errorf("linkRes.neigh.entry(%s, '', nil) mismatch (-want, +got):\n%s", entry.Addr, diff) @@ -1405,10 +1397,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestAdded, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -1436,10 +1428,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -1455,10 +1447,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { { wantEntries := []NeighborEntry{ { - Addr: entry.Addr, - LinkAddr: "", - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Unreachable, + UpdatedAt: clock.Now(), }, } if diff := cmp.Diff(linkRes.neigh.entries(), wantEntries, unorderedEntriesDiffOpts()...); diff != "" { @@ -1488,10 +1480,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: "", - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: "", + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -1518,10 +1510,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { EventType: entryTestChanged, NICID: 1, Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1541,10 +1533,10 @@ func TestNeighborCacheRetryResolution(t *testing.T) { } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + UpdatedAt: clock.Now(), } if diff := cmp.Diff(gotEntry, wantEntry); diff != "" { t.Fatalf("neighbor entry mismatch (-got, +want):\n%s", diff) @@ -1561,9 +1553,9 @@ func BenchmarkCacheClear(b *testing.B) { linkRes.delay = 0 // Clear for every possible size of the cache - for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { + for cacheSize := uint16(0); cacheSize < neighborCacheSize; cacheSize++ { // Fill the neighbor cache to capacity. - for i := 0; i < cacheSize; i++ { + for i := uint16(0); i < cacheSize; i++ { entry, ok := linkRes.entries.entry(i) if !ok { b.Fatalf("got linkRes.entries.entry(%d) = _, false, want = true", i) diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 6d95e1664..0a59eecdd 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -31,10 +31,10 @@ const ( // NeighborEntry describes a neighboring device in the local network. type NeighborEntry struct { - Addr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAtNanos int64 + Addr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time } // NeighborState defines the state of a NeighborEntry within the Neighbor @@ -138,10 +138,10 @@ func newNeighborEntry(cache *neighborCache, remoteAddr tcpip.Address, nudState * // calling `setStateLocked`. func newStaticNeighborEntry(cache *neighborCache, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { entry := NeighborEntry{ - Addr: addr, - LinkAddr: linkAddr, - State: Static, - UpdatedAtNanos: cache.nic.stack.clock.NowNanoseconds(), + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: cache.nic.stack.clock.Now(), } n := &neighborEntry{ cache: cache, @@ -166,14 +166,20 @@ func (e *neighborEntry) notifyCompletionLocked(err tcpip.Error) { if ch := e.mu.done; ch != nil { close(ch) e.mu.done = nil - // Dequeue the pending packets in a new goroutine to not hold up the current + // Dequeue the pending packets asynchronously to not hold up the current // goroutine as writing packets may be a costly operation. // // At the time of writing, when writing packets, a neighbor's link address // is resolved (which ends up obtaining the entry's lock) while holding the - // link resolution queue's lock. Dequeuing packets in a new goroutine avoids - // a lock ordering violation. - go e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err) + // link resolution queue's lock. Dequeuing packets asynchronously avoids a + // lock ordering violation. + // + // NB: this is equivalent to spawning a goroutine directly using the go + // keyword but allows tests that use manual clocks to deterministically + // wait for this work to complete. + e.cache.nic.stack.clock.AfterFunc(0, func() { + e.cache.nic.linkResQueue.dequeue(ch, e.mu.neigh.LinkAddr, err) + }) } } @@ -224,7 +230,7 @@ func (e *neighborEntry) cancelTimerLocked() { // // Precondition: e.mu MUST be locked. func (e *neighborEntry) removeLocked() { - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() e.dispatchRemoveEventLocked() e.cancelTimerLocked() // TODO(https://gvisor.dev/issues/5583): test the case where this function is @@ -246,7 +252,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { prev := e.mu.neigh.State e.mu.neigh.State = next - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() config := e.nudState.Config() switch next { @@ -307,7 +313,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { err = e.cache.linkRes.LinkAddressRequest(addr, "" /* localAddr */, linkAddr) @@ -354,14 +360,14 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { case Unknown, Unreachable: prev := e.mu.neigh.State e.mu.neigh.State = Incomplete - e.mu.neigh.UpdatedAtNanos = e.cache.nic.stack.clock.NowNanoseconds() + e.mu.neigh.UpdatedAt = e.cache.nic.stack.clock.Now() switch prev { case Unknown: e.dispatchAddEventLocked() case Unreachable: e.dispatchChangeEventLocked() - e.cache.nic.stats.Neighbor.UnreachableEntryLookups.Increment() + e.cache.nic.stats.neighbor.unreachableEntryLookups.Increment() } config := e.nudState.Config() @@ -378,7 +384,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // a shared lock. e.mu.timer = timer{ done: &done, - timer: e.cache.nic.stack.Clock().AfterFunc(0, func() { + timer: e.cache.nic.stack.Clock().AfterFunc(immediateDuration, func() { var err tcpip.Error = &tcpip.ErrTimeout{} if remaining != 0 { // As per RFC 4861 section 7.2.2: diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 1d39ee73d..59d86d6d4 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -36,11 +36,6 @@ const ( entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") - - // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, - // except where another value is explicitly used. It is chosen to match the - // MTU of loopback interfaces on Linux systems. - entryTestNetDefaultMTU = 65536 ) var ( @@ -196,13 +191,13 @@ func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.A // ResolveStaticAddress attempts to resolve address without sending requests. // It either resolves the name immediately or returns the empty LinkAddress. -func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { +func (*entryTestLinkResolver) ResolveStaticAddress(tcpip.Address) (tcpip.LinkAddress, bool) { return "", false } // LinkAddressProtocol returns the network protocol of the addresses this // resolver can resolve. -func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { +func (*entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { return entryTestNetNumber } @@ -219,7 +214,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e nudConfigs: c, randomGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), }, - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } netEP := (&testIPv6Protocol{}).NewEndpoint(&nic, nil) nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ @@ -354,10 +349,10 @@ func unknownToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes * EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } @@ -415,10 +410,10 @@ func unknownToStale(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entry EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -446,7 +441,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { // UpdatedAt should remain the same during address resolution. e.mu.Lock() - startedAt := e.mu.neigh.UpdatedAtNanos + startedAt := e.mu.neigh.UpdatedAt e.mu.Unlock() // Wait for the rest of the reachability probe transmissions, signifying @@ -470,7 +465,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Lock() - if got, want := e.mu.neigh.UpdatedAtNanos, startedAt; got != want { + if got, want := e.mu.neigh.UpdatedAt, startedAt; got != want { t.Errorf("got e.mu.neigh.UpdatedAt = %q, want = %q", got, want) } e.mu.Unlock() @@ -485,10 +480,10 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -547,10 +542,10 @@ func incompleteToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -644,10 +639,10 @@ func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -678,10 +673,10 @@ func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -757,10 +752,10 @@ func incompleteToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *tes EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -943,10 +938,10 @@ func reachableToStale(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDis EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -998,10 +993,10 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1050,10 +1045,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1102,10 +1097,10 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1191,10 +1186,10 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1243,10 +1238,10 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1284,10 +1279,10 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1332,10 +1327,10 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1391,10 +1386,10 @@ func staleToDelay(e *neighborEntry, nudDisp *testNUDDispatcher, linkRes *entryTe EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + UpdatedAt: clock.Now(), }, }, } @@ -1443,10 +1438,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1498,10 +1493,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1553,10 +1548,10 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -1645,10 +1640,10 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1697,10 +1692,10 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1770,10 +1765,10 @@ func delayToProbe(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDDispatc EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + UpdatedAt: clock.Now(), }, }, } @@ -1827,10 +1822,10 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -1882,10 +1877,10 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + UpdatedAt: clock.Now(), }, }, } @@ -2003,10 +1998,10 @@ func probeToReachableWithFlags(e *neighborEntry, nudDisp *testNUDDispatcher, lin EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: linkAddr, - State: Reachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: linkAddr, + State: Reachable, + UpdatedAt: clock.Now(), }, }, } @@ -2155,10 +2150,10 @@ func probeToUnreachable(c NUDConfigurations, e *neighborEntry, nudDisp *testNUDD EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Unreachable, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Unreachable, + UpdatedAt: clock.Now(), }, }, } @@ -2227,10 +2222,10 @@ func unreachableToIncomplete(e *neighborEntry, nudDisp *testNUDDispatcher, linkR EventType: entryTestChanged, NICID: entryTestNICID, Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - UpdatedAtNanos: clock.NowNanoseconds(), + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + UpdatedAt: clock.Now(), }, }, } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8d615500f..9cac6bbd1 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -51,7 +51,7 @@ type nic struct { name string context NICContext - stats NICStats + stats sharedStats // The network endpoints themselves may be modified by calling the interface's // methods, but the map reference and entries must be constant. @@ -78,26 +78,13 @@ type nic struct { } } -// NICStats hold statistics for a NIC. -type NICStats struct { - Tx DirectionStats - Rx DirectionStats - - DisabledRx DirectionStats - - Neighbor NeighborStats -} - -func makeNICStats() NICStats { - var s NICStats - tcpip.InitStatCounters(reflect.ValueOf(&s).Elem()) - return s -} - -// DirectionStats includes packet and byte counts. -type DirectionStats struct { - Packets *tcpip.StatCounter - Bytes *tcpip.StatCounter +// makeNICStats initializes the NIC statistics and associates them to the global +// NIC statistics. +func makeNICStats(global tcpip.NICStats) sharedStats { + var stats sharedStats + tcpip.InitStatCounters(reflect.ValueOf(&stats.local).Elem()) + stats.init(&stats.local, &global) + return stats } type packetEndpointList struct { @@ -150,7 +137,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC id: id, name: name, context: ctx, - stats: makeNICStats(), + stats: makeNICStats(stack.Stats().NICs), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]*linkResolver), duplicateAddressDetectors: make(map[tcpip.NetworkProtocolNumber]DuplicateAddressDetector), @@ -382,8 +369,8 @@ func (n *nic) writePacket(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt return err } - n.stats.Tx.Packets.Increment() - n.stats.Tx.Bytes.IncrementBy(uint64(numBytes)) + n.stats.tx.packets.Increment() + n.stats.tx.bytes.IncrementBy(uint64(numBytes)) return nil } @@ -399,13 +386,13 @@ func (n *nic) writePackets(r RouteInfo, protocol tcpip.NetworkProtocolNumber, pk } writtenPackets, err := n.LinkEndpoint.WritePackets(r, pkts, protocol) - n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets)) + n.stats.tx.packets.IncrementBy(uint64(writtenPackets)) writtenBytes := 0 for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() { writtenBytes += pb.Size() } - n.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes)) + n.stats.tx.bytes.IncrementBy(uint64(writtenBytes)) return writtenPackets, err } @@ -718,18 +705,18 @@ func (n *nic) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if !enabled { n.mu.RUnlock() - n.stats.DisabledRx.Packets.Increment() - n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.disabledRx.packets.Increment() + n.stats.disabledRx.bytes.IncrementBy(uint64(pkt.Data().Size())) return } - n.stats.Rx.Packets.Increment() - n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data().Size())) + n.stats.rx.packets.Increment() + n.stats.rx.bytes.IncrementBy(uint64(pkt.Data().Size())) networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL3ProtocolRcvdPackets.Increment() return } @@ -786,7 +773,7 @@ func (n *nic) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { - n.stack.stats.UnknownProtocolRcvdPackets.Increment() + n.stats.unknownL4ProtocolRcvdPackets.Increment() return TransportPacketProtocolUnreachable } @@ -800,27 +787,26 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. if pkt.TransportHeader().View().IsEmpty() { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set yet, parse it + // here. See icmp/protocol.go:protocol.Parse for a full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { // ICMP packets may be longer, but until icmp.Parse is implemented, here // we parse it using the minimum size. if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() // We consider a malformed transport packet handled because there is // nothing the caller can do. return TransportPacketHandled } } else if !transProto.Parse(pkt) { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } } srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View()) if err != nil { - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled } @@ -852,7 +838,7 @@ func (n *nic) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt // If it doesn't handle it then we should do so. switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: - n.stack.stats.MalformedRcvdPackets.Increment() + n.stats.malformedL4RcvdPackets.Increment() return TransportPacketHandled case UnknownDestinationPacketUnhandled: return TransportPacketDestinationPortUnreachable @@ -1000,3 +986,32 @@ func (n *nic) checkDuplicateAddress(protocol tcpip.NetworkProtocolNumber, addr t return d.CheckDuplicateAddress(addr, h), nil } + +func (n *nic) setForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + ep := n.getNetworkEndpoint(protocol) + if ep == nil { + return &tcpip.ErrUnknownProtocol{} + } + + forwardingEP, ok := ep.(ForwardingNetworkEndpoint) + if !ok { + return &tcpip.ErrNotSupported{} + } + + forwardingEP.SetForwarding(enable) + return nil +} + +func (n *nic) forwarding(protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) { + ep := n.getNetworkEndpoint(protocol) + if ep == nil { + return false, &tcpip.ErrUnknownProtocol{} + } + + forwardingEP, ok := ep.(ForwardingNetworkEndpoint) + if !ok { + return false, &tcpip.ErrNotSupported{} + } + + return forwardingEP.Forwarding(), nil +} diff --git a/pkg/tcpip/stack/nic_stats.go b/pkg/tcpip/stack/nic_stats.go new file mode 100644 index 000000000..1773d5e8d --- /dev/null +++ b/pkg/tcpip/stack/nic_stats.go @@ -0,0 +1,74 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "gvisor.dev/gvisor/pkg/tcpip" +) + +type sharedStats struct { + local tcpip.NICStats + multiCounterNICStats +} + +// LINT.IfChange(multiCounterNICPacketStats) + +type multiCounterNICPacketStats struct { + packets tcpip.MultiCounterStat + bytes tcpip.MultiCounterStat +} + +func (m *multiCounterNICPacketStats) init(a, b *tcpip.NICPacketStats) { + m.packets.Init(a.Packets, b.Packets) + m.bytes.Init(a.Bytes, b.Bytes) +} + +// LINT.ThenChange(../../tcpip.go:NICPacketStats) + +// LINT.IfChange(multiCounterNICNeighborStats) + +type multiCounterNICNeighborStats struct { + unreachableEntryLookups tcpip.MultiCounterStat +} + +func (m *multiCounterNICNeighborStats) init(a, b *tcpip.NICNeighborStats) { + m.unreachableEntryLookups.Init(a.UnreachableEntryLookups, b.UnreachableEntryLookups) +} + +// LINT.ThenChange(../../tcpip.go:NICNeighborStats) + +// LINT.IfChange(multiCounterNICStats) + +type multiCounterNICStats struct { + unknownL3ProtocolRcvdPackets tcpip.MultiCounterStat + unknownL4ProtocolRcvdPackets tcpip.MultiCounterStat + malformedL4RcvdPackets tcpip.MultiCounterStat + tx multiCounterNICPacketStats + rx multiCounterNICPacketStats + disabledRx multiCounterNICPacketStats + neighbor multiCounterNICNeighborStats +} + +func (m *multiCounterNICStats) init(a, b *tcpip.NICStats) { + m.unknownL3ProtocolRcvdPackets.Init(a.UnknownL3ProtocolRcvdPackets, b.UnknownL3ProtocolRcvdPackets) + m.unknownL4ProtocolRcvdPackets.Init(a.UnknownL4ProtocolRcvdPackets, b.UnknownL4ProtocolRcvdPackets) + m.malformedL4RcvdPackets.Init(a.MalformedL4RcvdPackets, b.MalformedL4RcvdPackets) + m.tx.init(&a.Tx, &b.Tx) + m.rx.init(&a.Rx, &b.Rx) + m.disabledRx.init(&a.DisabledRx, &b.DisabledRx) + m.neighbor.init(&a.Neighbor, &b.Neighbor) +} + +// LINT.ThenChange(../../tcpip.go:NICStats) diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 8a3005295..5cb342f78 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -15,11 +15,13 @@ package stack import ( + "reflect" "testing" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/testutil" ) var _ AddressableEndpoint = (*testIPv6Endpoint)(nil) @@ -171,19 +173,19 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { // When the NIC is disabled, the only field that matters is the stats field. // This test is limited to stats counter checks. nic := nic{ - stats: makeNICStats(), + stats: makeNICStats(tcpip.NICStats{}.FillIn()), } - if got := nic.stats.DisabledRx.Packets.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 0 { t.Errorf("got DisabledRx.Packets = %d, want = 0", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 0 { t.Errorf("got DisabledRx.Bytes = %d, want = 0", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } @@ -195,16 +197,28 @@ func TestDisabledRxStatsWhenNICDisabled(t *testing.T) { Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(), })) - if got := nic.stats.DisabledRx.Packets.Value(); got != 1 { + if got := nic.stats.local.DisabledRx.Packets.Value(); got != 1 { t.Errorf("got DisabledRx.Packets = %d, want = 1", got) } - if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 { + if got := nic.stats.local.DisabledRx.Bytes.Value(); got != 4 { t.Errorf("got DisabledRx.Bytes = %d, want = 4", got) } - if got := nic.stats.Rx.Packets.Value(); got != 0 { + if got := nic.stats.local.Rx.Packets.Value(); got != 0 { t.Errorf("got Rx.Packets = %d, want = 0", got) } - if got := nic.stats.Rx.Bytes.Value(); got != 0 { + if got := nic.stats.local.Rx.Bytes.Value(); got != 0 { t.Errorf("got Rx.Bytes = %d, want = 0", got) } } + +func TestMultiCounterStatsInitialization(t *testing.T) { + global := tcpip.NICStats{}.FillIn() + nic := nic{ + stats: makeNICStats(global), + } + multi := nic.stats.multiCounterNICStats + local := nic.stats.local + if err := testutil.ValidateMultiCounterStats(reflect.ValueOf(&multi).Elem(), []reflect.Value{reflect.ValueOf(&local).Elem(), reflect.ValueOf(&global).Elem()}); err != nil { + t.Error(err) + } +} diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 5a94e9ac6..ca9822bca 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -16,6 +16,7 @@ package stack import ( "math" + "math/rand" "sync" "time" @@ -313,45 +314,36 @@ func calcMaxRandomFactor(minRandomFactor float32) float32 { return defaultMaxRandomFactor } -// A Rand is a source of random numbers. -type Rand interface { - // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0). - Float32() float32 -} - // NUDState stores states needed for calculating reachable time. type NUDState struct { - rng Rand + clock tcpip.Clock + rng *rand.Rand - // mu protects the fields below. - // - // It is necessary for NUDState to handle its own locking since neighbor - // entries may access the NUD state from within the goroutine spawned by - // time.AfterFunc(). This goroutine may run concurrently with the main - // process for controlling the neighbor cache and would otherwise introduce - // race conditions if NUDState was not locked properly. - mu sync.RWMutex + mu struct { + sync.RWMutex - config NUDConfigurations + config NUDConfigurations - // reachableTime is the duration to wait for a REACHABLE entry to - // transition into STALE after inactivity. This value is calculated with - // the algorithm defined in RFC 4861 section 6.3.2. - reachableTime time.Duration + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration - expiration time.Time - prevBaseReachableTime time.Duration - prevMinRandomFactor float32 - prevMaxRandomFactor float32 + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 + } } // NewNUDState returns new NUDState using c as configuration and the specified // random number generator for use in recomputing ReachableTime. -func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { +func NewNUDState(c NUDConfigurations, clock tcpip.Clock, rng *rand.Rand) *NUDState { s := &NUDState{ - rng: rng, + clock: clock, + rng: rng, } - s.config = c + s.mu.config = c return s } @@ -359,14 +351,14 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { func (s *NUDState) Config() NUDConfigurations { s.mu.RLock() defer s.mu.RUnlock() - return s.config + return s.mu.config } // SetConfig replaces the existing NUD configurations with c. func (s *NUDState) SetConfig(c NUDConfigurations) { s.mu.Lock() defer s.mu.Unlock() - s.config = c + s.mu.config = c } // ReachableTime returns the duration to wait for a REACHABLE entry to @@ -377,13 +369,13 @@ func (s *NUDState) ReachableTime() time.Duration { s.mu.Lock() defer s.mu.Unlock() - if time.Now().After(s.expiration) || - s.config.BaseReachableTime != s.prevBaseReachableTime || - s.config.MinRandomFactor != s.prevMinRandomFactor || - s.config.MaxRandomFactor != s.prevMaxRandomFactor { + if s.clock.Now().After(s.mu.expiration) || + s.mu.config.BaseReachableTime != s.mu.prevBaseReachableTime || + s.mu.config.MinRandomFactor != s.mu.prevMinRandomFactor || + s.mu.config.MaxRandomFactor != s.mu.prevMaxRandomFactor { s.recomputeReachableTimeLocked() } - return s.reachableTime + return s.mu.reachableTime } // recomputeReachableTimeLocked forces a recalculation of ReachableTime using @@ -408,23 +400,23 @@ func (s *NUDState) ReachableTime() time.Duration { // // s.mu MUST be locked for writing. func (s *NUDState) recomputeReachableTimeLocked() { - s.prevBaseReachableTime = s.config.BaseReachableTime - s.prevMinRandomFactor = s.config.MinRandomFactor - s.prevMaxRandomFactor = s.config.MaxRandomFactor + s.mu.prevBaseReachableTime = s.mu.config.BaseReachableTime + s.mu.prevMinRandomFactor = s.mu.config.MinRandomFactor + s.mu.prevMaxRandomFactor = s.mu.config.MaxRandomFactor - randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + randomFactor := s.mu.config.MinRandomFactor + s.rng.Float32()*(s.mu.config.MaxRandomFactor-s.mu.config.MinRandomFactor) // Check for overflow, given that minRandomFactor and maxRandomFactor are // guaranteed to be positive numbers. - if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { - s.reachableTime = time.Duration(math.MaxInt64) + if math.MaxInt64/randomFactor < float32(s.mu.config.BaseReachableTime) { + s.mu.reachableTime = time.Duration(math.MaxInt64) } else if randomFactor == 1 { // Avoid loss of precision when a large base reachable time is used. - s.reachableTime = s.config.BaseReachableTime + s.mu.reachableTime = s.mu.config.BaseReachableTime } else { - reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) - s.reachableTime = time.Duration(reachableTime) + reachableTime := int64(float32(s.mu.config.BaseReachableTime) * randomFactor) + s.mu.reachableTime = time.Duration(reachableTime) } - s.expiration = time.Now().Add(2 * time.Hour) + s.mu.expiration = s.clock.Now().Add(2 * time.Hour) } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index e1253f310..1aeb2f8a5 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -16,6 +16,7 @@ package stack_test import ( "math" + "math/rand" "testing" "time" @@ -28,17 +29,15 @@ import ( ) const ( - defaultBaseReachableTime = 30 * time.Second - minimumBaseReachableTime = time.Millisecond - defaultMinRandomFactor = 0.5 - defaultMaxRandomFactor = 1.5 - defaultRetransmitTimer = time.Second - minimumRetransmitTimer = time.Millisecond - defaultDelayFirstProbeTime = 5 * time.Second - defaultMaxMulticastProbes = 3 - defaultMaxUnicastProbes = 3 - defaultMaxAnycastDelayTime = time.Second - defaultMaxReachbilityConfirmations = 3 + defaultBaseReachableTime = 30 * time.Second + minimumBaseReachableTime = time.Millisecond + defaultMinRandomFactor = 0.5 + defaultMaxRandomFactor = 1.5 + defaultRetransmitTimer = time.Second + minimumRetransmitTimer = time.Millisecond + defaultDelayFirstProbeTime = 5 * time.Second + defaultMaxMulticastProbes = 3 + defaultMaxUnicastProbes = 3 defaultFakeRandomNum = 0.5 ) @@ -48,12 +47,14 @@ type fakeRand struct { num float32 } -var _ stack.Rand = (*fakeRand)(nil) +var _ rand.Source = (*fakeRand)(nil) -func (f *fakeRand) Float32() float32 { - return f.num +func (f *fakeRand) Int63() int64 { + return int64(f.num * float32(1<<63)) } +func (*fakeRand) Seed(int64) {} + func TestNUDFunctions(t *testing.T) { const nicID = 1 @@ -169,7 +170,7 @@ func TestNUDFunctions(t *testing.T) { t.Errorf("s.Neigbors(%d, %d) error mismatch (-want +got):\n%s", test.nicID, ipv6.ProtocolNumber, diff) } else if test.expectedErr == nil { if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + []stack.NeighborEntry{{Addr: llAddr2, LinkAddr: linkAddr1, State: stack.Static, UpdatedAt: clock.Now()}}, neighbors, ); diff != "" { t.Errorf("neighbors mismatch (-want +got):\n%s", diff) @@ -710,7 +711,8 @@ func TestNUDStateReachableTime(t *testing.T) { rng := fakeRand{ num: defaultFakeRandomNum, } - s := stack.NewNUDState(c, &rng) + var clock faketime.NullClock + s := stack.NewNUDState(c, &clock, rand.New(&rng)) if got, want := s.ReachableTime(), test.want; got != want { t.Errorf("got ReachableTime = %q, want = %q", got, want) } @@ -782,7 +784,8 @@ func TestNUDStateRecomputeReachableTime(t *testing.T) { rng := fakeRand{ num: defaultFakeRandomNum, } - s := stack.NewNUDState(c, &rng) + var clock faketime.NullClock + s := stack.NewNUDState(c, &clock, rand.New(&rng)) old := s.ReachableTime() if got, want := s.ReachableTime(), old; got != want { diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index e2e073091..9192d8433 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -134,7 +134,7 @@ type PacketBuffer struct { // https://www.man7.org/linux/man-pages/man7/packet.7.html. PktType tcpip.PacketType - // NICID is the ID of the interface the network packet was received at. + // NICID is the ID of the last interface the network packet was handled at. NICID tcpip.NICID // RXTransportChecksumValidated indicates that transport checksum verification @@ -245,10 +245,10 @@ func (pk *PacketBuffer) dataOffset() int { func (pk *PacketBuffer) push(typ headerType, size int) tcpipbuffer.View { h := &pk.headers[typ] if h.length > 0 { - panic(fmt.Sprintf("push must not be called twice: type %s", typ)) + panic(fmt.Sprintf("push(%s, %d) called after previous push", typ, size)) } if pk.pushed+size > pk.reserved { - panic("not enough headroom reserved") + panic(fmt.Sprintf("push(%s, %d) overflows; pushed=%d reserved=%d", typ, size, pk.pushed, pk.reserved)) } pk.pushed += size h.offset = -pk.pushed @@ -261,7 +261,7 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v tcpipbuffer.View, c if h.length > 0 { panic(fmt.Sprintf("consume must not be called twice: type %s", typ)) } - if pk.headerOffset()+pk.consumed+size > int(pk.buf.Size()) { + if pk.reserved+pk.consumed+size > int(pk.buf.Size()) { return nil, false } h.offset = pk.consumed diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go index 1c1aeb950..a8da34992 100644 --- a/pkg/tcpip/stack/packet_buffer_test.go +++ b/pkg/tcpip/stack/packet_buffer_test.go @@ -259,6 +259,37 @@ func TestPacketHeaderPushConsumeMixed(t *testing.T) { }) } +func TestPacketHeaderPushConsumeMixedTooLong(t *testing.T) { + link := makeView(10) + network := makeView(20) + data := makeView(30) + + initData := concatViews(network, data) + pk := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: len(link), + Data: buffer.NewViewFromBytes(initData).ToVectorisedView(), + }) + + // 1. Push link header + copy(pk.LinkHeader().Push(len(link)), link) + + checkPacketContents(t, "" /* prefix */, pk, packetContents{ + link: link, + data: initData, + }) + + // 2. Consume network header, with a number of bytes too large. + gotNetwork, ok := pk.NetworkHeader().Consume(len(initData) + 1) + if ok { + t.Fatalf("pk.NetworkHeader().Consume(%d) = %q, true; want _, false", len(initData)+1, gotNetwork) + } + + checkPacketContents(t, "" /* prefix */, pk, packetContents{ + link: link, + data: initData, + }) +} + func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) { const headerSize = 10 diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go index 421fb5c15..c8294eb6e 100644 --- a/pkg/tcpip/stack/rand.go +++ b/pkg/tcpip/stack/rand.go @@ -15,7 +15,7 @@ package stack import ( - mathrand "math/rand" + "math/rand" "gvisor.dev/gvisor/pkg/sync" ) @@ -23,7 +23,7 @@ import ( // lockedRandomSource provides a threadsafe rand.Source. type lockedRandomSource struct { mu sync.Mutex - src mathrand.Source + src rand.Source } func (r *lockedRandomSource) Int63() (n int64) { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index a82c807b4..a038389e0 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -420,7 +420,7 @@ const ( PermanentExpired // Temporary is an endpoint, created on a one-off basis to temporarily - // consider the NIC bound an an address that it is not explictiy bound to + // consider the NIC bound an an address that it is not explicitly bound to // (such as a permanent address). Its reference count must not be biased by 1 // so that the address is removed immediately when references to it are no // longer held. @@ -630,7 +630,7 @@ type NetworkEndpoint interface { // HandlePacket takes ownership of pkt. HandlePacket(pkt *PacketBuffer) - // Close is called when the endpoint is reomved from a stack. + // Close is called when the endpoint is removed from a stack. Close() // NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for @@ -658,9 +658,9 @@ type IPNetworkEndpointStats interface { IPStats() *tcpip.IPStats } -// ForwardingNetworkProtocol is a NetworkProtocol that may forward packets. -type ForwardingNetworkProtocol interface { - NetworkProtocol +// ForwardingNetworkEndpoint is a network endpoint that may forward packets. +type ForwardingNetworkEndpoint interface { + NetworkEndpoint // Forwarding returns the forwarding configuration. Forwarding() bool @@ -968,7 +968,7 @@ type DuplicateAddressDetector interface { // called with the result of the original DAD request. CheckDuplicateAddress(tcpip.Address, DADCompletionHandler) DADCheckAddressDisposition - // SetDADConfiguations sets the configurations for DAD. + // SetDADConfigurations sets the configurations for DAD. SetDADConfigurations(c DADConfigurations) // DuplicateAddressProtocol returns the network protocol the receiver can @@ -979,7 +979,7 @@ type DuplicateAddressDetector interface { // LinkAddressResolver handles link address resolution for a network protocol. type LinkAddressResolver interface { // LinkAddressRequest sends a request for the link address of the target - // address. The request is broadcasted on the local network if a remote link + // address. The request is broadcast on the local network if a remote link // address is not provided. LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) tcpip.Error @@ -1072,4 +1072,4 @@ type GSOEndpoint interface { // SoftwareGSOMaxSize is a maximum allowed size of a software GSO segment. // This isn't a hard limit, because it is never set into packet headers. -const SoftwareGSOMaxSize = (1 << 16) +const SoftwareGSOMaxSize = 1 << 16 diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 8a044c073..f17c04277 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -446,7 +446,7 @@ func (r *Route) isValidForOutgoingRLocked() bool { // If the source NIC and outgoing NIC are different, make sure the stack has // forwarding enabled, or the packet will be handled locally. - if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) { + if r.outgoingNIC != r.localAddressNIC && !isNICForwarding(r.localAddressNIC, r.NetProto()) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto(), r.RemoteAddress())) { return false } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 483a960c8..81fabe29a 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -20,17 +20,16 @@ package stack import ( - "bytes" "encoding/binary" "fmt" "io" - mathrand "math/rand" + "math/rand" "sync/atomic" "time" "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/atomicbitops" - "gvisor.dev/gvisor/pkg/rand" + cryptorand "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -40,13 +39,6 @@ import ( ) const ( - // ageLimit is set to the same cache stale time used in Linux. - ageLimit = 1 * time.Minute - // resolutionTimeout is set to the same ARP timeout used in Linux. - resolutionTimeout = 1 * time.Second - // resolutionAttempts is set to the same ARP retries used in Linux. - resolutionAttempts = 3 - // DefaultTOS is the default type of service value for network endpoints. DefaultTOS = 0 ) @@ -95,8 +87,9 @@ type Stack struct { } } - mu sync.RWMutex - nics map[tcpip.NICID]*nic + mu sync.RWMutex + nics map[tcpip.NICID]*nic + defaultForwardingEnabled map[tcpip.NetworkProtocolNumber]struct{} // cleanupEndpointsMu protects cleanupEndpoints. cleanupEndpointsMu sync.Mutex @@ -115,7 +108,7 @@ type Stack struct { handleLocal bool // tables are the iptables packet filtering and manipulation rules. - // TODO(gvisor.dev/issue/170): S/R this field. + // TODO(gvisor.dev/issue/4595): S/R this field. tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the @@ -144,7 +137,7 @@ type Stack struct { // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. - randomGenerator *mathrand.Rand + randomGenerator *rand.Rand // secureRNG is a cryptographically secure random number generator. secureRNG io.Reader @@ -195,9 +188,9 @@ type Options struct { // TransportProtocols lists the transport protocols to enable. TransportProtocols []TransportProtocolFactory - // Clock is an optional clock source used for timestampping packets. + // Clock is an optional clock used for timekeeping. // - // If no Clock is specified, the clock source will be time.Now. + // If Clock is nil, tcpip.NewStdClock() will be used. Clock tcpip.Clock // Stats are optional statistic counters. @@ -224,15 +217,21 @@ type Options struct { // RandSource is an optional source to use to generate random // numbers. If omitted it defaults to a Source seeded by the data - // returned by rand.Read(). + // returned by the stack secure RNG. // // RandSource must be thread-safe. - RandSource mathrand.Source + RandSource rand.Source - // IPTables are the initial iptables rules. If nil, iptables will allow + // IPTables are the initial iptables rules. If nil, DefaultIPTables will be + // used to construct the initial iptables rules. // all traffic. IPTables *IPTables + // DefaultIPTables is an optional iptables rules constructor that is called + // if IPTables is nil. If both fields are nil, iptables will allow all + // traffic. + DefaultIPTables func(uint32) *IPTables + // SecureRNG is a cryptographically secure random number generator. SecureRNG io.Reader } @@ -330,40 +329,50 @@ func New(opts Options) *Stack { opts.UniqueID = new(uniqueIDGenerator) } + if opts.SecureRNG == nil { + opts.SecureRNG = cryptorand.Reader + } + randSrc := opts.RandSource if randSrc == nil { - // Source provided by mathrand.NewSource is not thread-safe so + var v int64 + if err := binary.Read(opts.SecureRNG, binary.LittleEndian, &v); err != nil { + panic(err) + } + // Source provided by rand.NewSource is not thread-safe so // we wrap it in a simple thread-safe version. - randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} + randSrc = &lockedRandomSource{src: rand.NewSource(v)} } + randomGenerator := rand.New(randSrc) + seed := randomGenerator.Uint32() if opts.IPTables == nil { - opts.IPTables = DefaultTables() + if opts.DefaultIPTables == nil { + opts.DefaultIPTables = DefaultTables + } + opts.IPTables = opts.DefaultIPTables(seed) } opts.NUDConfigs.resetInvalidFields() - if opts.SecureRNG == nil { - opts.SecureRNG = rand.Reader - } - s := &Stack{ - transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), - networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), - nics: make(map[tcpip.NICID]*nic), - cleanupEndpoints: make(map[TransportEndpoint]struct{}), - PortManager: ports.NewPortManager(), - clock: clock, - stats: opts.Stats.FillIn(), - handleLocal: opts.HandleLocal, - tables: opts.IPTables, - icmpRateLimiter: NewICMPRateLimiter(), - seed: generateRandUint32(), - nudConfigs: opts.NUDConfigs, - uniqueIDGenerator: opts.UniqueID, - nudDisp: opts.NUDDisp, - randomGenerator: mathrand.New(randSrc), - secureRNG: opts.SecureRNG, + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + nics: make(map[tcpip.NICID]*nic), + defaultForwardingEnabled: make(map[tcpip.NetworkProtocolNumber]struct{}), + cleanupEndpoints: make(map[TransportEndpoint]struct{}), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + handleLocal: opts.HandleLocal, + tables: opts.IPTables, + icmpRateLimiter: NewICMPRateLimiter(), + seed: seed, + nudConfigs: opts.NUDConfigs, + uniqueIDGenerator: opts.UniqueID, + nudDisp: opts.NUDDisp, + randomGenerator: randomGenerator, + secureRNG: opts.SecureRNG, sendBufferSize: tcpip.SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultBufferSize, @@ -492,37 +501,61 @@ func (s *Stack) Stats() tcpip.Stats { return s.stats } -// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the -// passed protocol and sets the default setting for newly created NICs. -func (s *Stack) SetForwardingDefaultAndAllNICs(protocolNum tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { - protocol, ok := s.networkProtocols[protocolNum] +// SetNICForwarding enables or disables packet forwarding on the specified NIC +// for the passed protocol. +func (s *Stack) SetNICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] if !ok { - return &tcpip.ErrUnknownProtocol{} + return &tcpip.ErrUnknownNICID{} } - forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) + return nic.setForwarding(protocol, enable) +} + +// NICForwarding returns the forwarding configuration for the specified NIC. +func (s *Stack) NICForwarding(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (bool, tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + nic, ok := s.nics[id] if !ok { - return &tcpip.ErrNotSupported{} + return false, &tcpip.ErrUnknownNICID{} } - forwardingProtocol.SetForwarding(enable) - return nil + return nic.forwarding(protocol) } -// Forwarding returns true if packet forwarding between NICs is enabled for the -// passed protocol. -func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { - protocol, ok := s.networkProtocols[protocolNum] - if !ok { - return false +// SetForwardingDefaultAndAllNICs sets packet forwarding for all NICs for the +// passed protocol and sets the default setting for newly created NICs. +func (s *Stack) SetForwardingDefaultAndAllNICs(protocol tcpip.NetworkProtocolNumber, enable bool) tcpip.Error { + s.mu.Lock() + defer s.mu.Unlock() + + doneOnce := false + for id, nic := range s.nics { + if err := nic.setForwarding(protocol, enable); err != nil { + // Expect forwarding to be settable on all interfaces if it was set on + // one. + if doneOnce { + panic(fmt.Sprintf("nic(id=%d).setForwarding(%d, %t): %s", id, protocol, enable, err)) + } + + return err + } + + doneOnce = true } - forwardingProtocol, ok := protocol.(ForwardingNetworkProtocol) - if !ok { - return false + if enable { + s.defaultForwardingEnabled[protocol] = struct{}{} + } else { + delete(s.defaultForwardingEnabled, protocol) } - return forwardingProtocol.Forwarding() + return nil } // PortRange returns the UDP and TCP inclusive range of ephemeral ports used in @@ -659,6 +692,11 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp } n := newNIC(s, id, opts.Name, ep, opts.Context) + for proto := range s.defaultForwardingEnabled { + if err := n.setForwarding(proto, true); err != nil { + panic(fmt.Sprintf("newNIC(%d, ...).setForwarding(%d, true): %s", id, proto, err)) + } + } s.nics[id] = n if !opts.Disabled { return n.enable() @@ -773,7 +811,7 @@ type NICInfo struct { // MTU is the maximum transmission unit. MTU uint32 - Stats NICStats + Stats tcpip.NICStats // NetworkStats holds the stats of each NetworkEndpoint bound to the NIC. NetworkStats map[tcpip.NetworkProtocolNumber]NetworkEndpointStats @@ -786,6 +824,10 @@ type NICInfo struct { // value sent in haType field of an ARP Request sent by this NIC and the // value expected in the haType field of an ARP response. ARPHardwareType header.ARPHardwareType + + // Forwarding holds the forwarding status for each network endpoint that + // supports forwarding. + Forwarding map[tcpip.NetworkProtocolNumber]bool } // HasNIC returns true if the NICID is defined in the stack. @@ -815,17 +857,33 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { netStats[proto] = netEP.Stats() } - nics[id] = NICInfo{ + info := NICInfo{ Name: nic.name, LinkAddress: nic.LinkEndpoint.LinkAddress(), ProtocolAddresses: nic.primaryAddresses(), Flags: flags, MTU: nic.LinkEndpoint.MTU(), - Stats: nic.stats, + Stats: nic.stats.local, NetworkStats: netStats, Context: nic.context, ARPHardwareType: nic.LinkEndpoint.ARPHardwareType(), + Forwarding: make(map[tcpip.NetworkProtocolNumber]bool), } + + for proto := range s.networkProtocols { + switch forwarding, err := nic.forwarding(proto); err.(type) { + case nil: + info.Forwarding[proto] = forwarding + case *tcpip.ErrUnknownProtocol: + panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID())) + case *tcpip.ErrNotSupported: + // Not all network protocols support forwarding. + default: + panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err)) + } + } + + nics[id] = info } return nics } @@ -1029,6 +1087,20 @@ func (s *Stack) HandleLocal() bool { return s.handleLocal } +func isNICForwarding(nic *nic, proto tcpip.NetworkProtocolNumber) bool { + switch forwarding, err := nic.forwarding(proto); err.(type) { + case nil: + return forwarding + case *tcpip.ErrUnknownProtocol: + panic(fmt.Sprintf("expected network protocol %d to be available on NIC %d", proto, nic.ID())) + case *tcpip.ErrNotSupported: + // Not all network protocols support forwarding. + return false + default: + panic(fmt.Sprintf("nic(id=%d).forwarding(%d): %s", nic.ID(), proto, err)) + } +} + // FindRoute creates a route to the given destination address, leaving through // the given NIC and local address (if provided). // @@ -1081,7 +1153,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n return nil, &tcpip.ErrNetworkUnreachable{} } - canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal + onlyGlobalAddresses := !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal // Find a route to the remote with the route table. var chosenRoute tcpip.Route @@ -1120,7 +1192,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n // requirement to do this from any RFC but simply a choice made to better // follow a strong host model which the netstack follows at the time of // writing. - if canForward && chosenRoute == (tcpip.Route{}) { + if onlyGlobalAddresses && chosenRoute == (tcpip.Route{}) && isNICForwarding(nic, netProto) { chosenRoute = route } } @@ -1754,7 +1826,7 @@ func (s *Stack) Seed() uint32 { // Rand returns a reference to a pseudo random generator that can be used // to generate random numbers as required. -func (s *Stack) Rand() *mathrand.Rand { +func (s *Stack) Rand() *rand.Rand { return s.randomGenerator } @@ -1764,27 +1836,6 @@ func (s *Stack) SecureRNG() io.Reader { return s.secureRNG } -func generateRandUint32() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - return binary.LittleEndian.Uint32(b) -} - -func generateRandInt64() int64 { - b := make([]byte, 8) - if _, err := rand.Read(b); err != nil { - panic(err) - } - buf := bytes.NewReader(b) - var v int64 - if err := binary.Read(buf, binary.LittleEndian, &v); err != nil { - panic(err) - } - return v -} - // FindNICNameFromID returns the name of the NIC for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { s.mu.RLock() @@ -1821,9 +1872,8 @@ const ( // ParsePacketBufferTransport parses the provided packet buffer's transport // header. func (s *Stack) ParsePacketBufferTransport(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) ParseResult { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader - // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a - // full explanation. + // ICMP packets don't have their TransportHeader fields set yet, parse it + // here. See icmp/protocol.go:protocol.Parse for a full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { return ParsedOK } diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go index 33824afd0..dfec4258a 100644 --- a/pkg/tcpip/stack/stack_global_state.go +++ b/pkg/tcpip/stack/stack_global_state.go @@ -14,78 +14,6 @@ package stack -import "time" - // StackFromEnv is the global stack created in restore run. // FIXME(b/36201077) var StackFromEnv *Stack - -// saveT is invoked by stateify. -func (t *TCPCubicState) saveT() unixTime { - return unixTime{t.T.Unix(), t.T.UnixNano()} -} - -// loadT is invoked by stateify. -func (t *TCPCubicState) loadT(unix unixTime) { - t.T = time.Unix(unix.second, unix.nano) -} - -// saveXmitTime is invoked by stateify. -func (t *TCPRACKState) saveXmitTime() unixTime { - return unixTime{t.XmitTime.Unix(), t.XmitTime.UnixNano()} -} - -// loadXmitTime is invoked by stateify. -func (t *TCPRACKState) loadXmitTime(unix unixTime) { - t.XmitTime = time.Unix(unix.second, unix.nano) -} - -// saveLastSendTime is invoked by stateify. -func (t *TCPSenderState) saveLastSendTime() unixTime { - return unixTime{t.LastSendTime.Unix(), t.LastSendTime.UnixNano()} -} - -// loadLastSendTime is invoked by stateify. -func (t *TCPSenderState) loadLastSendTime(unix unixTime) { - t.LastSendTime = time.Unix(unix.second, unix.nano) -} - -// saveRTTMeasureTime is invoked by stateify. -func (t *TCPSenderState) saveRTTMeasureTime() unixTime { - return unixTime{t.RTTMeasureTime.Unix(), t.RTTMeasureTime.UnixNano()} -} - -// loadRTTMeasureTime is invoked by stateify. -func (t *TCPSenderState) loadRTTMeasureTime(unix unixTime) { - t.RTTMeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) saveMeasureTime() unixTime { - return unixTime{r.MeasureTime.Unix(), r.MeasureTime.UnixNano()} -} - -// loadMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { - r.MeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveRTTMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) saveRTTMeasureTime() unixTime { - return unixTime{r.RTTMeasureTime.Unix(), r.RTTMeasureTime.UnixNano()} -} - -// loadRTTMeasureTime is invoked by stateify. -func (r *RcvBufAutoTuneParams) loadRTTMeasureTime(unix unixTime) { - r.RTTMeasureTime = time.Unix(unix.second, unix.nano) -} - -// saveSegTime is invoked by stateify. -func (t *TCPEndpointState) saveSegTime() unixTime { - return unixTime{t.SegTime.Unix(), t.SegTime.UnixNano()} -} - -// loadSegTime is invoked by stateify. -func (t *TCPEndpointState) loadSegTime(unix unixTime) { - t.SegTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index ff88b1bd3..21951d05a 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -84,7 +84,8 @@ type fakeNetworkEndpoint struct { mu struct { sync.RWMutex - enabled bool + enabled bool + forwarding bool } nic stack.NetworkInterface @@ -165,10 +166,6 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { return f.nic.MaxHeaderLength() + fakeNetHeaderLen } -func (*fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { - return 0 -} - func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return f.proto.Number() } @@ -196,11 +193,11 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, params stack.NetworkHe } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (*fakeNetworkEndpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, tcpip.Error) { +func (*fakeNetworkEndpoint) WritePackets(*stack.Route, stack.PacketBufferList, stack.NetworkHeaderParams) (int, tcpip.Error) { panic("not implemented") } -func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) tcpip.Error { +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -227,11 +224,6 @@ type fakeNetworkProtocol struct { packetCount [10]int sendPacketCount [10]int defaultTTL uint8 - - mu struct { - sync.RWMutex - forwarding bool - } } func (*fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -300,15 +292,15 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } -// Forwarding implements stack.ForwardingNetworkProtocol. -func (f *fakeNetworkProtocol) Forwarding() bool { +// Forwarding implements stack.ForwardingNetworkEndpoint. +func (f *fakeNetworkEndpoint) Forwarding() bool { f.mu.RLock() defer f.mu.RUnlock() return f.mu.forwarding } -// SetForwarding implements stack.ForwardingNetworkProtocol. -func (f *fakeNetworkProtocol) SetForwarding(v bool) { +// SetForwarding implements stack.ForwardingNetworkEndpoint. +func (f *fakeNetworkEndpoint) SetForwarding(v bool) { f.mu.Lock() defer f.mu.Unlock() f.mu.forwarding = v @@ -467,14 +459,14 @@ func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer } } -func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View, wantErr tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) @@ -924,15 +916,15 @@ func TestRouteWithDownNIC(t *testing.T) { if err := test.downFn(s, nicID1); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID1, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) testSend(t, r2, ep2, buf) // Writes with Routes that use NIC2 after being brought down should fail. if err := test.downFn(s, nicID2); err != nil { t.Fatalf("test.downFn(_, %d): %s", nicID2, err) } - testFailingSend(t, r1, ep1, buf, &tcpip.ErrInvalidEndpointState{}) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r1, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) if upFn := test.upFn; upFn != nil { // Writes with Routes that use NIC1 after being brought up should @@ -945,7 +937,7 @@ func TestRouteWithDownNIC(t *testing.T) { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } testSend(t, r1, ep1, buf) - testFailingSend(t, r2, ep2, buf, &tcpip.ErrInvalidEndpointState{}) + testFailingSend(t, r2, buf, &tcpip.ErrInvalidEndpointState{}) } }) } @@ -1070,7 +1062,7 @@ func TestAddressRemoval(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. err := s.RemoveAddress(1, localAddr) @@ -1122,8 +1114,8 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { t.Fatal("RemoveAddress failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, ep, buf) - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) // Check that removing the same address fails. { @@ -1144,7 +1136,7 @@ func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.A // No address given, verify that there is no address assigned to the NIC. for _, a := range info.ProtocolAddresses { if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { - t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, tcpip.AddressWithPrefix{}) } } return @@ -1224,7 +1216,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 2. Add Address, everything should work. @@ -1252,7 +1244,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 4. Add Address back, everything should work again. @@ -1291,8 +1283,8 @@ func TestEndpointExpiration(t *testing.T) { testSend(t, r, ep, nil) testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, ep, nil, &tcpip.ErrInvalidEndpointState{}) - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSend(t, r, nil, &tcpip.ErrInvalidEndpointState{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } // 7. Add Address back, everything should work again. @@ -1328,7 +1320,7 @@ func TestEndpointExpiration(t *testing.T) { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, remoteAddr, nil, &tcpip.ErrNoRoute{}) } }) } @@ -1578,7 +1570,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, ep, nil, &tcpip.ErrNoRoute{}) + testFailingSendTo(t, s, dstAddr, nil, &tcpip.ErrNoRoute{}) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -1619,7 +1611,7 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } } - protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}} + protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{Address: header.IPv4Any}} if err := s.AddProtocolAddress(1, protoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err) } @@ -1645,12 +1637,12 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { } func TestOutgoingBroadcastWithRouteTable(t *testing.T) { - defaultAddr := tcpip.AddressWithPrefix{header.IPv4Any, 0} + defaultAddr := tcpip.AddressWithPrefix{Address: header.IPv4Any} // Local subnet on NIC1: 192.168.1.58/24, gateway 192.168.1.1. - nic1Addr := tcpip.AddressWithPrefix{"\xc0\xa8\x01\x3a", 24} + nic1Addr := tcpip.AddressWithPrefix{Address: "\xc0\xa8\x01\x3a", PrefixLen: 24} nic1Gateway := testutil.MustParse4("192.168.1.1") // Local subnet on NIC2: 10.10.10.5/24, gateway 10.10.10.1. - nic2Addr := tcpip.AddressWithPrefix{"\x0a\x0a\x0a\x05", 24} + nic2Addr := tcpip.AddressWithPrefix{Address: "\x0a\x0a\x0a\x05", PrefixLen: 24} nic2Gateway := testutil.MustParse4("10.10.10.1") // Create a new stack with two NICs. @@ -1664,12 +1656,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err := s.CreateNIC(2, ep); err != nil { t.Fatalf("CreateNIC failed: %s", err) } - nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr} + nic1ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic1Addr} if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err) } - nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr} + nic2ProtoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: nic2Addr} if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil { t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err) } @@ -1713,7 +1705,7 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { // 2. Case: Having an explicit route for broadcast will select that one. rt = append( []tcpip.Route{ - {Destination: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, + {Destination: tcpip.AddressWithPrefix{Address: header.IPv4Broadcast, PrefixLen: 8 * header.IPv4AddressSize}.Subnet(), NIC: 1}, }, rt..., ) @@ -2053,7 +2045,7 @@ func TestAddAddress(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } @@ -2117,7 +2109,7 @@ func TestAddAddressWithOptions(t *testing.T) { } expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{ Protocol: fakeNetNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{address, fakeDefaultPrefixLen}, + AddressWithPrefix: tcpip.AddressWithPrefix{Address: address, PrefixLen: fakeDefaultPrefixLen}, }) } } @@ -2238,7 +2230,7 @@ func TestCreateNICWithOptions(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { s := stack.New(stack.Options{}) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") for _, call := range test.calls { if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want { t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want) @@ -2252,46 +2244,87 @@ func TestNICStats(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, }) - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatal("CreateNIC failed: ", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatal("AddAddress failed:", err) + + nics := []struct { + addr tcpip.Address + txByteCount int + rxByteCount int + }{ + { + addr: "\x01", + txByteCount: 30, + rxByteCount: 10, + }, + { + addr: "\x02", + txByteCount: 50, + rxByteCount: 20, + }, } - // Route all packets for address \x01 to NIC 1. - { - subnet, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) + + var txBytesTotal, rxBytesTotal, txPacketsTotal, rxPacketsTotal int + for i, nic := range nics { + nicid := tcpip.NICID(i) + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { + t.Fatal("CreateNIC failed: ", err) + } + if err := s.AddAddress(nicid, fakeNetNumber, nic.addr); err != nil { + t.Fatal("AddAddress failed:", err) } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) - } - // Send a packet to address 1. - buf := buffer.NewView(30) - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) - } + { + subnet, err := tcpip.NewSubnet(nic.addr, "\xff") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicid}}) + } - if got, want := s.NICInfo()[1].Stats.Rx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + nicStats := s.NICInfo()[nicid].Stats + + // Inbound packet. + rxBuffer := buffer.NewView(nic.rxByteCount) + ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: rxBuffer.ToVectorisedView(), + })) + if got, want := nicStats.Rx.Packets.Value(), uint64(1); got != want { + t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) + } + if got, want := nicStats.Rx.Bytes.Value(), uint64(nic.rxByteCount); got != want { + t.Errorf("got Rx.Bytes.Value() = %d, want = %d", got, want) + } + rxPacketsTotal++ + rxBytesTotal += nic.rxByteCount + + // Outbound packet. + txBuffer := buffer.NewView(nic.txByteCount) + actualTxLength := nic.txByteCount + fakeNetHeaderLen + if err := sendTo(s, nic.addr, txBuffer); err != nil { + t.Fatal("sendTo failed: ", err) + } + want := ep.Drain() + if got := nicStats.Tx.Packets.Value(); got != uint64(want) { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := nicStats.Tx.Bytes.Value(), uint64(actualTxLength); got != want { + t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) + } + txPacketsTotal += want + txBytesTotal += actualTxLength } - payload := buffer.NewView(10) - // Write a packet out via the address for NIC 1 - if err := sendTo(s, "\x01", payload); err != nil { - t.Fatal("sendTo failed: ", err) + // Now verify that each NIC stats was correctly aggregated at the stack level. + if got, want := s.Stats().NICs.Rx.Packets.Value(), uint64(rxPacketsTotal); got != want { + t.Errorf("got s.Stats().NIC.Rx.Packets.Value() = %d, want = %d", got, want) } - want := uint64(ep1.Drain()) - if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) + if got, want := s.Stats().NICs.Rx.Bytes.Value(), uint64(rxBytesTotal); got != want { + t.Errorf("got s.Stats().Rx.Bytes.Value() = %d, want = %d", got, want) } - - if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)+fakeNetHeaderLen); got != want { + if got, want := s.Stats().NICs.Tx.Packets.Value(), uint64(txPacketsTotal); got != want { + t.Errorf("got Tx.Packets.Value() = %d, ep.Drain() = %d", got, want) + } + if got, want := s.Stats().NICs.Tx.Bytes.Value(), uint64(txBytesTotal); got != want { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } @@ -2320,7 +2353,7 @@ func TestNICContextPreservation(t *testing.T) { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{}) id := tcpip.NICID(1) - ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")) + ep := channel.New(0, 0, "\x00\x00\x00\x00\x00\x00") if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil { t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err) } @@ -2607,15 +2640,17 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { const nicID = 1 ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } dadConfigs := stack.DefaultDADConfigurations() + clock := faketime.NewManualClock() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ AutoGenLinkLocal: true, NDPDisp: &ndpDisp, DADConfigs: dadConfigs, })}, + Clock: clock, } e := channel.New(int(dadConfigs.DupAddrDetectTransmits), 1280, linkAddr1) @@ -2633,17 +2668,18 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { linkLocalAddr := header.LinkLocalAddr(linkAddr1) // Wait for DAD to resolve. + clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) select { - case <-time.After(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer + time.Second): + case e := <-ndpDisp.dadC: + if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("dad event mismatch (-want +got):\n%s", diff) + } + default: // We should get a resolution event after 1s (default time to // resolve as per default NDP configurations). Waiting for that // resolution time + an extra 1s without a resolution event // means something is wrong. t.Fatal("timed out waiting for DAD resolution") - case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) - } } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil { t.Fatal(err) @@ -3274,8 +3310,9 @@ func TestDoDADWhenNICEnabled(t *testing.T) { const nicID = 1 ndpDisp := ndpDispatcher{ - dadC: make(chan ndpDADEvent), + dadC: make(chan ndpDADEvent, 1), } + clock := faketime.NewManualClock() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ DADConfigs: stack.DADConfigurations{ @@ -3284,6 +3321,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { }, NDPDisp: &ndpDisp, })}, + Clock: clock, } e := channel.New(dadTransmits, 1280, linkAddr1) @@ -3328,13 +3366,14 @@ func TestDoDADWhenNICEnabled(t *testing.T) { } // Wait for DAD to resolve. + clock.Advance(dadTransmits * retransmitTimer) select { - case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } + default: + t.Fatal("timed out waiting for DAD resolution") } if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) { t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr) @@ -3841,8 +3880,6 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { // TestAddRoute tests Stack.AddRoute func TestAddRoute(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) subnet1, err := tcpip.NewSubnet("\x00", "\x00") @@ -3879,8 +3916,6 @@ func TestAddRoute(t *testing.T) { // TestRemoveRoutes tests Stack.RemoveRoutes func TestRemoveRoutes(t *testing.T) { - const nicID = 1 - s := stack.New(stack.Options{}) addressToRemove := tcpip.Address("\x01") @@ -4227,7 +4262,7 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) - if r != nil { + if err == nil { defer r.Release() } if diff := cmp.Diff(test.findRouteErr, err); diff != "" { @@ -4398,7 +4433,7 @@ func TestClearNeighborCacheOnNICDisable(t *testing.T) { if neighbors, err := s.Neighbors(nicID, addr.proto); err != nil { t.Fatalf("s.Neighbors(%d, %d): %s", nicID, addr.proto, err) } else if diff := cmp.Diff( - []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAtNanos: clock.NowNanoseconds()}}, + []stack.NeighborEntry{{Addr: addr.addr, LinkAddr: linkAddr, State: stack.Static, UpdatedAt: clock.Now()}}, neighbors, ); diff != "" { t.Fatalf("proto=%d neighbors mismatch (-want +got):\n%s", addr.proto, diff) diff --git a/pkg/tcpip/stack/tcp.go b/pkg/tcpip/stack/tcp.go index ddff6e2d6..e90c1a770 100644 --- a/pkg/tcpip/stack/tcp.go +++ b/pkg/tcpip/stack/tcp.go @@ -39,7 +39,7 @@ type TCPCubicState struct { WMax float64 // T is the time when the current congestion avoidance was entered. - T time.Time `state:".(unixTime)"` + T tcpip.MonotonicTime // TimeSinceLastCongestion denotes the time since the current // congestion avoidance was entered. @@ -78,7 +78,7 @@ type TCPCubicState struct { type TCPRACKState struct { // XmitTime is the transmission timestamp of the most recent // acknowledged segment. - XmitTime time.Time `state:".(unixTime)"` + XmitTime tcpip.MonotonicTime // EndSequence is the ending TCP sequence number of the most recent // acknowledged segment. @@ -216,7 +216,7 @@ type TCPRTTState struct { // +stateify savable type TCPSenderState struct { // LastSendTime is the timestamp at which we sent the last segment. - LastSendTime time.Time `state:".(unixTime)"` + LastSendTime tcpip.MonotonicTime // DupAckCount is the number of Duplicate ACKs received. It is used for // fast retransmit. @@ -256,7 +256,7 @@ type TCPSenderState struct { RTTMeasureSeqNum seqnum.Value // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. - RTTMeasureTime time.Time `state:".(unixTime)"` + RTTMeasureTime tcpip.MonotonicTime // Closed indicates that the caller has closed the endpoint for // sending. @@ -313,7 +313,7 @@ type TCPSACKInfo struct { type RcvBufAutoTuneParams struct { // MeasureTime is the time at which the current measurement was // started. - MeasureTime time.Time `state:".(unixTime)"` + MeasureTime tcpip.MonotonicTime // CopiedBytes is the number of bytes copied to user space since this // measure began. @@ -341,7 +341,7 @@ type RcvBufAutoTuneParams struct { // RTTMeasureTime is the absolute time at which the current RTT // measurement period began. - RTTMeasureTime time.Time `state:".(unixTime)"` + RTTMeasureTime tcpip.MonotonicTime // Disabled is true if an explicit receive buffer is set for the // endpoint. @@ -429,7 +429,7 @@ type TCPEndpointState struct { ID TCPEndpointID // SegTime denotes the absolute time when this segment was received. - SegTime time.Time `state:".(unixTime)"` + SegTime tcpip.MonotonicTime // RcvBufState contains information about the state of the endpoint's // receive socket buffer. diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 80ad1a9d4..8a8454a6a 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -16,8 +16,6 @@ package stack import ( "fmt" - "math/rand" - "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" @@ -223,7 +221,7 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t return multiPortEp.singleRegisterEndpoint(t, flags) } -func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { +func (epsByNIC *endpointsByNIC) checkEndpoint(flags ports.Flags, bindToDevice tcpip.NICID) tcpip.Error { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -475,7 +473,7 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol if !ok { epsByNIC = &endpointsByNIC{ endpoints: make(map[tcpip.NICID]*multiPortEndpoint), - seed: rand.Uint32(), + seed: d.stack.Seed(), } eps.endpoints[id] = epsByNIC } @@ -502,7 +500,7 @@ func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNum return nil } - return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice) + return epsByNIC.checkEndpoint(flags, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 4848495c9..0972c94de 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -18,6 +18,7 @@ import ( "io/ioutil" "math" "math/rand" + "strconv" "testing" "gvisor.dev/gvisor/pkg/tcpip" @@ -84,7 +85,8 @@ func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICI } type headers struct { - srcPort, dstPort uint16 + srcPort uint16 + dstPort uint16 } func newPayload() []byte { @@ -208,7 +210,7 @@ func TestBindToDeviceDistribution(t *testing.T) { reuse bool bindToDevice tcpip.NICID } - for _, test := range []struct { + tcs := []struct { name string // endpoints will received the inject packets. endpoints []endpointSockopts @@ -217,29 +219,29 @@ func TestBindToDeviceDistribution(t *testing.T) { wantDistributions map[tcpip.NICID][]float64 }{ { - "BindPortReuse", + name: "BindPortReuse", // 5 endpoints that all have reuse set. - []endpointSockopts{ + endpoints: []endpointSockopts{ {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, {reuse: true, bindToDevice: 0}, }, - map[tcpip.NICID][]float64{ + wantDistributions: map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed evenly. 1: {0.2, 0.2, 0.2, 0.2, 0.2}, }, }, { - "BindToDevice", + name: "BindToDevice", // 3 endpoints with various bindings. - []endpointSockopts{ + endpoints: []endpointSockopts{ {reuse: false, bindToDevice: 1}, {reuse: false, bindToDevice: 2}, {reuse: false, bindToDevice: 3}, }, - map[tcpip.NICID][]float64{ + wantDistributions: map[tcpip.NICID][]float64{ // Injected packets on dev0 go only to the endpoint bound to dev0. 1: {1, 0, 0}, // Injected packets on dev1 go only to the endpoint bound to dev1. @@ -249,9 +251,9 @@ func TestBindToDeviceDistribution(t *testing.T) { }, }, { - "ReuseAndBindToDevice", + name: "ReuseAndBindToDevice", // 6 endpoints with various bindings. - []endpointSockopts{ + endpoints: []endpointSockopts{ {reuse: true, bindToDevice: 1}, {reuse: true, bindToDevice: 1}, {reuse: true, bindToDevice: 2}, @@ -259,7 +261,7 @@ func TestBindToDeviceDistribution(t *testing.T) { {reuse: true, bindToDevice: 2}, {reuse: true, bindToDevice: 0}, }, - map[tcpip.NICID][]float64{ + wantDistributions: map[tcpip.NICID][]float64{ // Injected packets on dev0 get distributed among endpoints bound to // dev0. 1: {0.5, 0.5, 0, 0, 0, 0}, @@ -270,35 +272,42 @@ func TestBindToDeviceDistribution(t *testing.T) { 1000: {0, 0, 0, 0, 0, 1}, }, }, - } { - for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{ - "IPv4": ipv4.ProtocolNumber, - "IPv6": ipv6.ProtocolNumber, - } { + } + protos := map[string]tcpip.NetworkProtocolNumber{ + "IPv4": ipv4.ProtocolNumber, + "IPv6": ipv6.ProtocolNumber, + } + + for _, test := range tcs { + for protoName, protoNum := range protos { for device, wantDistribution := range test.wantDistributions { - t.Run(test.name+protoName+string(device), func(t *testing.T) { + t.Run(test.name+protoName+"-"+strconv.Itoa(int(device)), func(t *testing.T) { + // Create the NICs. var devices []tcpip.NICID for d := range test.wantDistributions { devices = append(devices, d) } c := newDualTestContextMultiNIC(t, defaultMTU, devices) + // Create endpoints and bind each to a NIC, sometimes reusing ports. eps := make(map[tcpip.Endpoint]int) - pollChannel := make(chan tcpip.Endpoint) for i, endpoint := range test.endpoints { // Try to receive the data. wq := waiter.Queue{} we, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&we, waiter.ReadableEvents) - defer wq.EventUnregister(&we) - defer close(ch) + t.Cleanup(func() { + wq.EventUnregister(&we) + close(ch) + }) var err tcpip.Error - ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq) + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, protoNum, &wq) if err != nil { t.Fatalf("NewEndpoint failed: %s", err) } + t.Cleanup(ep.Close) eps[ep] = i go func(ep tcpip.Endpoint) { @@ -307,32 +316,34 @@ func TestBindToDeviceDistribution(t *testing.T) { } }(ep) - defer ep.Close() ep.SocketOptions().SetReusePort(endpoint.reuse) if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) } var dstAddr tcpip.Address - switch netProtoNum { + switch protoNum { case ipv4.ProtocolNumber: dstAddr = testDstAddrV4 case ipv6.ProtocolNumber: dstAddr = testDstAddrV6 default: - t.Fatalf("unexpected protocol number: %d", netProtoNum) + t.Fatalf("unexpected protocol number: %d", protoNum) } if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil { t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err) } } - npackets := 100000 - nports := 10000 + // Send packets across a range of ports, checking that packets from + // the same source port are always demultiplexed to the same + // destination endpoint. + npackets := 10_000 + nports := 1_000 if got, want := len(test.endpoints), len(wantDistribution); got != want { t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) } - ports := make(map[uint16]tcpip.Endpoint) + endpoints := make(map[uint16]tcpip.Endpoint) stats := make(map[tcpip.Endpoint]int) for i := 0; i < npackets; i++ { // Send a packet. @@ -342,13 +353,13 @@ func TestBindToDeviceDistribution(t *testing.T) { srcPort: testSrcPort + port, dstPort: testDstPort, } - switch netProtoNum { + switch protoNum { case ipv4.ProtocolNumber: c.sendV4Packet(payload, hdrs, device) case ipv6.ProtocolNumber: c.sendV6Packet(payload, hdrs, device) default: - t.Fatalf("unexpected protocol number: %d", netProtoNum) + t.Fatalf("unexpected protocol number: %d", protoNum) } ep := <-pollChannel @@ -357,11 +368,11 @@ func TestBindToDeviceDistribution(t *testing.T) { } stats[ep]++ if i < nports { - ports[uint16(i)] = ep + endpoints[uint16(i)] = ep } else { // Check that all packets from one client are handled by the same // socket. - if want, got := ports[port], ep; want != got { + if want, got := endpoints[port], ep; want != got { t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) } } diff --git a/pkg/tcpip/stdclock.go b/pkg/tcpip/stdclock.go index 7ce43a68e..371da2f40 100644 --- a/pkg/tcpip/stdclock.go +++ b/pkg/tcpip/stdclock.go @@ -60,11 +60,11 @@ type stdClock struct { // monotonicOffset is assigned maxMonotonic after restore so that the // monotonic time will continue from where it "left off" before saving as part // of S/R. - monotonicOffset int64 `state:"nosave"` + monotonicOffset MonotonicTime `state:"nosave"` // monotonicMU protects maxMonotonic. monotonicMU sync.Mutex `state:"nosave"` - maxMonotonic int64 + maxMonotonic MonotonicTime } // NewStdClock returns an instance of a clock that uses the time package. @@ -76,25 +76,25 @@ func NewStdClock() Clock { var _ Clock = (*stdClock)(nil) -// NowNanoseconds implements Clock.NowNanoseconds. -func (*stdClock) NowNanoseconds() int64 { - return time.Now().UnixNano() +// Now implements Clock.Now. +func (*stdClock) Now() time.Time { + return time.Now() } // NowMonotonic implements Clock.NowMonotonic. -func (s *stdClock) NowMonotonic() int64 { +func (s *stdClock) NowMonotonic() MonotonicTime { sinceBase := time.Since(s.baseTime) if sinceBase < 0 { panic(fmt.Sprintf("got negative duration = %s since base time = %s", sinceBase, s.baseTime)) } - monotonicValue := sinceBase.Nanoseconds() + s.monotonicOffset + monotonicValue := s.monotonicOffset.Add(sinceBase) s.monotonicMU.Lock() defer s.monotonicMU.Unlock() // Monotonic time values must never decrease. - if monotonicValue > s.maxMonotonic { + if s.maxMonotonic.Before(monotonicValue) { s.maxMonotonic = monotonicValue } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 797778e08..8f2658f64 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -64,17 +64,47 @@ func (e *ErrSaveRejection) Error() string { return "save rejected due to unsupported networking state: " + e.Err.Error() } +// MonotonicTime is a monotonic clock reading. +// +// +stateify savable +type MonotonicTime struct { + nanoseconds int64 +} + +// Before reports whether the monotonic clock reading mt is before u. +func (mt MonotonicTime) Before(u MonotonicTime) bool { + return mt.nanoseconds < u.nanoseconds +} + +// After reports whether the monotonic clock reading mt is after u. +func (mt MonotonicTime) After(u MonotonicTime) bool { + return mt.nanoseconds > u.nanoseconds +} + +// Add returns the monotonic clock reading mt+d. +func (mt MonotonicTime) Add(d time.Duration) MonotonicTime { + return MonotonicTime{ + nanoseconds: time.Unix(0, mt.nanoseconds).Add(d).Sub(time.Unix(0, 0)).Nanoseconds(), + } +} + +// Sub returns the duration mt-u. If the result exceeds the maximum (or minimum) +// value that can be stored in a Duration, the maximum (or minimum) duration +// will be returned. To compute t-d for a duration d, use t.Add(-d). +func (mt MonotonicTime) Sub(u MonotonicTime) time.Duration { + return time.Unix(0, mt.nanoseconds).Sub(time.Unix(0, u.nanoseconds)) +} + // A Clock provides the current time and schedules work for execution. // // Times returned by a Clock should always be used for application-visible // time. Only monotonic times should be used for netstack internal timekeeping. type Clock interface { - // NowNanoseconds returns the current real time as a number of - // nanoseconds since the Unix epoch. - NowNanoseconds() int64 + // Now returns the current local time. + Now() time.Time - // NowMonotonic returns a monotonic time value at nanosecond resolution. - NowMonotonic() int64 + // NowMonotonic returns the current monotonic clock reading. + NowMonotonic() MonotonicTime // AfterFunc waits for the duration to elapse and then calls f in its own // goroutine. It returns a Timer that can be used to cancel the call using @@ -435,11 +465,11 @@ type ControlMessages struct { // PacketOwner is used to get UID and GID of the packet. type PacketOwner interface { - // UID returns UID of the packet. - UID() uint32 + // UID returns KUID of the packet. + KUID() uint32 - // GID returns GID of the packet. - GID() uint32 + // GID returns KGID of the packet. + KGID() uint32 } // ReadOptions contains options for Endpoint.Read. @@ -861,6 +891,9 @@ type SettableSocketOption interface { isSettableSocketOption() } +// EndpointState represents the state of an endpoint. +type EndpointState uint8 + // CongestionControlState indicates the current congestion control state for // TCP sender. type CongestionControlState int @@ -897,6 +930,9 @@ type TCPInfoOption struct { // RTO is the retransmission timeout for the endpoint. RTO time.Duration + // State is the current endpoint protocol state. + State EndpointState + // CcState is the congestion control state. CcState CongestionControlState @@ -1552,6 +1588,10 @@ type IPForwardingStats struct { // were too big for the outgoing MTU. PacketTooBig *StatCounter + // HostUnreachable is the number of IP packets received which could not be + // successfully forwarded due to an unresolvable next hop. + HostUnreachable *StatCounter + // ExtensionHeaderProblem is the number of IP packets which were dropped // because of a problem encountered when processing an IPv6 extension // header. @@ -1835,37 +1875,104 @@ type UDPStats struct { ChecksumErrors *StatCounter } +// NICNeighborStats holds metrics for the neighbor table. +type NICNeighborStats struct { + // LINT.IfChange(NICNeighborStats) + + // UnreachableEntryLookups counts the number of lookups performed on an + // entry in Unreachable state. + UnreachableEntryLookups *StatCounter + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICNeighborStats) +} + +// NICPacketStats holds basic packet statistics. +type NICPacketStats struct { + // LINT.IfChange(NICPacketStats) + + // Packets is the number of packets counted. + Packets *StatCounter + + // Bytes is the number of bytes counted. + Bytes *StatCounter + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICPacketStats) +} + +// NICStats holds NIC statistics. +type NICStats struct { + // LINT.IfChange(NICStats) + + // UnknownL3ProtocolRcvdPackets is the number of packets received that were + // for an unknown or unsupported network protocol. + UnknownL3ProtocolRcvdPackets *StatCounter + + // UnknownL4ProtocolRcvdPackets is the number of packets received that were + // for an unknown or unsupported transport protocol. + UnknownL4ProtocolRcvdPackets *StatCounter + + // MalformedL4RcvdPackets is the number of packets received by a NIC that + // could not be delivered to a transport endpoint because the L4 header could + // not be parsed. + MalformedL4RcvdPackets *StatCounter + + // Tx contains statistics about transmitted packets. + Tx NICPacketStats + + // Rx contains statistics about received packets. + Rx NICPacketStats + + // DisabledRx contains statistics about received packets on disabled NICs. + DisabledRx NICPacketStats + + // Neighbor contains statistics about neighbor entries. + Neighbor NICNeighborStats + + // LINT.ThenChange(stack/nic_stats.go:multiCounterNICStats) +} + +// FillIn returns a copy of s with nil fields initialized to new StatCounters. +func (s NICStats) FillIn() NICStats { + InitStatCounters(reflect.ValueOf(&s).Elem()) + return s +} + // Stats holds statistics about the networking stack. -// -// All fields are optional. type Stats struct { - // UnknownProtocolRcvdPackets is the number of packets received by the - // stack that were for an unknown or unsupported protocol. - UnknownProtocolRcvdPackets *StatCounter + // TODO(https://gvisor.dev/issues/5986): Make the DroppedPackets stat less + // ambiguous. - // MalformedRcvdPackets is the number of packets received by the stack - // that were deemed malformed. - MalformedRcvdPackets *StatCounter - - // DroppedPackets is the number of packets dropped due to full queues. + // DroppedPackets is the number of packets dropped at the transport layer. DroppedPackets *StatCounter - // ICMP breaks out ICMP-specific stats (both v4 and v6). + // NICs is an aggregation of every NIC's statistics. These should not be + // incremented using this field, but using the relevant NIC multicounters. + NICs NICStats + + // ICMP is an aggregation of every NetworkEndpoint's ICMP statistics (both v4 + // and v6). These should not be incremented using this field, but using the + // relevant NetworkEndpoint ICMP multicounters. ICMP ICMPStats - // IGMP breaks out IGMP-specific stats. + // IGMP is an aggregation of every NetworkEndpoint's IGMP statistics. These + // should not be incremented using this field, but using the relevant + // NetworkEndpoint IGMP multicounters. IGMP IGMPStats - // IP breaks out IP-specific stats (both v4 and v6). + // IP is an aggregation of every NetworkEndpoint's IP statistics. These should + // not be incremented using this field, but using the relevant NetworkEndpoint + // IP multicounters. IP IPStats - // ARP breaks out ARP-specific stats. + // ARP is an aggregation of every NetworkEndpoint's ARP statistics. These + // should not be incremented using this field, but using the relevant + // NetworkEndpoint ARP multicounters. ARP ARPStats - // TCP breaks out TCP-specific stats. + // TCP holds TCP-specific stats. TCP TCPStats - // UDP breaks out UDP-specific stats. + // UDP holds UDP-specific stats. UDP UDPStats } diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index 269081ff8..c96ae2f02 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -19,7 +19,6 @@ import ( "fmt" "io" "net" - "strings" "testing" "github.com/google/go-cmp/cmp" @@ -210,26 +209,6 @@ func TestAddressString(t *testing.T) { } } -func TestStatsString(t *testing.T) { - got := fmt.Sprintf("%+v", Stats{}.FillIn()) - - matchers := []string{ - // Print root-level stats correctly. - "UnknownProtocolRcvdPackets:0", - // Print protocol-specific stats correctly. - "TCP:{ActiveConnectionOpenings:0", - } - - for _, m := range matchers { - if !strings.Contains(got, m) { - t.Errorf("string.Contains(got, %q) = false", m) - } - } - if t.Failed() { - t.Logf(`got = fmt.Sprintf("%%+v", Stats{}.FillIn()) = %q`, got) - } -} - func TestAddressWithPrefixSubnet(t *testing.T) { tests := []struct { addr Address diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index ab2dab60c..181ef799e 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -53,12 +53,14 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/faketime", "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", + "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 42bc53328..92fa6257d 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -16,6 +16,7 @@ package forward_test import ( "bytes" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -34,6 +35,39 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +const ttl = 64 + +var ( + ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") + ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") +) + +func rxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) +} + +func rxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) +} + +func forwardedICMPv4EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo))) +} + +func forwardedICMPv6EchoRequestChecker(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoRequest))) +} + func TestForwarding(t *testing.T) { const listenPort = 8080 @@ -320,45 +354,16 @@ func TestMulticastForwarding(t *testing.T) { const ( nicID1 = 1 nicID2 = 2 - ttl = 64 ) var ( ipv4LinkLocalUnicastAddr = testutil.MustParse4("169.254.0.10") ipv4LinkLocalMulticastAddr = testutil.MustParse4("224.0.0.10") - ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") ipv6LinkLocalUnicastAddr = testutil.MustParse6("fe80::a") ipv6LinkLocalMulticastAddr = testutil.MustParse6("ff02::a") - ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") ) - rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv4EchoRequest(e, src, dst, ttl) - } - - rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { - utils.RxICMPv6EchoRequest(e, src, dst, ttl) - } - - v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv4(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv4( - checker.ICMPv4Type(header.ICMPv4Echo))) - } - - v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { - checker.IPv6(t, b, - checker.SrcAddr(src), - checker.DstAddr(dst), - checker.TTL(ttl-1), - checker.ICMPv6( - checker.ICMPv6Type(header.ICMPv6EchoRequest))) - } - tests := []struct { name string srcAddr, dstAddr tcpip.Address @@ -394,7 +399,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv4EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) }, }, { @@ -404,7 +409,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv4EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) }, }, @@ -436,7 +441,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv6EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) }, }, { @@ -446,7 +451,7 @@ func TestMulticastForwarding(t *testing.T) { rx: rxICMPv6EchoRequest, expectForward: true, checker: func(t *testing.T, b []byte) { - v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) }, }, } @@ -506,3 +511,180 @@ func TestMulticastForwarding(t *testing.T) { }) } } + +func TestPerInterfaceForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + ) + + tests := []struct { + name string + srcAddr, dstAddr tcpip.Address + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + checker func(*testing.T, []byte) + }{ + { + name: "IPv4 unicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, + rx: rxICMPv4EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv4 multicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4GlobalMulticastAddr, + rx: rxICMPv4EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv4EchoRequestChecker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) + }, + }, + + { + name: "IPv6 unicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, + rx: rxICMPv6EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv6 multicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6GlobalMulticastAddr, + rx: rxICMPv6EchoRequest, + checker: func(t *testing.T, b []byte) { + forwardedICMPv6EchoRequestChecker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) + }, + }, + } + + netProtos := [...]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber, ipv6.ProtocolNumber} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + // ARP is not used in this test but it is a network protocol that does + // not support forwarding. We install the protocol to make sure that + // forwarding information for a NIC is only reported for network + // protocols that support forwarding. + arp.NewProtocol, + + ipv4.NewProtocol, + ipv6.NewProtocol, + }, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) + } + + for _, add := range [...]struct { + nicID tcpip.NICID + addr tcpip.ProtocolAddress + }{ + { + nicID: nicID1, + addr: utils.RouterNIC1IPv4Addr, + }, + { + nicID: nicID1, + addr: utils.RouterNIC1IPv6Addr, + }, + { + nicID: nicID2, + addr: utils.RouterNIC2IPv4Addr, + }, + { + nicID: nicID2, + addr: utils.RouterNIC2IPv6Addr, + }, + } { + if err := s.AddProtocolAddress(add.nicID, add.addr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", add.nicID, add.addr, err) + } + } + + // Only enable forwarding on NIC1 and make sure that only packets arriving + // on NIC1 are forwarded. + for _, netProto := range netProtos { + if err := s.SetNICForwarding(nicID1, netProto, true); err != nil { + t.Fatalf("s.SetNICForwarding(%d, %d, true): %s", nicID1, netProtos, err) + } + } + + nicsInfo := s.NICInfo() + for _, subTest := range [...]struct { + nicID tcpip.NICID + nicEP *channel.Endpoint + otherNICID tcpip.NICID + otherNICEP *channel.Endpoint + expectForwarding bool + }{ + { + nicID: nicID1, + nicEP: e1, + otherNICID: nicID2, + otherNICEP: e2, + expectForwarding: true, + }, + { + nicID: nicID2, + nicEP: e2, + otherNICID: nicID2, + otherNICEP: e1, + expectForwarding: false, + }, + } { + t.Run(fmt.Sprintf("Packet arriving at NIC%d", subTest.nicID), func(t *testing.T) { + nicInfo, ok := nicsInfo[subTest.nicID] + if !ok { + t.Errorf("expected NIC info for NIC %d; got = %#v", subTest.nicID, nicsInfo) + } else { + forwarding := make(map[tcpip.NetworkProtocolNumber]bool) + for _, netProto := range netProtos { + forwarding[netProto] = subTest.expectForwarding + } + + if diff := cmp.Diff(forwarding, nicInfo.Forwarding); diff != "" { + t.Errorf("nicsInfo[%d].Forwarding mismatch (-want +got):\n%s", subTest.nicID, diff) + } + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: subTest.otherNICID, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: subTest.otherNICID, + }, + }) + + test.rx(subTest.nicEP, test.srcAddr, test.dstAddr) + if p, ok := subTest.nicEP.Read(); ok { + t.Errorf("unexpectedly got a response from the interface the packet arrived on: %#v", p) + } + if p, ok := subTest.otherNICEP.Read(); ok != subTest.expectForwarding { + t.Errorf("got otherNICEP.Read() = (%#v, %t), want = (_, %t)", p, ok, subTest.expectForwarding) + } else if subTest.expectForwarding { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 07ba2b837..f9ab7d0af 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -166,7 +166,7 @@ func TestIPTablesStatsForInput(t *testing.T) { // Make sure the packet is not dropped by the next rule. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -187,7 +187,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -207,7 +207,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -227,7 +227,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -250,7 +250,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -273,7 +273,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -293,7 +293,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -313,7 +313,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -465,7 +465,7 @@ func TestIPTableWritePackets(t *testing.T) { } if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil { - t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err) + t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err) } }, genPacket: func(r *stack.Route) stack.PacketBufferList { @@ -556,7 +556,7 @@ func TestIPTableWritePackets(t *testing.T) { } if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil { - t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err) + t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err) } }, genPacket: func(r *stack.Route) stack.PacketBufferList { @@ -681,6 +681,32 @@ func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Addr checker.ICMPv6Type(header.ICMPv6EchoReply))) } +func boolToInt(v bool) uint64 { + if v { + return 1 + } + return 0 +} + +func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { + return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, ipv6) + ruleIdx := filter.BuiltinChains[hook] + filter.Rules[ruleIdx].Filter = f + filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} + if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) + } + } +} + func TestForwardingHook(t *testing.T) { const ( nicID1 = 1 @@ -740,32 +766,6 @@ func TestForwardingHook(t *testing.T) { }, } - setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { - return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { - t.Helper() - - ipv6 := netProto == ipv6.ProtocolNumber - - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, ipv6) - ruleIdx := filter.BuiltinChains[stack.Forward] - filter.Rules[ruleIdx].Filter = f - filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} - // Make sure the packet is not dropped by the next rule. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} - if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) - } - } - } - - boolToInt := func(v bool) uint64 { - if v { - return 1 - } - return 0 - } - subTests := []struct { name string setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) @@ -779,59 +779,59 @@ func TestForwardingHook(t *testing.T) { { name: "Drop", - setupFilter: setupDropFilter(stack.IPHeaderFilter{}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}), expectForward: false, }, { name: "Drop with input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}), expectForward: false, }, { name: "Drop with output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}), expectForward: false, }, { name: "Drop with input and output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), expectForward: false, }, { name: "Drop with other input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other input and output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), expectForward: true, }, { name: "Drop with input and other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other input and other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with inverted input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), expectForward: true, }, { name: "Drop with inverted output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), expectForward: true, }, } @@ -941,3 +941,194 @@ func TestForwardingHook(t *testing.T) { }) } } + +func TestInputHookWithLocalForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Name = "nic1" + nic2Name = "nic2" + + otherNICName = "otherNIC" + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + rx func(*channel.Endpoint) + checker func(*testing.T, []byte) + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + rx: func(e *channel.Endpoint) { + utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl) + }, + checker: func(t *testing.T, b []byte) { + checker.IPv4(t, b, + checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address), + checker.DstAddr(utils.RemoteIPv4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply))) + }, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + rx: func(e *channel.Endpoint) { + utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl) + }, + checker: func(t *testing.T, b []byte) { + checker.IPv6(t, b, + checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address), + checker.DstAddr(utils.RemoteIPv6Addr), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply))) + }, + }, + } + + subTests := []struct { + name string + setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) + expectDrop bool + }{ + { + name: "Accept", + setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, + expectDrop: false, + }, + + { + name: "Drop", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}), + expectDrop: true, + }, + { + name: "Drop with input NIC filtering on arrival NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}), + expectDrop: true, + }, + { + name: "Drop with input NIC filtering on delivered NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}), + expectDrop: false, + }, + + { + name: "Drop with input NIC filtering on other NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}), + expectDrop: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + }) + + subTest.setupFilter(t, s, test.netProto) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err) + } + if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) + } + if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err) + } + if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err) + } + + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID1, + }, + }) + + test.rx(e1) + + ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) + } + ep1Stats := ep1.Stats() + ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) + } + ip1Stats := ipEP1Stats.IPStats() + + if got := ip1Stats.PacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) + } + if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want { + t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want) + } + + ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) + } + ep2Stats := ep2.Stats() + ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) + } + ip2Stats := ipEP2Stats.IPStats() + if got := ip2Stats.PacketsReceived.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) + } + if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want { + t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want) + } + if got := ip2Stats.PacketsSent.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got) + } + + if p, ok := e1.Read(); ok == subTest.expectDrop { + t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop) + } else if !subTest.expectDrop { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + if p, ok := e2.Read(); ok { + t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index c657714ba..27caa0c28 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -17,6 +17,8 @@ package link_resolution_test import ( "bytes" "fmt" + "net" + "runtime" "testing" "time" @@ -27,12 +29,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" + tcptestutil "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -283,9 +287,11 @@ func TestTCPLinkResolutionFailure(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + Clock: clock, } host1Stack, host2Stack := setupStack(t, stackOpts, host1NICID, host2NICID) @@ -329,7 +335,17 @@ func TestTCPLinkResolutionFailure(t *testing.T) { // Wait for an error due to link resolution failing, or the endpoint to be // writable. + if test.expectedWriteErr != nil { + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) + } + clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) + } else { + clock.RunImmediatelyScheduledJobs() + } <-ch + { var r bytes.Reader r.Reset([]byte{0}) @@ -395,6 +411,242 @@ func TestTCPLinkResolutionFailure(t *testing.T) { } } +func TestForwardingWithLinkResolutionFailure(t *testing.T) { + const ( + incomingNICID = 1 + outgoingNICID = 2 + ttl = 2 + expectedHostUnreachableErrorCount = 1 + ) + outgoingLinkAddr := tcptestutil.MustParseLink("02:03:03:04:05:06") + + rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) + } + + rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) + } + + arpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { + if request.Proto != arp.ProtocolNumber { + t.Errorf("got request.Proto = %d, want = %d", request.Proto, arp.ProtocolNumber) + } + if request.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got request.Route.RemoteLinkAddress = %s, want = %s", request.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + } + rep := header.ARP(request.Pkt.NetworkHeader().View()) + if got := rep.Op(); got != header.ARPRequest { + t.Errorf("got Op() = %d, want = %d", got, header.ARPRequest) + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != outgoingLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, outgoingLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != src { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, src) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != dst { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, dst) + } + } + + ndpChecker := func(t *testing.T, request channel.PacketInfo, src, dst tcpip.Address) { + if request.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", request.Proto, header.IPv6ProtocolNumber) + } + + snmc := header.SolicitedNodeAddr(dst) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); request.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", request.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(request.Pkt.NetworkHeader()), + checker.SrcAddr(src), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(dst), + )) + } + + icmpv4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ipv4.DefaultTTL), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4HostUnreachable), + ), + ) + } + + icmpv6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ipv6.DefaultTTL), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6AddressUnreachable), + ), + ) + } + + tests := []struct { + name string + networkProtocolFactory []stack.NetworkProtocolFactory + networkProtocolNumber tcpip.NetworkProtocolNumber + sourceAddr tcpip.Address + destAddr tcpip.Address + incomingAddr tcpip.AddressWithPrefix + outgoingAddr tcpip.AddressWithPrefix + transportProtocol func(*stack.Stack) stack.TransportProtocol + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + linkResolutionRequestChecker func(*testing.T, channel.PacketInfo, tcpip.Address, tcpip.Address) + icmpReplyChecker func(*testing.T, []byte, tcpip.Address, tcpip.Address) + mtu uint32 + }{ + { + name: "IPv4 Host unreachable", + networkProtocolFactory: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + networkProtocolNumber: header.IPv4ProtocolNumber, + sourceAddr: tcptestutil.MustParse4("10.0.0.2"), + destAddr: tcptestutil.MustParse4("11.0.0.2"), + incomingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), + PrefixLen: 8, + }, + outgoingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), + PrefixLen: 8, + }, + transportProtocol: icmp.NewProtocol4, + linkResolutionRequestChecker: arpChecker, + icmpReplyChecker: icmpv4Checker, + rx: rxICMPv4EchoRequest, + mtu: ipv4.MaxTotalSize, + }, + { + name: "IPv6 Host unreachable", + networkProtocolFactory: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, + networkProtocolNumber: header.IPv6ProtocolNumber, + sourceAddr: tcptestutil.MustParse6("10::2"), + destAddr: tcptestutil.MustParse6("11::2"), + incomingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + }, + outgoingAddr: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11::1").To16()), + PrefixLen: 64, + }, + transportProtocol: icmp.NewProtocol6, + linkResolutionRequestChecker: ndpChecker, + icmpReplyChecker: icmpv6Checker, + rx: rxICMPv6EchoRequest, + mtu: header.IPv6MinimumMTU, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + + s := stack.New(stack.Options{ + NetworkProtocols: test.networkProtocolFactory, + TransportProtocols: []stack.TransportProtocolFactory{test.transportProtocol}, + Clock: clock, + }) + + // Set up endpoint through which we will receive packets. + incomingEndpoint := channel.New(1, test.mtu, "") + if err := s.CreateNIC(incomingNICID, incomingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", incomingNICID, err) + } + incomingProtoAddr := tcpip.ProtocolAddress{ + Protocol: test.networkProtocolNumber, + AddressWithPrefix: test.incomingAddr, + } + if err := s.AddProtocolAddress(incomingNICID, incomingProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", incomingNICID, incomingProtoAddr, err) + } + + // Set up endpoint through which we will attempt to forward packets. + outgoingEndpoint := channel.New(1, test.mtu, outgoingLinkAddr) + outgoingEndpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(outgoingNICID, outgoingEndpoint); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", outgoingNICID, err) + } + outgoingProtoAddr := tcpip.ProtocolAddress{ + Protocol: test.networkProtocolNumber, + AddressWithPrefix: test.outgoingAddr, + } + if err := s.AddProtocolAddress(outgoingNICID, outgoingProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", outgoingNICID, outgoingProtoAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: test.incomingAddr.Subnet(), + NIC: incomingNICID, + }, + { + Destination: test.outgoingAddr.Subnet(), + NIC: outgoingNICID, + }, + }) + + if err := s.SetForwardingDefaultAndAllNICs(test.networkProtocolNumber, true); err != nil { + t.Fatalf("SetForwardingDefaultAndAllNICs(%d, true): %s", test.networkProtocolNumber, err) + } + + test.rx(incomingEndpoint, test.sourceAddr, test.destAddr) + + nudConfigs, err := s.NUDConfigurations(outgoingNICID, test.networkProtocolNumber) + if err != nil { + t.Fatalf("s.NUDConfigurations(%d, %d): %s", outgoingNICID, test.networkProtocolNumber, err) + } + // Trigger the first packet on the endpoint. + clock.RunImmediatelyScheduledJobs() + + for i := 0; i < int(nudConfigs.MaxMulticastProbes); i++ { + request, ok := outgoingEndpoint.Read() + if !ok { + t.Fatal("expected ARP packet through outgoing NIC") + } + + test.linkResolutionRequestChecker(t, request, test.outgoingAddr.Address, test.destAddr) + + // Advance the clock the span of one request timeout. + clock.Advance(nudConfigs.RetransmitTimer) + } + + // Next, we make a blocking read to retrieve the error packet. This is + // necessary because outgoing packets are dequeued asynchronously when + // link resolution fails, and this dequeue is what triggers the ICMP + // error. + reply, ok := incomingEndpoint.Read() + if !ok { + t.Fatal("expected ICMP packet through incoming NIC") + } + + test.icmpReplyChecker(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), test.incomingAddr.Address, test.sourceAddr) + + // Since link resolution failed, we don't expect the packet to be + // forwarded. + forwardedPacket, ok := outgoingEndpoint.Read() + if ok { + t.Fatalf("expected no ICMP Echo packet through outgoing NIC, instead found: %#v", forwardedPacket) + } + + if got, want := s.Stats().IP.Forwarding.HostUnreachable.Value(), expectedHostUnreachableErrorCount; int(got) != want { + t.Errorf("got rt.Stats().IP.Forwarding.HostUnreachable.Value() = %d, want = %d", got, want) + } + }) + } +} + func TestGetLinkAddress(t *testing.T) { const ( host1NICID = 1 @@ -449,8 +701,10 @@ func TestGetLinkAddress(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + Clock: clock, } host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) @@ -466,8 +720,20 @@ func TestGetLinkAddress(t *testing.T) { if test.expectedErr == nil { wantRes.LinkAddress = utils.LinkAddr2 } - if diff := cmp.Diff(wantRes, <-ch); diff != "" { - t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) + } + + clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) + select { + case got := <-ch: + if diff := cmp.Diff(wantRes, got); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("event didn't arrive") } }) } @@ -544,8 +810,10 @@ func TestRouteResolvedFields(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + Clock: clock, } host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) @@ -575,8 +843,20 @@ func TestRouteResolvedFields(t *testing.T) { if _, ok := err.(*tcpip.ErrWouldBlock); !ok { t.Errorf("got r.ResolvedFields(_) = %s, want = %s", err, &tcpip.ErrWouldBlock{}) } - if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, <-ch, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { - t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) + + nudConfigs, err := host1Stack.NUDConfigurations(host1NICID, test.netProto) + if err != nil { + t.Fatalf("host1Stack.NUDConfigurations(%d, %d): %s", host1NICID, test.netProto, err) + } + clock.Advance(time.Duration(nudConfigs.MaxMulticastProbes) * nudConfigs.RetransmitTimer) + + select { + case got := <-ch: + if diff := cmp.Diff(stack.ResolvedFieldsResult{RouteInfo: wantRouteInfo, Err: test.expectedErr}, got, cmp.AllowUnexported(stack.RouteInfo{})); diff != "" { + t.Errorf("route resolve result mismatch (-want +got):\n%s", diff) + } + default: + t.Fatalf("event didn't arrive") } if test.expectedErr != nil { @@ -789,11 +1069,16 @@ func (d *nudDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.Neighbo d.c <- e } -func (d *nudDispatcher) waitForEvent(want eventInfo) error { - if diff := cmp.Diff(want, <-d.c, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { - return fmt.Errorf("got invalid event (-want +got):\n%s", diff) +func (d *nudDispatcher) expectEvent(want eventInfo) error { + select { + case got := <-d.c: + if diff := cmp.Diff(want, got, cmp.AllowUnexported(eventInfo{}), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAt")); diff != "" { + return fmt.Errorf("got invalid event (-want +got):\n%s", diff) + } + return nil + default: + return fmt.Errorf("event didn't arrive") } - return nil } // TestTCPConfirmNeighborReachability tests that TCP informs layers beneath it @@ -804,7 +1089,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto tcpip.NetworkProtocolNumber remoteAddr tcpip.Address neighborAddr tcpip.Address - getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) + getEndpoints func(*testing.T, *stack.Stack, *stack.Stack, *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) isHost1Listener bool }{ { @@ -812,23 +1097,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -836,23 +1123,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -860,23 +1149,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -884,23 +1175,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, }, { @@ -908,23 +1201,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -933,23 +1228,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, routerStack, _ *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := routerStack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("routerStack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -958,23 +1255,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv4.ProtocolNumber, remoteAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv4.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -983,23 +1282,25 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { netProto: ipv6.ProtocolNumber, remoteAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, neighborAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, - getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, tcpip.Endpoint, <-chan struct{}) { + getEndpoints: func(t *testing.T, host1Stack, _, host2Stack *stack.Stack) (tcpip.Endpoint, <-chan struct{}, tcpip.Endpoint, <-chan struct{}) { var listenerWQ waiter.Queue + listenerWE, listenerCH := waiter.NewChannelEntry(nil) + listenerWQ.EventRegister(&listenerWE, waiter.EventIn) listenerEP, err := host1Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &listenerWQ) if err != nil { t.Fatalf("host1Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } + t.Cleanup(listenerEP.Close) var clientWQ waiter.Queue clientWE, clientCH := waiter.NewChannelEntry(nil) - clientWQ.EventRegister(&clientWE, waiter.WritableEvents) + clientWQ.EventRegister(&clientWE, waiter.ReadableEvents|waiter.WritableEvents) clientEP, err := host2Stack.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &clientWQ) if err != nil { - listenerEP.Close() t.Fatalf("host2Stack.NewEndpoint(%d, %d, _): %s", tcp.ProtocolNumber, ipv6.ProtocolNumber, err) } - return listenerEP, clientEP, clientCH + return listenerEP, listenerCH, clientEP, clientCH }, isHost1Listener: true, }, @@ -1037,14 +1338,14 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("link resolution mismatch (-want +got):\n%s", diff) } } - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryAdded, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Incomplete, Addr: test.neighborAddr}, }); err != nil { t.Fatalf("error waiting for initial NUD event: %s", err) } - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1064,7 +1365,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // See NUDConfigurations.BaseReachableTime for more information. maxReachableTime := time.Duration(float32(nudConfigs.BaseReachableTime) * nudConfigs.MaxRandomFactor) clock.Advance(maxReachableTime) - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1072,8 +1373,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("error waiting for stale NUD event: %s", err) } - listenerEP, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) - defer listenerEP.Close() + listenerEP, listenerCH, clientEP, clientCH := test.getEndpoints(t, host1Stack, routerStack, host2Stack) defer clientEP.Close() listenerAddr := tcpip.FullAddress{Addr: test.remoteAddr, Port: 1234} if err := listenerEP.Bind(listenerAddr); err != nil { @@ -1094,14 +1394,15 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // with confirmation that the neighbor is reachable (indicated by a // successful 3-way handshake). <-clientCH - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, }); err != nil { t.Fatalf("error waiting for delay NUD event: %s", err) } - if err := nudDisp.waitForEvent(eventInfo{ + <-listenerCH + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1109,26 +1410,55 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("error waiting for reachable NUD event: %s", err) } + peerEP, peerWQ, err := listenerEP.Accept(nil) + if err != nil { + t.Fatalf("listenerEP.Accept(): %s", err) + } + defer peerEP.Close() + peerWE, peerCH := waiter.NewChannelEntry(nil) + peerWQ.EventRegister(&peerWE, waiter.ReadableEvents) + // Wait for the neighbor to be stale again then send data to the remote. // // On successful transmission, the neighbor should become reachable // without probing the neighbor as a TCP ACK would be received which is an // indication of the neighbor being reachable. clock.Advance(maxReachableTime) - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Stale, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, }); err != nil { t.Fatalf("error waiting for stale NUD event: %s", err) } - var r bytes.Reader - r.Reset([]byte{0}) - var wOpts tcpip.WriteOptions - if _, err := clientEP.Write(&r, wOpts); err != nil { - t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) + { + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + if _, err := clientEP.Write(&r, wOpts); err != nil { + t.Errorf("clientEP.Write(_, %#v): %s", wOpts, err) + } + } + // Heads up, there is a race here. + // + // Incoming TCP segments are handled in + // tcp.(*endpoint).handleSegmentLocked: + // + // - tcp.(*endpoint).rcv.handleRcvdSegment puts the segment on the + // segment queue and notifies waiting readers (such as this channel) + // + // - tcp.(*endpoint).snd.handleRcvdSegment sends an ACK for the segment + // and notifies the NUD machinery that the peer is reachable + // + // Thus we must permit a delay between the readable signal and the + // expected NUD event. + // + // At the time of writing, this race is reliably hit with gotsan. + <-peerCH + for len(nudDisp.c) == 0 { + runtime.Gosched() } - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Delay, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1141,7 +1471,7 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { // TCP should not mark the route reachable and NUD should go through the // probe state. clock.Advance(nudConfigs.DelayFirstProbeTime) - if err := nudDisp.waitForEvent(eventInfo{ + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Probe, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, @@ -1149,7 +1479,16 @@ func TestTCPConfirmNeighborReachability(t *testing.T) { t.Fatalf("error waiting for probe NUD event: %s", err) } } - if err := nudDisp.waitForEvent(eventInfo{ + { + var r bytes.Reader + r.Reset([]byte{0}) + var wOpts tcpip.WriteOptions + if _, err := peerEP.Write(&r, wOpts); err != nil { + t.Errorf("peerEP.Write(_, %#v): %s", wOpts, err) + } + } + <-clientCH + if err := nudDisp.expectEvent(eventInfo{ eventType: entryChanged, nicID: utils.Host1NICID, entry: stack.NeighborEntry{State: stack.Reachable, Addr: test.neighborAddr, LinkAddr: utils.LinkAddr2}, diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 87d36e1dd..b4b2ec723 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -44,20 +44,17 @@ type ndpDispatcher struct{} func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { } -func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { - return false +func (*ndpDispatcher) OnOffLinkRouteUpdated(tcpip.NICID, tcpip.Subnet, tcpip.Address) { } -func (*ndpDispatcher) OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) {} +func (*ndpDispatcher) OnOffLinkRouteInvalidated(tcpip.NICID, tcpip.Subnet, tcpip.Address) {} -func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool { - return false +func (*ndpDispatcher) OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) { } func (*ndpDispatcher) OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) {} -func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool { - return true +func (*ndpDispatcher) OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) { } func (*ndpDispatcher) OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix) {} diff --git a/pkg/tcpip/testutil/testutil.go b/pkg/tcpip/testutil/testutil.go index f84d399fb..94b580a70 100644 --- a/pkg/tcpip/testutil/testutil.go +++ b/pkg/tcpip/testutil/testutil.go @@ -109,3 +109,15 @@ func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) er return nil } + +// MustParseLink parses a Link string into a tcpip.LinkAddress, panicking on +// error. +// +// The string must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff. +func MustParseLink(addr string) tcpip.LinkAddress { + parsed, err := tcpip.ParseMACAddress(addr) + if err != nil { + panic(fmt.Sprintf("tcpip.ParseMACAddress(%s): %s", addr, err)) + } + return parsed +} diff --git a/pkg/tcpip/time.s b/pkg/tcpip/time.s deleted file mode 100644 index fb37360ac..000000000 --- a/pkg/tcpip/time.s +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Empty assembly file so empty func definitions work. diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index 1633d0aeb..ed1ed8ac6 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -15,6 +15,7 @@ package tcpip_test import ( + "math" "sync" "testing" "time" @@ -22,10 +23,85 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) +func TestMonotonicTimeBefore(t *testing.T) { + var mt tcpip.MonotonicTime + if mt.Before(mt) { + t.Errorf("%#v.Before(%#v)", mt, mt) + } + + one := mt.Add(1) + if one.Before(mt) { + t.Errorf("%#v.Before(%#v)", one, mt) + } + if !mt.Before(one) { + t.Errorf("!%#v.Before(%#v)", mt, one) + } +} + +func TestMonotonicTimeAfter(t *testing.T) { + var mt tcpip.MonotonicTime + if mt.After(mt) { + t.Errorf("%#v.After(%#v)", mt, mt) + } + + one := mt.Add(1) + if mt.After(one) { + t.Errorf("%#v.After(%#v)", mt, one) + } + if !one.After(mt) { + t.Errorf("!%#v.After(%#v)", one, mt) + } +} + +func TestMonotonicTimeAddSub(t *testing.T) { + var mt tcpip.MonotonicTime + if one, two := mt.Add(2), mt.Add(1).Add(1); one != two { + t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two) + } + + min := mt.Add(math.MinInt64) + max := mt.Add(math.MaxInt64) + + if overflow := mt.Add(1).Add(math.MaxInt64); overflow != max { + t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow) + } + if underflow := mt.Add(-1).Add(math.MinInt64); underflow != min { + t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", min, underflow) + } + + if got, want := min.Sub(min), time.Duration(0); want != got { + t.Errorf("got min.Sub(min) = %d, want %d", got, want) + } + if got, want := max.Sub(max), time.Duration(0); want != got { + t.Errorf("got max.Sub(max) = %d, want %d", got, want) + } + + if overflow, want := max.Sub(min), time.Duration(math.MaxInt64); overflow != want { + t.Errorf("mt.Add(math.MaxInt64).Sub(mt.Add(math.MinInt64) != %s (%#v)", want, overflow) + } + if underflow, want := min.Sub(max), time.Duration(math.MinInt64); underflow != want { + t.Errorf("mt.Add(math.MinInt64).Sub(mt.Add(math.MaxInt64) != %s (%#v)", want, underflow) + } +} + +func TestMonotonicTimeSub(t *testing.T) { + var mt tcpip.MonotonicTime + + if one, two := mt.Add(2), mt.Add(1).Add(1); one != two { + t.Errorf("mt.Add(2) != mt.Add(1).Add(1) (%#v != %#v)", one, two) + } + + if max, overflow := mt.Add(math.MaxInt64), mt.Add(1).Add(math.MaxInt64); max != overflow { + t.Errorf("mt.Add(math.MaxInt64) != mt.Add(1).Add(math.MaxInt64) (%#v != %#v)", max, overflow) + } + if max, underflow := mt.Add(math.MinInt64), mt.Add(-1).Add(math.MinInt64); max != underflow { + t.Errorf("mt.Add(math.MinInt64) != mt.Add(-1).Add(math.MinInt64) (%#v != %#v)", max, underflow) + } +} + const ( shortDuration = 1 * time.Nanosecond middleDuration = 100 * time.Millisecond - longDuration = 1 * time.Second ) func TestJobReschedule(t *testing.T) { @@ -53,10 +129,14 @@ func TestJobReschedule(t *testing.T) { wg.Wait() } +func stdClockWithAfter() (tcpip.Clock, func(time.Duration) <-chan time.Time) { + return tcpip.NewStdClock(), time.After +} + func TestJobExecution(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -68,7 +148,7 @@ func TestJobExecution(t *testing.T) { // Wait for timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -76,14 +156,14 @@ func TestJobExecution(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestCancellableTimerResetFromLongDuration(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -99,7 +179,7 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { // Wait for timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -107,14 +187,14 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestJobRescheduleFromShortDuration(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -128,7 +208,7 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { select { case <-ch: t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration): + case <-after(middleDuration): } job.Schedule(shortDuration) @@ -136,7 +216,7 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { // Wait for timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -144,14 +224,14 @@ func TestJobRescheduleFromShortDuration(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestJobImmediatelyCancel(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -167,14 +247,19 @@ func TestJobImmediatelyCancel(t *testing.T) { select { case <-ch: t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration): + case <-after(middleDuration): } } +func stdClockWithAfterAndSleep() (tcpip.Clock, func(time.Duration) <-chan time.Time, func(time.Duration)) { + clock, after := stdClockWithAfter() + return clock, after, time.Sleep +} + func TestJobCancelledRescheduleWithoutLock(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after, sleep := stdClockWithAfterAndSleep() var lock sync.Mutex ch := make(chan struct{}) @@ -189,7 +274,7 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) { lock.Lock() // Sleep until the timer fires and gets blocked trying to take the lock. - time.Sleep(middleDuration * 2) + sleep(middleDuration * 2) job.Cancel() lock.Unlock() } @@ -199,14 +284,14 @@ func TestJobCancelledRescheduleWithoutLock(t *testing.T) { select { case <-ch: t.Fatal("timer fired after being stopped") - case <-time.After(middleDuration * 2): + case <-after(middleDuration * 2): } } func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after, sleep := stdClockWithAfterAndSleep() var lock sync.Mutex ch := make(chan struct{}) @@ -215,7 +300,7 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { job.Schedule(shortDuration) for i := 0; i < 10; i++ { // Sleep until the timer fires and gets blocked trying to take the lock. - time.Sleep(middleDuration) + sleep(middleDuration) job.Cancel() job.Schedule(shortDuration) } @@ -224,7 +309,7 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { // Wait for double the duration for the last timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -232,14 +317,14 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } func TestManyJobReschedulesUnderLock(t *testing.T) { t.Parallel() - clock := tcpip.NewStdClock() + clock, after := stdClockWithAfter() var lock sync.Mutex ch := make(chan struct{}) @@ -255,7 +340,7 @@ func TestManyJobReschedulesUnderLock(t *testing.T) { // Wait for double the duration for the last timer to fire. select { case <-ch: - case <-time.After(middleDuration): + case <-after(middleDuration): t.Fatal("timed out waiting for timer to fire") } @@ -263,6 +348,6 @@ func TestManyJobReschedulesUnderLock(t *testing.T) { select { case <-ch: t.Fatal("no other timers should have fired") - case <-time.After(middleDuration): + case <-after(middleDuration): } } diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 7e5c79776..bbc0e3ecc 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -38,3 +38,22 @@ go_library( "//pkg/waiter", ], ) + +go_test( + name = "icmp_x_test", + size = "small", + srcs = ["icmp_test.go"], + deps = [ + ":icmp", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/testutil", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 8afde7fca..cb316d27a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -16,6 +16,7 @@ package icmp import ( "io" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -26,14 +27,12 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// TODO(https://gvisor.dev/issues/5623): Unit test this package. - // +stateify savable type icmpPacket struct { icmpPacketEntry senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + receivedAt time.Time `state:".(int64)"` } type endpointState int @@ -133,7 +132,8 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) + e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice) } // Close the receive list and drain it. @@ -160,7 +160,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { @@ -193,7 +193,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult Total: p.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: p.timestamp, + Timestamp: p.receivedAt.UnixNano(), }, } if opts.NeedRemoteAddr { @@ -304,6 +304,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicID := to.NIC + if nicID == 0 { + nicID = tcpip.NICID(e.ops.GetBindToDevice()) + } if e.BindNICID != 0 { if nicID != 0 && nicID != e.BindNICID { return 0, &tcpip.ErrNoRoute{} @@ -348,8 +351,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return int64(len(v)), nil } +var _ tcpip.SocketOptionsHandler = (*endpoint)(nil) + +// HasNIC implements tcpip.SocketOptionsHandler. +func (e *endpoint) HasNIC(id int32) bool { + return e.stack.HasNIC(tcpip.NICID(id)) +} + // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { +func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { return nil } @@ -390,7 +400,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -606,18 +616,19 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { +func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) + err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) return id, err } // We need to find a port for the endpoint. - _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { + _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) + err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) switch err.(type) { case nil: return true, nil @@ -747,8 +758,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB switch e.NetProto { case header.IPv4ProtocolNumber: h := header.ICMPv4(pkt.TransportHeader().View()) - // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed - // after early parsing. if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -756,8 +765,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } case header.IPv6ProtocolNumber: h := header.ICMPv6(pkt.TransportHeader().View()) - // TODO(gvisor.dev/issue/170): Determine if len(h) check is still needed - // after early parsing. if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() @@ -800,7 +807,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() - packet.timestamp = e.stack.Clock().NowNanoseconds() + packet.receivedAt = e.stack.Clock().Now() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index 28a56a2d5..b8b839e4a 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -15,11 +15,23 @@ package icmp import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *icmpPacket) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *icmpPacket) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves icmpPacket.data field. func (p *icmpPacket) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go new file mode 100644 index 000000000..cc950cbde --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_test.go @@ -0,0 +1,235 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package icmp_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// TODO(https://gvisor.dev/issues/5623): Finish unit testing the icmp package. +// See the issue for remaining areas of work. + +var ( + localV4Addr1 = testutil.MustParse4("10.0.0.1") + localV4Addr2 = testutil.MustParse4("10.0.0.2") + remoteV4Addr = testutil.MustParse4("10.0.0.3") +) + +func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name string, addrV4 tcpip.Address) *channel.Endpoint { + t.Helper() + + ep := channel.New(1 /* size */, header.IPv4MinimumMTU, "" /* linkAddr */) + t.Cleanup(ep.Close) + + wep := stack.LinkEndpoint(ep) + if testing.Verbose() { + wep = sniffer.New(ep) + } + + opts := stack.NICOptions{Name: name} + if err := s.CreateNICWithOptions(id, wep, opts); err != nil { + t.Fatalf("s.CreateNIC(%d, _) = %s", id, err) + } + + if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err) + } + + s.AddRoute(tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: id, + }) + + return ep +} + +func writePayload(buf []byte) { + for i := range buf { + buf[i] = byte(i) + } +} + +func newICMPv4EchoRequest(payloadSize uint32) buffer.View { + buf := buffer.NewView(header.ICMPv4MinimumSize + int(payloadSize)) + writePayload(buf[header.ICMPv4MinimumSize:]) + + icmp := header.ICMPv4(buf) + icmp.SetType(header.ICMPv4Echo) + // No need to set the checksum; it is reset by the socket before the packet + // is sent. + + return buf +} + +// TestWriteUnboundWithBindToDevice exercises writing to an unbound ICMP socket +// when SO_BINDTODEVICE is set to the non-default NIC for that subnet. +// +// Only IPv4 is tested. The logic to determine which NIC to use is agnostic to +// the version of IP. +func TestWriteUnboundWithBindToDevice(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + HandleLocal: true, + }) + + // Add two NICs, both with default routes on the same subnet. The first NIC + // added will be the default NIC for that subnet. + defaultEP := addNICWithDefaultRoute(t, s, 1, "default", localV4Addr1) + alternateEP := addNICWithDefaultRoute(t, s, 2, "alternate", localV4Addr2) + + socket, err := s.NewEndpoint(icmp.ProtocolNumber4, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _) = %s", icmp.ProtocolNumber4, ipv4.ProtocolNumber, err) + } + defer socket.Close() + + echoPayloadSize := defaultEP.MTU() - header.IPv4MinimumSize - header.ICMPv4MinimumSize + + // Send a packet without SO_BINDTODEVICE. This verifies that the first NIC + // to be added is the default NIC to send packets when not explicitly bound. + { + buf := newICMPv4EchoRequest(echoPayloadSize) + r := buf.Reader() + n, err := socket.Write(&r, tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: remoteV4Addr}, + }) + if err != nil { + t.Fatalf("socket.Write(_, {To:%s}) = %s", remoteV4Addr, err) + } + if n != int64(len(buf)) { + t.Fatalf("got n = %d, want n = %d", n, len(buf)) + } + + // Verify the packet was sent out the default NIC. + p, ok := defaultEP.Read() + if !ok { + t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)") + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() + + checker.IPv4(t, b, []checker.NetworkChecker{ + checker.SrcAddr(localV4Addr1), + checker.DstAddr(remoteV4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]), + ), + }...) + + // Verify the packet was not sent out the alternate NIC. + if p, ok := alternateEP.Read(); ok { + t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p) + } + } + + // Send a packet with SO_BINDTODEVICE. This exercises reliance on + // SO_BINDTODEVICE to route the packet to the alternate NIC. + { + // Use SO_BINDTODEVICE to send over the alternate NIC by default. + socket.SocketOptions().SetBindToDevice(2) + + buf := newICMPv4EchoRequest(echoPayloadSize) + r := buf.Reader() + n, err := socket.Write(&r, tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: remoteV4Addr}, + }) + if err != nil { + t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err) + } + if n != int64(len(buf)) { + t.Fatalf("got n = %d, want n = %d", n, len(buf)) + } + + // Verify the packet was not sent out the default NIC. + if p, ok := defaultEP.Read(); ok { + t.Fatalf("got defaultEP.Read(_) = %+v, true; want = _, false", p) + } + + // Verify the packet was sent out the alternate NIC. + p, ok := alternateEP.Read() + if !ok { + t.Fatalf("got alternateEP.Read(_) = _, false; want = _, true (packet wasn't written out)") + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() + + checker.IPv4(t, b, []checker.NetworkChecker{ + checker.SrcAddr(localV4Addr2), + checker.DstAddr(remoteV4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]), + ), + }...) + } + + // Send a packet with SO_BINDTODEVICE cleared. This verifies that clearing + // the device binding will fallback to using the default NIC to send + // packets. + { + socket.SocketOptions().SetBindToDevice(0) + + buf := newICMPv4EchoRequest(echoPayloadSize) + r := buf.Reader() + n, err := socket.Write(&r, tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: remoteV4Addr}, + }) + if err != nil { + t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err) + } + if n != int64(len(buf)) { + t.Fatalf("got n = %d, want n = %d", n, len(buf)) + } + + // Verify the packet was sent out the default NIC. + p, ok := defaultEP.Read() + if !ok { + t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)") + } + + vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views()) + b := vv.ToView() + + checker.IPv4(t, b, []checker.NetworkChecker{ + checker.SrcAddr(localV4Addr1), + checker.DstAddr(remoteV4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]), + ), + }...) + + // Verify the packet was not sent out the alternate NIC. + if p, ok := alternateEP.Read(); ok { + t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p) + } + } +} diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 47f7dd1cb..fa82affc1 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -123,8 +123,6 @@ func (*protocol) Wait() {} // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - // TODO(gvisor.dev/issue/170): Implement parsing of ICMP. - // // Right now, the Parse() method is tied to enabled protocols passed into // stack.New. This works for UDP and TCP, but we handle ICMP traffic even // when netstack users don't pass ICMP as a supported protocol. diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 496eca581..cd8c99d41 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -27,6 +27,7 @@ package packet import ( "fmt" "io" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -41,9 +42,8 @@ type packet struct { packetEntry // data holds the actual packet data, including any headers and // payload. - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // timestampNS is the unix time at which the packet was received. - timestampNS int64 + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + receivedAt time.Time `state:".(int64)"` // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress // packetInfo holds additional information like the protocol @@ -159,7 +159,7 @@ func (ep *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (ep *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { @@ -189,7 +189,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul Total: packet.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: packet.timestampNS, + Timestamp: packet.receivedAt.UnixNano(), }, } if opts.NeedRemoteAddr { @@ -220,19 +220,19 @@ func (*endpoint) Disconnect() tcpip.Error { // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be // connected, and this function always returnes *tcpip.ErrNotSupported. -func (*endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { +func (*endpoint) Connect(tcpip.FullAddress) tcpip.Error { return &tcpip.ErrNotSupported{} } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used // with Shutdown, and this function always returns *tcpip.ErrNotSupported. -func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { +func (*endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { return &tcpip.ErrNotSupported{} } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with // Listen, and this function always returns *tcpip.ErrNotSupported. -func (*endpoint) Listen(backlog int) tcpip.Error { +func (*endpoint) Listen(int) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -318,7 +318,7 @@ func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. -func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { +func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -339,7 +339,7 @@ func (ep *endpoint) UpdateLastError(err tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -451,7 +451,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views()) } } - packet.timestampNS = ep.stack.Clock().NowNanoseconds() + packet.receivedAt = ep.stack.Clock().Now() ep.rcvList.PushBack(&packet) ep.rcvBufSize += packet.data.Size() @@ -484,7 +484,7 @@ func (ep *endpoint) Stats() tcpip.EndpointStats { } // SetOwner implements tcpip.Endpoint.SetOwner. -func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} +func (*endpoint) SetOwner(tcpip.PacketOwner) {} // SocketOptions implements tcpip.Endpoint.SocketOptions. func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 5bd860d20..e729921db 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,11 +15,23 @@ package packet import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *packet) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *packet) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves packet.data field. func (p *packet) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index bcec3d2e7..1bce2769a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -27,6 +27,7 @@ package raw import ( "io" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -41,9 +42,8 @@ type rawPacket struct { rawPacketEntry // data holds the actual packet data, including any headers and // payload. - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // timestampNS is the unix time at which the packet was received. - timestampNS int64 + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + receivedAt time.Time `state:".(int64)"` // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress } @@ -183,7 +183,7 @@ func (e *endpoint) Close() { } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +func (*endpoint) ModerateRecvBuf(int) {} func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.mu.Lock() @@ -219,7 +219,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult Total: pkt.data.Size(), ControlMessages: tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: pkt.timestampNS, + Timestamp: pkt.receivedAt.UnixNano(), }, } if opts.NeedRemoteAddr { @@ -402,7 +402,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { } // Find a route to the destination. - route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false) + route, err := e.stack.FindRoute(nic, "", addr.Addr, e.NetProto, false) if err != nil { return err } @@ -428,7 +428,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { } // Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { +func (e *endpoint) Shutdown(tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -439,7 +439,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { } // Listen implements tcpip.Endpoint.Listen. -func (*endpoint) Listen(backlog int) tcpip.Error { +func (*endpoint) Listen(int) tcpip.Error { return &tcpip.ErrNotSupported{} } @@ -513,12 +513,12 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } -func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { +func (*endpoint) SetSockOptInt(tcpip.SockOptInt, int) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { +func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { return &tcpip.ErrUnknownProtocolOption{} } @@ -621,7 +621,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } combinedVV.Append(pkt.Data().ExtractVV()) packet.data = combinedVV - packet.timestampNS = e.stack.Clock().NowNanoseconds() + packet.receivedAt = e.stack.Clock().Now() e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 5d6f2709c..39669b445 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -15,11 +15,23 @@ package raw import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *rawPacket) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *rawPacket) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves rawPacket.data field. func (p *rawPacket) saveData() buffer.VectorisedView { // We cannot save p.data directly as p.data.views may alias to p.views, diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 0f20d3856..8436d2cf0 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -41,7 +41,6 @@ go_library( "protocol.go", "rack.go", "rcv.go", - "rcv_state.go", "reno.go", "reno_recovery.go", "sack.go", @@ -53,7 +52,6 @@ go_library( "segment_state.go", "segment_unsafe.go", "snd.go", - "snd_state.go", "tcp_endpoint_list.go", "tcp_segment_list.go", "timer.go", @@ -134,6 +132,7 @@ go_test( deps = [ "//pkg/sleep", "//pkg/tcpip/buffer", + "//pkg/tcpip/faketime", "//pkg/tcpip/stack", "@com_github_google_go_cmp//cmp:go_default_library", ], diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index d4bd4e80e..d807b13b7 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -114,8 +114,8 @@ type listenContext struct { } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. -func timeStamp() uint32 { - return uint32(time.Now().Unix()>>6) & tsMask +func timeStamp(clock tcpip.Clock) uint32 { + return uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Seconds()) >> 6 & tsMask } // newListenContext creates a new listen context. @@ -171,7 +171,7 @@ func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonc // createCookie creates a SYN cookie for the given id and incoming sequence // number. func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value { - ts := timeStamp() + ts := timeStamp(l.stack.Clock()) v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset) v += (l.cookieHash(id, ts, 1) + data) & hashMask return seqnum.Value(v) @@ -181,7 +181,7 @@ func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Va // sequence number. If it is, it also returns the data originally encoded in the // cookie when createCookie was called. func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) { - ts := timeStamp() + ts := timeStamp(l.stack.Clock()) v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq) cookieTS := v >> tsOffset if ((ts - cookieTS) & tsMask) > maxTSDiff { @@ -247,7 +247,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts *header func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber - isn := generateSecureISN(s.id, l.stack.Seed()) + isn := generateSecureISN(s.id, l.stack.Clock(), l.stack.Seed()) ep, err := l.createConnectingEndpoint(s, opts, queue) if err != nil { return nil, err @@ -550,7 +550,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err e.rcvQueueInfo.rcvQueueMu.Lock() rcvClosed := e.rcvQueueInfo.RcvClosed e.rcvQueueInfo.rcvQueueMu.Unlock() - if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) { + if rcvClosed || s.flags.Contains(header.TCPFlagSyn|header.TCPFlagAck) { // If the endpoint is shutdown, reply with reset. // // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST @@ -560,6 +560,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } switch { + case s.flags.Contains(header.TCPFlagRst): + e.stack.Stats().DroppedPackets.Increment() + return nil + case s.flags == header.TCPFlagSyn: if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() @@ -591,7 +595,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, - TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), + TSVal: tcpTimeStamp(e.stack.Clock().NowMonotonic(), timeStampOffset(e.stack.Rand())), TSEcr: opts.TSVal, MSS: calculateAdvertisedMSS(e.userMSS, route), } @@ -611,7 +615,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() return nil - case (s.flags & header.TCPFlagAck) != 0: + case s.flags.Contains(header.TCPFlagAck): if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be @@ -736,6 +740,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err mss: rcvdSynOptions.MSS, }) + // Requeue the segment if the ACK completing the handshake has more info + // to be procesed by the newly established endpoint. + if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && n.enqueueSegment(s) { + s.incRef() + n.newSegmentWaker.Assert() + } + // Do the delivery in a separate goroutine so // that we don't block the listen loop in case // the application is slow to accept or stops @@ -753,6 +764,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil default: + e.stack.Stats().DroppedPackets.Increment() return nil } } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5e03e7715..2137ebc25 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -19,7 +19,6 @@ import ( "math" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -92,7 +91,7 @@ type handshake struct { rcvWndScale int // startTime is the time at which the first SYN/SYN-ACK was sent. - startTime time.Time + startTime tcpip.MonotonicTime // deferAccept if non-zero will drop the final ACK for a passive // handshake till an ACK segment with data is received or the timeout is @@ -147,21 +146,16 @@ func FindWndScale(wnd seqnum.Size) int { // resetState resets the state of the handshake object such that it becomes // ready for a new 3-way handshake. func (h *handshake) resetState() { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - h.state = handshakeSynSent h.flags = header.TCPFlagSyn h.ackNum = 0 h.mss = 0 - h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Seed()) + h.iss = generateSecureISN(h.ep.TransportEndpointInfo.ID, h.ep.stack.Clock(), h.ep.stack.Seed()) } // generateSecureISN generates a secure Initial Sequence number based on the // recommendation here https://tools.ietf.org/html/rfc6528#page-3. -func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { +func generateSecureISN(id stack.TransportEndpointID, clock tcpip.Clock, seed uint32) seqnum.Value { isnHasher := jenkins.Sum32(seed) isnHasher.Write([]byte(id.LocalAddress)) isnHasher.Write([]byte(id.RemoteAddress)) @@ -180,7 +174,7 @@ func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { // // Which sort of guarantees that we won't reuse the ISN for a new // connection for the same tuple for at least 274s. - isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6) + isn := isnHasher.Sum32() + uint32(clock.NowMonotonic().Sub(tcpip.MonotonicTime{}).Nanoseconds()>>6) return seqnum.Value(isn) } @@ -212,7 +206,7 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea // a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in // response. func (h *handshake) checkAck(s *segment) bool { - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber != h.iss+1 { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber != h.iss+1 { // RFC 793, page 36, states that a reset must be generated when // the connection is in any non-synchronized state and an // incoming segment acknowledges something not yet sent. The @@ -230,8 +224,8 @@ func (h *handshake) checkAck(s *segment) bool { func (h *handshake) synSentState(s *segment) tcpip.Error { // RFC 793, page 37, states that in the SYN-SENT state, a reset is // acceptable if the ack field acknowledges the SYN. - if s.flagIsSet(header.TCPFlagRst) { - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 { + if s.flags.Contains(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == h.iss+1 { // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK // was acceptable then signal the user "error: connection reset", drop // the segment, enter CLOSED state, delete TCB, and return." @@ -249,7 +243,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { // We are in the SYN-SENT state. We only care about segments that have // the SYN flag. - if !s.flagIsSet(header.TCPFlagSyn) { + if !s.flags.Contains(header.TCPFlagSyn) { return nil } @@ -270,7 +264,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { // If this is a SYN ACK response, we only need to acknowledge the SYN // and the handshake is completed. - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { h.state = handshakeCompleted h.ep.transitionToStateEstablishedLocked(h) @@ -316,7 +310,7 @@ func (h *handshake) synSentState(s *segment) tcpip.Error { // synRcvdState handles a segment received when the TCP 3-way handshake is in // the SYN-RCVD state. func (h *handshake) synRcvdState(s *segment) tcpip.Error { - if s.flagIsSet(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagRst) { // RFC 793, page 37, states that in the SYN-RCVD state, a reset // is acceptable if the sequence number is in the window. if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { @@ -340,13 +334,13 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { return nil } - if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 { + if s.flags.Contains(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 { // We received two SYN segments with different sequence // numbers, so we reset this and restart the whole // process, except that we don't reset the timer. ack := s.sequenceNumber.Add(s.logicalLen()) seq := seqnum.Value(0) - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { seq = s.ackNumber } h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0) @@ -378,10 +372,10 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { // We have previously received (and acknowledged) the peer's SYN. If the // peer acknowledges our SYN, the handshake is completed. - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { // If deferAccept is not zero and this is a bare ACK and the // timeout is not hit then drop the ACK. - if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept { + if h.deferAccept != 0 && s.data.Size() == 0 && h.ep.stack.Clock().NowMonotonic().Sub(h.startTime) < h.deferAccept { h.acked = true h.ep.stack.Stats().DroppedPackets.Increment() return nil @@ -412,11 +406,11 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { h.ep.transitionToStateEstablishedLocked(h) - // If the segment has data then requeue it for the receiver - // to process it again once main loop is started. - if s.data.Size() > 0 { + // Requeue the segment if the ACK completing the handshake has more info + // to be procesed by the newly established endpoint. + if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && h.ep.enqueueSegment(s) { s.incRef() - h.ep.enqueueSegment(s) + h.ep.newSegmentWaker.Assert() } return nil } @@ -426,7 +420,7 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { func (h *handshake) handleSegment(s *segment) tcpip.Error { h.sndWnd = s.window - if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 { + if !s.flags.Contains(header.TCPFlagSyn) && h.sndWndScale > 0 { h.sndWnd <<= uint8(h.sndWndScale) } @@ -474,7 +468,7 @@ func (h *handshake) processSegments() tcpip.Error { // start sends the first SYN/SYN-ACK. It does not block, even if link address // resolution is required. func (h *handshake) start() { - h.startTime = time.Now() + h.startTime = h.ep.stack.Clock().NowMonotonic() h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { @@ -527,7 +521,7 @@ func (h *handshake) complete() tcpip.Error { defer s.Done() // Initialize the resend timer. - timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert) + timer, err := newBackoffTimer(h.ep.stack.Clock(), time.Second, MaxRTO, resendWaker.Assert) if err != nil { return err } @@ -552,7 +546,7 @@ func (h *handshake) complete() tcpip.Error { // The last is required to provide a way for the peer to complete // the connection with another ACK or data (as ACKs are never // retransmitted on their own). - if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { + if h.active || !h.acked || h.deferAccept != 0 && h.ep.stack.Clock().NowMonotonic().Sub(h.startTime) > h.deferAccept { h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.TransportEndpointInfo.ID, ttl: h.ep.ttl, @@ -608,15 +602,15 @@ func (h *handshake) complete() tcpip.Error { type backoffTimer struct { timeout time.Duration maxTimeout time.Duration - t *time.Timer + t tcpip.Timer } -func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) { +func newBackoffTimer(clock tcpip.Clock, timeout, maxTimeout time.Duration, f func()) (*backoffTimer, tcpip.Error) { if timeout > maxTimeout { return nil, &tcpip.ErrTimeout{} } bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} - bt.t = time.AfterFunc(timeout, f) + bt.t = clock.AfterFunc(timeout, f) return bt, nil } @@ -634,7 +628,7 @@ func (bt *backoffTimer) stop() { } func parseSynSegmentOptions(s *segment) header.TCPSynOptions { - synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck)) + synOpts := header.ParseSynOptions(s.options, s.flags.Contains(header.TCPFlagAck)) if synOpts.TS { s.parsedOptions.TSVal = synOpts.TSVal s.parsedOptions.TSEcr = synOpts.TSEcr @@ -1188,11 +1182,11 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) // the TCPEndpointState after the segment is processed. defer e.probeSegmentLocked() - if s.flagIsSet(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagRst) { if ok, err := e.handleReset(s); !ok { return false, err } - } else if s.flagIsSet(header.TCPFlagSyn) { + } else if s.flags.Contains(header.TCPFlagSyn) { // See: https://tools.ietf.org/html/rfc5961#section-4.1 // 1) If the SYN bit is set, irrespective of the sequence number, TCP // MUST send an ACK (also referred to as challenge ACK) to the remote @@ -1216,7 +1210,7 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) // should then rely on SYN retransmission from the remote end to // re-establish the connection. e.snd.maybeSendOutOfWindowAck(s) - } else if s.flagIsSet(header.TCPFlagAck) { + } else if s.flags.Contains(header.TCPFlagAck) { // Patch the window size in the segment according to the // send window scale. s.window <<= e.snd.SndWndScale @@ -1235,7 +1229,7 @@ func (e *endpoint) handleSegmentLocked(s *segment) (cont bool, err tcpip.Error) // Now check if the received segment has caused us to transition // to a CLOSED state, if yes then terminate processing and do // not invoke the sender. - state := e.state + state := e.EndpointState() if state == StateClose { // When we get into StateClose while processing from the queue, // return immediately and let the protocolMainloop handle it. @@ -1267,7 +1261,7 @@ func (e *endpoint) keepaliveTimerExpired() tcpip.Error { // If a userTimeout is set then abort the connection if it is // exceeded. - if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { + if userTimeout != 0 && e.stack.Clock().NowMonotonic().Sub(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { e.keepalive.Unlock() e.stack.Stats().TCP.EstablishedTimedout.Increment() return &tcpip.ErrTimeout{} @@ -1322,7 +1316,7 @@ func (e *endpoint) disableKeepaliveTimer() { // segments. func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) tcpip.Error { e.mu.Lock() - var closeTimer *time.Timer + var closeTimer tcpip.Timer var closeWaker sleep.Waker epilogue := func() { @@ -1480,11 +1474,19 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ return &tcpip.ErrConnectionReset{} } - if n¬ifyClose != 0 && closeTimer == nil { - if e.EndpointState() == StateFinWait2 && e.closed { + if n¬ifyClose != 0 && e.closed { + switch e.EndpointState() { + case StateEstablished: + // Perform full shutdown if the endpoint is still + // established. This can occur when notifyClose + // was asserted just before becoming established. + e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) + case StateFinWait2: // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. - closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + if closeTimer == nil { + closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + } } } @@ -1721,7 +1723,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { var timeWaitWaker sleep.Waker s.AddWaker(&timeWaitWaker, timeWaitDone) - timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert) + timeWaitTimer := e.stack.Clock().AfterFunc(timeWaitDuration, timeWaitWaker.Assert) defer timeWaitTimer.Stop() for { diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index 962f1d687..6985194bb 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -41,7 +41,7 @@ type cubicState struct { func newCubicCC(s *sender) *cubicState { return &cubicState{ TCPCubicState: stack.TCPCubicState{ - T: time.Now(), + T: s.ep.stack.Clock().NowMonotonic(), Beta: 0.7, C: 0.4, }, @@ -60,7 +60,7 @@ func (c *cubicState) enterCongestionAvoidance() { // https://tools.ietf.org/html/rfc8312#section-4.8 if c.numCongestionEvents == 0 { c.K = 0 - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() c.WLastMax = c.WMax c.WMax = float64(c.s.SndCwnd) } @@ -115,14 +115,15 @@ func (c *cubicState) cubicCwnd(t float64) float64 { // getCwnd returns the current congestion window as computed by CUBIC. // Refer: https://tools.ietf.org/html/rfc8312#section-4 func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int { - elapsed := time.Since(c.T).Seconds() + elapsed := c.s.ep.stack.Clock().NowMonotonic().Sub(c.T) + elapsedSeconds := elapsed.Seconds() // Compute the window as per Cubic after 'elapsed' time // since last congestion event. - c.WC = c.cubicCwnd(elapsed - c.K) + c.WC = c.cubicCwnd(elapsedSeconds - c.K) // Compute the TCP friendly estimate of the congestion window. - c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsed/srtt.Seconds()) + c.WEst = c.WMax*c.Beta + (3.0*((1.0-c.Beta)/(1.0+c.Beta)))*(elapsedSeconds/srtt.Seconds()) // Make sure in the TCP friendly region CUBIC performs at least // as well as Reno. @@ -134,7 +135,7 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int // In Concave/Convex region of CUBIC, calculate what CUBIC window // will be after 1 RTT and use that to grow congestion window // for every ack. - tEst := (time.Since(c.T) + srtt).Seconds() + tEst := (elapsed + srtt).Seconds() wtRtt := c.cubicCwnd(tEst - c.K) // As per 4.3 for each received ACK cwnd must be incremented // by (w_cubic(t+RTT) - cwnd/cwnd. @@ -151,7 +152,7 @@ func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int func (c *cubicState) HandleLossDetected() { // See: https://tools.ietf.org/html/rfc8312#section-4.5 c.numCongestionEvents++ - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() c.WLastMax = c.WMax c.WMax = float64(c.s.SndCwnd) @@ -162,7 +163,7 @@ func (c *cubicState) HandleLossDetected() { // HandleRTOExpired implements congestionContrl.HandleRTOExpired. func (c *cubicState) HandleRTOExpired() { // See: https://tools.ietf.org/html/rfc8312#section-4.6 - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() c.numCongestionEvents = 0 c.WLastMax = c.WMax c.WMax = float64(c.s.SndCwnd) @@ -193,7 +194,7 @@ func (c *cubicState) fastConvergence() { // PostRecovery implemements congestionControl.PostRecovery. func (c *cubicState) PostRecovery() { - c.T = time.Now() + c.T = c.s.ep.stack.Clock().NowMonotonic() } // reduceSlowStartThreshold returns new SsThresh as described in diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 512053a04..dff7cb89c 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -16,10 +16,11 @@ package tcp import ( "encoding/binary" + "math/rand" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -141,15 +142,16 @@ func (p *processor) start(wg *sync.WaitGroup) { // in-order. type dispatcher struct { processors []processor - seed uint32 - wg sync.WaitGroup + // seed is a random secret for a jenkins hash. + seed uint32 + wg sync.WaitGroup } -func (d *dispatcher) init(nProcessors int) { +func (d *dispatcher) init(rng *rand.Rand, nProcessors int) { d.close() d.wait() d.processors = make([]processor, nProcessors) - d.seed = generateRandUint32() + d.seed = rng.Uint32() for i := range d.processors { p := &d.processors[i] p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker) @@ -172,12 +174,11 @@ func (d *dispatcher) wait() { d.wg.Wait() } -func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) { ep := stackEP.(*endpoint) - s := newIncomingSegment(id, pkt) + s := newIncomingSegment(id, clock, pkt) if !s.parse(pkt.RXTransportChecksumValidated) { - ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() s.decRef() @@ -185,7 +186,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans } if !s.csumValid { - ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.ChecksumErrors.Increment() ep.stats.ReceiveErrors.ChecksumErrors.Increment() s.decRef() @@ -213,14 +213,6 @@ func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.Trans d.selectProcessor(id).queueEndpoint(ep) } -func generateRandUint32() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - return binary.LittleEndian.Uint32(b) -} - func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { var payload [4]byte binary.LittleEndian.PutUint16(payload[0:], id.LocalPort) diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index f148d505d..5342aacfd 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -421,7 +421,7 @@ func testV4Accept(t *testing.T, c *context.Context) { r.Reset(data) nep.Write(&r, tcpip.WriteOptions{}) b = c.GetPacket() - tcp = header.TCP(header.IPv4(b).Payload()) + tcp = header.IPv4(b).Payload() if string(tcp.Payload()) != data { t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 50d39cbad..a27e2110b 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -20,12 +20,12 @@ import ( "fmt" "io" "math" + "math/rand" "runtime" "strings" "sync/atomic" "time" - "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,19 +38,15 @@ import ( ) // EndpointState represents the state of a TCP endpoint. -type EndpointState uint32 +type EndpointState tcpip.EndpointState // Endpoint states. Note that are represented in a netstack-specific manner and // may not be meaningful externally. Specifically, they need to be translated to // Linux's representation for these states if presented to userspace. const ( - // Endpoint states internal to netstack. These map to the TCP state CLOSED. - StateInitial EndpointState = iota - StateBound - StateConnecting // Connect() called, but the initial SYN hasn't been sent. - StateError - - // TCP protocol states. + _ EndpointState = iota + // TCP protocol states in sync with the definitions in + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/tcp_states.h#L13 StateEstablished StateSynSent StateSynRecv @@ -62,6 +58,12 @@ const ( StateLastAck StateListen StateClosing + + // Endpoint states internal to netstack. + StateInitial + StateBound + StateConnecting // Connect() called, but the initial SYN hasn't been sent. + StateError ) const ( @@ -97,6 +99,16 @@ func (s EndpointState) connecting() bool { } } +// internal returns true when the state is netstack internal. +func (s EndpointState) internal() bool { + switch s { + case StateInitial, StateBound, StateConnecting, StateError: + return true + default: + return false + } +} + // handshake returns true when s is one of the states representing an endpoint // in the middle of a TCP handshake. func (s EndpointState) handshake() bool { @@ -422,12 +434,12 @@ type endpoint struct { // state must be read/set using the EndpointState()/setEndpointState() // methods. - state EndpointState `state:".(EndpointState)"` + state uint32 `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the // endpoint state at restore time as the socket is moved to it's correct // state. - origEndpointState EndpointState `state:"nosave"` + origEndpointState uint32 `state:"nosave"` isPortReserved bool `state:"manual"` isRegistered bool `state:"manual"` @@ -468,7 +480,7 @@ type endpoint struct { // recentTSTime is the unix time when we last updated // TCPEndpointStateInner.RecentTS. - recentTSTime time.Time `state:".(unixTime)"` + recentTSTime tcpip.MonotonicTime // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -626,7 +638,7 @@ type endpoint struct { // lastOutOfWindowAckTime is the time at which the an ACK was sent in response // to an out of window segment being received by this endpoint. - lastOutOfWindowAckTime time.Time `state:".(unixTime)"` + lastOutOfWindowAckTime tcpip.MonotonicTime } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -747,7 +759,7 @@ func (e *endpoint) ResumeWork() { // // Precondition: e.mu must be held to call this method. func (e *endpoint) setEndpointState(state EndpointState) { - oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + oldstate := EndpointState(atomic.LoadUint32(&e.state)) switch state { case StateEstablished: e.stack.Stats().TCP.CurrentEstablished.Increment() @@ -764,18 +776,18 @@ func (e *endpoint) setEndpointState(state EndpointState) { e.stack.Stats().TCP.CurrentEstablished.Decrement() } } - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32(&e.state, uint32(state)) } // EndpointState returns the current state of the endpoint. func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + return EndpointState(atomic.LoadUint32(&e.state)) } // setRecentTimestamp sets the recentTS field to the provided value. func (e *endpoint) setRecentTimestamp(recentTS uint32) { e.RecentTS = recentTS - e.recentTSTime = time.Now() + e.recentTSTime = e.stack.Clock().NowMonotonic() } // recentTimestamp returns the value of the recentTS field. @@ -806,11 +818,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue }, sndQueueInfo: sndQueueInfo{ TCPSndBufState: stack.TCPSndBufState{ - SndMTU: int(math.MaxInt32), + SndMTU: math.MaxInt32, }, }, waiterQueue: waiterQueue, - state: StateInitial, + state: uint32(StateInitial), keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -870,9 +882,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } e.segmentQueue.ep = e - e.TSOffset = timeStampOffset() + e.TSOffset = timeStampOffset(e.stack.Rand()) e.acceptCond = sync.NewCond(&e.acceptMu) - e.keepalive.timer.init(&e.keepalive.waker) + e.keepalive.timer.init(e.stack.Clock(), &e.keepalive.waker) return e } @@ -1189,7 +1201,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvQueueInfo.rcvQueueMu.Unlock() return } - now := time.Now() + now := e.stack.Clock().NowMonotonic() if rtt := e.rcvQueueInfo.RcvAutoParams.RTT; rtt == 0 || now.Sub(e.rcvQueueInfo.RcvAutoParams.MeasureTime) < rtt { e.rcvQueueInfo.RcvAutoParams.CopiedBytes += copied e.rcvQueueInfo.rcvQueueMu.Unlock() @@ -1544,7 +1556,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } // Add data to the send queue. - s := newOutgoingSegment(e.TransportEndpointInfo.ID, v) + s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), v) e.sndQueueInfo.SndBufUsed += len(v) e.sndQueueInfo.SndBufInQueue += seqnum.Size(len(v)) e.sndQueueInfo.sndQueue.PushBack(s) @@ -1956,6 +1968,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { func (e *endpoint) getTCPInfo() tcpip.TCPInfoOption { info := tcpip.TCPInfoOption{} e.LockUser() + if state := e.EndpointState(); state.internal() { + info.State = tcpip.EndpointState(StateClose) + } else { + info.State = tcpip.EndpointState(state) + } snd := e.snd if snd != nil { // We do not calculate RTT before sending the data packets. If @@ -2198,7 +2215,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp BindToDevice: bindToDevice, Dest: addr, } - if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil { if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } @@ -2224,7 +2241,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but // less than 1 second has elapsed since its recentTS was updated then // we cannot reuse the port. - if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second { + if tcpEP.EndpointState() != StateTimeWait || e.stack.Clock().NowMonotonic().Sub(tcpEP.recentTSTime) < 1*time.Second { tcpEP.UnlockUser() return false, nil } @@ -2245,7 +2262,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp BindToDevice: bindToDevice, Dest: addr, } - if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */); err != nil { return false, nil } } @@ -2370,7 +2387,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { } // Queue fin segment. - s := newOutgoingSegment(e.TransportEndpointInfo.ID, nil) + s := newOutgoingSegment(e.TransportEndpointInfo.ID, e.stack.Clock(), nil) e.sndQueueInfo.sndQueue.PushBack(s) e.sndQueueInfo.SndBufInQueue++ // Mark endpoint as closed. @@ -2581,7 +2598,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { BindToDevice: bindToDevice, Dest: tcpip.FullAddress{}, } - port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) { + port, err := e.stack.ReservePort(e.stack.Rand(), portRes, func(p uint16) (bool, tcpip.Error) { id := e.TransportEndpointInfo.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a @@ -2731,7 +2748,7 @@ func (e *endpoint) updateSndBufferUsage(v int) { // We only notify when there is half the sendBufferSize available after // a full buffer event occurs. This ensures that we don't wake up // writers to queue just 1-2 segments and go back to sleep. - notify = notify && e.sndQueueInfo.SndBufUsed < int(sendBufferSize)>>1 + notify = notify && e.sndQueueInfo.SndBufUsed < sendBufferSize>>1 e.sndQueueInfo.sndQueueMu.Unlock() if notify { @@ -2848,23 +2865,20 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { // timestamp returns the timestamp value to be used in the TSVal field of the // timestamp option for outgoing TCP segments for a given endpoint. func (e *endpoint) timestamp() uint32 { - return tcpTimeStamp(time.Now(), e.TSOffset) + return tcpTimeStamp(e.stack.Clock().NowMonotonic(), e.TSOffset) } // tcpTimeStamp returns a timestamp offset by the provided offset. This is // not inlined above as it's used when SYN cookies are in use and endpoint // is not created at the time when the SYN cookie is sent. -func tcpTimeStamp(curTime time.Time, offset uint32) uint32 { - return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset +func tcpTimeStamp(curTime tcpip.MonotonicTime, offset uint32) uint32 { + d := curTime.Sub(tcpip.MonotonicTime{}) + return uint32(d.Milliseconds()) + offset } // timeStampOffset returns a randomized timestamp offset to be used when sending // timestamp values in a timestamp option for a TCP segment. -func timeStampOffset() uint32 { - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } +func timeStampOffset(rng *rand.Rand) uint32 { // Initialize a random tsOffset that will be added to the recentTS // everytime the timestamp is sent when the Timestamp option is enabled. // @@ -2874,7 +2888,7 @@ func timeStampOffset() uint32 { // NOTE: This is not completely to spec as normally this should be // initialized in a manner analogous to how sequence numbers are // randomized per connection basis. But for now this is sufficient. - return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 + return rng.Uint32() } // maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint @@ -2909,7 +2923,7 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState { s := stack.TCPEndpointState{ TCPEndpointStateInner: e.TCPEndpointStateInner, ID: stack.TCPEndpointID(e.TransportEndpointInfo.ID), - SegTime: time.Now(), + SegTime: e.stack.Clock().NowMonotonic(), Receiver: e.rcv.TCPReceiverState, Sender: e.snd.TCPSenderState, } @@ -2937,7 +2951,7 @@ func (e *endpoint) completeStateLocked() stack.TCPEndpointState { if cubic, ok := e.snd.cc.(*cubicState); ok { s.Sender.Cubic = cubic.TCPCubicState - s.Sender.Cubic.TimeSinceLastCongestion = time.Since(s.Sender.Cubic.T) + s.Sender.Cubic.TimeSinceLastCongestion = e.stack.Clock().NowMonotonic().Sub(s.Sender.Cubic.T) } s.Sender.RACKState = e.snd.rc.TCPRACKState @@ -3029,14 +3043,16 @@ func GetTCPSendBufferLimits(s tcpip.StackHandler) tcpip.SendBufferSizeOption { // allowOutOfWindowAck returns true if an out-of-window ACK can be sent now. func (e *endpoint) allowOutOfWindowAck() bool { - var limit stack.TCPInvalidRateLimitOption - if err := e.stack.Option(&limit); err != nil { - panic(fmt.Sprintf("e.stack.Option(%+v) failed with error: %s", limit, err)) - } + now := e.stack.Clock().NowMonotonic() - now := time.Now() - if now.Sub(e.lastOutOfWindowAckTime) < time.Duration(limit) { - return false + if e.lastOutOfWindowAckTime != (tcpip.MonotonicTime{}) { + var limit stack.TCPInvalidRateLimitOption + if err := e.stack.Option(&limit); err != nil { + panic(fmt.Sprintf("e.stack.Option(%+v) failed with error: %s", limit, err)) + } + if now.Sub(e.lastOutOfWindowAckTime) < time.Duration(limit) { + return false + } } e.lastOutOfWindowAckTime = now diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 6e9777fe4..952ccacdd 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -154,20 +154,25 @@ func (e *endpoint) afterLoad() { e.origEndpointState = e.state // Restore the endpoint to InitialState as it will be moved to // its origEndpointState during Resume. - e.state = StateInitial + e.state = uint32(StateInitial) // Condition variables and mutexs are not S/R'ed so reinitialize // acceptCond with e.acceptMu. e.acceptCond = sync.NewCond(&e.acceptMu) - e.keepalive.timer.init(&e.keepalive.waker) stack.StackFromEnv.RegisterRestoredEndpoint(e) } // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.keepalive.timer.init(s.Clock(), &e.keepalive.waker) + if snd := e.snd; snd != nil { + snd.resendTimer.init(s.Clock(), &snd.resendWaker) + snd.reorderTimer.init(s.Clock(), &snd.reorderWaker) + snd.probeTimer.init(s.Clock(), &snd.probeWaker) + } e.stack = s e.ops.InitHandler(e, e.stack, GetTCPSendBufferLimits, GetTCPReceiveBufferLimits) e.segmentQueue.thaw() - epState := e.origEndpointState + epState := EndpointState(e.origEndpointState) switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss tcpip.TCPSendBufferSizeRangeOption @@ -281,32 +286,12 @@ func (e *endpoint) Resume(s *stack.Stack) { }() case epState == StateClose: e.isPortReserved = false - e.state = StateClose + e.state = uint32(StateClose) e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) case epState == StateError: - e.state = StateError + e.state = uint32(StateError) e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } } - -// saveRecentTSTime is invoked by stateify. -func (e *endpoint) saveRecentTSTime() unixTime { - return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()} -} - -// loadRecentTSTime is invoked by stateify. -func (e *endpoint) loadRecentTSTime(unix unixTime) { - e.recentTSTime = time.Unix(unix.second, unix.nano) -} - -// saveLastOutOfWindowAckTime is invoked by stateify. -func (e *endpoint) saveLastOutOfWindowAckTime() unixTime { - return unixTime{e.lastOutOfWindowAckTime.Unix(), e.lastOutOfWindowAckTime.UnixNano()} -} - -// loadLastOutOfWindowAckTime is invoked by stateify. -func (e *endpoint) loadLastOutOfWindowAckTime(unix unixTime) { - e.lastOutOfWindowAckTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 2f9fe7ee0..65c86823a 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -65,7 +65,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - s := newIncomingSegment(id, pkt) + s := newIncomingSegment(id, f.stack.Clock(), pkt) defer s.decRef() // We only care about well-formed SYN packets. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index a3d1aa1a3..2fc282e73 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -131,7 +131,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err tcpip.Error) { // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - p.dispatcher.queuePacket(ep, id, pkt) + p.dispatcher.queuePacket(ep, id, p.stack.Clock(), pkt) } // HandleUnknownDestinationPacket handles packets targeted at this protocol but @@ -142,14 +142,14 @@ func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEnd // particular, SYNs addressed to a non-existent connection are rejected by this // means." func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { - s := newIncomingSegment(id, pkt) + s := newIncomingSegment(id, p.stack.Clock(), pkt) defer s.decRef() if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid { return stack.UnknownDestinationPacketMalformed } - if !s.flagIsSet(header.TCPFlagRst) { + if !s.flags.Contains(header.TCPFlagRst) { replyWithReset(p.stack, s, stack.DefaultTOS, 0) } @@ -181,7 +181,7 @@ func replyWithReset(st *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error { // reset has sequence number zero and the ACK field is set to the sum // of the sequence number and segment length of the incoming segment. // The connection remains in the CLOSED state. - if s.flagIsSet(header.TCPFlagAck) { + if s.flags.Contains(header.TCPFlagAck) { seq = s.ackNumber } else { flags |= header.TCPFlagAck @@ -401,7 +401,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er case *tcpip.TCPTimeWaitReuseOption: p.mu.RLock() - *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse) + *v = p.timeWaitReuse p.mu.RUnlock() return nil @@ -481,6 +481,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { // TODO(gvisor.dev/issue/5243): Set recovery to tcpip.TCPRACKLossDetection. recovery: 0, } - p.dispatcher.init(runtime.GOMAXPROCS(0)) + p.dispatcher.init(s.Rand(), runtime.GOMAXPROCS(0)) return &p } diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index 9e332dcf7..0da4eafaa 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -79,7 +79,7 @@ func (rc *rackControl) init(snd *sender, iss seqnum.Value) { // update will update the RACK related fields when an ACK has been received. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-09#section-6.2 func (rc *rackControl) update(seg *segment, ackSeg *segment) { - rtt := time.Now().Sub(seg.xmitTime) + rtt := rc.snd.ep.stack.Clock().NowMonotonic().Sub(seg.xmitTime) tsOffset := rc.snd.ep.TSOffset // If the ACK is for a retransmitted packet, do not update if it is a @@ -115,7 +115,7 @@ func (rc *rackControl) update(seg *segment, ackSeg *segment) { // ending sequence number of the packet which has been acknowledged // most recently. endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + if rc.XmitTime.Before(seg.xmitTime) || (seg.xmitTime == rc.XmitTime && rc.EndSequence.LessThan(endSeq)) { rc.XmitTime = seg.xmitTime rc.EndSequence = endSeq } @@ -174,7 +174,7 @@ func (s *sender) schedulePTO() { } s.rtt.Unlock() - now := time.Now() + now := s.ep.stack.Clock().NowMonotonic() if s.resendTimer.enabled() { if now.Add(pto).After(s.resendTimer.target) { pto = s.resendTimer.target.Sub(now) @@ -279,7 +279,7 @@ func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { // been observed RACK uses reo_wnd of zero during loss recovery, in order to // retransmit quickly, or when the number of DUPACKs exceeds the classic // DUPACKthreshold. -func (rc *rackControl) updateRACKReorderWindow(ackSeg *segment) { +func (rc *rackControl) updateRACKReorderWindow() { dsackSeen := rc.DSACKSeen snd := rc.snd @@ -352,7 +352,7 @@ func (rc *rackControl) exitRecovery() { // detectLoss marks the segment as lost if the reordering window has elapsed // and the ACK is not received. It will also arm the reorder timer. // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 Step 5. -func (rc *rackControl) detectLoss(rcvTime time.Time) int { +func (rc *rackControl) detectLoss(rcvTime tcpip.MonotonicTime) int { var timeout time.Duration numLost := 0 for seg := rc.snd.writeList.Front(); seg != nil && seg.xmitCount != 0; seg = seg.Next() { @@ -366,7 +366,7 @@ func (rc *rackControl) detectLoss(rcvTime time.Time) int { } endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime.Equal(rc.XmitTime) && rc.EndSequence.LessThan(endSeq)) { + if seg.xmitTime.Before(rc.XmitTime) || (seg.xmitTime == rc.XmitTime && rc.EndSequence.LessThan(endSeq)) { timeRemaining := seg.xmitTime.Sub(rcvTime) + rc.RTT + rc.ReoWnd if timeRemaining <= 0 { seg.lost = true @@ -392,7 +392,7 @@ func (rc *rackControl) reorderTimerExpired() tcpip.Error { return nil } - numLost := rc.detectLoss(time.Now()) + numLost := rc.detectLoss(rc.snd.ep.stack.Clock().NowMonotonic()) if numLost == 0 { return nil } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 133371455..661ca604a 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -17,7 +17,6 @@ package tcp import ( "container/heap" "math" - "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -50,7 +49,7 @@ type receiver struct { pendingRcvdSegments segmentHeap // Time when the last ack was received. - lastRcvdAckTime time.Time `state:".(unixTime)"` + lastRcvdAckTime tcpip.MonotonicTime } func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { @@ -63,7 +62,7 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale }, rcvWnd: rcvWnd, rcvWUP: irs + 1, - lastRcvdAckTime: time.Now(), + lastRcvdAckTime: ep.stack.Clock().NowMonotonic(), } } @@ -137,9 +136,9 @@ func (r *receiver) getSendParams() (RcvNxt seqnum.Value, rcvWnd seqnum.Size) { // rcvWUP RcvNxt RcvAcc new RcvAcc // <=====curWnd ===> // <========= newWnd > curWnd ========= > - if r.RcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.RcvNxt.Add(seqnum.Size(newWnd))) && toGrow { + if r.RcvNxt.Add(curWnd).LessThan(r.RcvNxt.Add(newWnd)) && toGrow { // If the new window moves the right edge, then update RcvAcc. - r.RcvAcc = r.RcvNxt.Add(seqnum.Size(newWnd)) + r.RcvAcc = r.RcvNxt.Add(newWnd) } else { if newWnd == 0 { // newWnd is zero but we can't advertise a zero as it would cause window @@ -245,7 +244,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum TrimSACKBlockList(&r.ep.sack, r.RcvNxt) // Handle FIN or FIN-ACK. - if s.flagIsSet(header.TCPFlagFin) { + if s.flags.Contains(header.TCPFlagFin) { r.RcvNxt++ // Send ACK immediately. @@ -261,7 +260,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum case StateEstablished: r.ep.setEndpointState(StateCloseWait) case StateFinWait1: - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { // FIN-ACK, transition to TIME-WAIT. r.ep.setEndpointState(StateTimeWait) } else { @@ -296,7 +295,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // Handle ACK (not FIN-ACK, which we handled above) during one of the // shutdown states. - if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { + if s.flags.Contains(header.TCPFlagAck) && s.ackNumber == r.ep.snd.SndNxt { switch r.ep.EndpointState() { case StateFinWait1: r.ep.setEndpointState(StateFinWait2) @@ -325,9 +324,9 @@ func (r *receiver) updateRTT() { // is first acknowledged and the receipt of data that is at least one // window beyond the sequence number that was acknowledged. r.ep.rcvQueueInfo.rcvQueueMu.Lock() - if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime.IsZero() { + if r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime == (tcpip.MonotonicTime{}) { // New measurement. - r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = r.ep.stack.Clock().NowMonotonic() r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return @@ -336,14 +335,14 @@ func (r *receiver) updateRTT() { r.ep.rcvQueueInfo.rcvQueueMu.Unlock() return } - rtt := time.Since(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime) + rtt := r.ep.stack.Clock().NowMonotonic().Sub(r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime) // We only store the minimum observed RTT here as this is only used in // absence of a SRTT available from either timestamps or a sender // measurement of RTT. if r.ep.rcvQueueInfo.RcvAutoParams.RTT == 0 || rtt < r.ep.rcvQueueInfo.RcvAutoParams.RTT { r.ep.rcvQueueInfo.RcvAutoParams.RTT = rtt } - r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = time.Now() + r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureTime = r.ep.stack.Clock().NowMonotonic() r.ep.rcvQueueInfo.RcvAutoParams.RTTMeasureSeqNumber = r.RcvNxt.Add(r.rcvWnd) r.ep.rcvQueueInfo.rcvQueueMu.Unlock() } @@ -424,7 +423,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // while the FIN is considered to occur after // the last actual data octet in a segment in // which it occurs. - if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) { + if closed && (!s.flags.Contains(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.RcvNxt+1) { return true, &tcpip.ErrConnectionAborted{} } } @@ -467,11 +466,11 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { } // Store the time of the last ack. - r.lastRcvdAckTime = time.Now() + r.lastRcvdAckTime = r.ep.stack.Clock().NowMonotonic() // Defer segment processing if it can't be consumed now. if !r.consumeSegment(s, segSeq, segLen) { - if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { + if segLen > 0 || s.flags.Contains(header.TCPFlagFin) { // We only store the segment if it's within our buffer size limit. // // Only use 75% of the receive buffer queue for out-of-order @@ -539,7 +538,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // // As we do not yet support PAWS, we are being conservative in ignoring // RSTs by default. - if s.flagIsSet(header.TCPFlagRst) { + if s.flags.Contains(header.TCPFlagRst) { return false, false } @@ -559,13 +558,13 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // (2) returns to TIME-WAIT state if the SYN turns out // to be an old duplicate". - if s.flagIsSet(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) { + if s.flags.Contains(header.TCPFlagSyn) && r.RcvNxt.LessThan(segSeq) { return false, true } // Drop the segment if it does not contain an ACK. - if !s.flagIsSet(header.TCPFlagAck) { + if !s.flags.Contains(header.TCPFlagAck) { return false, false } @@ -574,7 +573,7 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.MaxSentAck, segSeq) } - if segSeq.Add(1) == r.RcvNxt && s.flagIsSet(header.TCPFlagFin) { + if segSeq.Add(1) == r.RcvNxt && s.flags.Contains(header.TCPFlagFin) { // If it's a FIN-ACK then resetTimeWait and send an ACK, as it // indicates our final ACK could have been lost. r.ep.snd.sendAck() diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go deleted file mode 100644 index 2bf21a2e7..000000000 --- a/pkg/tcpip/transport/tcp/rcv_state.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "time" -) - -// saveLastRcvdAckTime is invoked by stateify. -func (r *receiver) saveLastRcvdAckTime() unixTime { - return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()} -} - -// loadLastRcvdAckTime is invoked by stateify. -func (r *receiver) loadLastRcvdAckTime(unix unixTime) { - r.lastRcvdAckTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 7e5ba6ef7..ca78c96f2 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -17,7 +17,6 @@ package tcp import ( "fmt" "sync/atomic" - "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -73,9 +72,9 @@ type segment struct { parsedOptions header.TCPOptions options []byte `state:".([]byte)"` hasNewSACKInfo bool - rcvdTime time.Time `state:".(unixTime)"` + rcvdTime tcpip.MonotonicTime // xmitTime is the last transmit time of this segment. - xmitTime time.Time `state:".(unixTime)"` + xmitTime tcpip.MonotonicTime xmitCount uint32 // acked indicates if the segment has already been SACKed. @@ -88,7 +87,7 @@ type segment struct { lost bool } -func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { +func newIncomingSegment(id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) *segment { netHdr := pkt.Network() s := &segment{ refCnt: 1, @@ -100,17 +99,17 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) * } s.data = pkt.Data().ExtractVV().Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) - s.rcvdTime = time.Now() + s.rcvdTime = clock.NowMonotonic() s.dataMemSize = s.data.Size() return s } -func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment { +func newOutgoingSegment(id stack.TransportEndpointID, clock tcpip.Clock, v buffer.View) *segment { s := &segment{ refCnt: 1, id: id, } - s.rcvdTime = time.Now() + s.rcvdTime = clock.NowMonotonic() if len(v) != 0 { s.views[0] = v s.data = buffer.NewVectorisedView(len(v), s.views[:1]) @@ -149,16 +148,6 @@ func (s *segment) merge(oth *segment) { oth.dataMemSize = oth.data.Size() } -// flagIsSet checks if at least one flag in flags is set in s.flags. -func (s *segment) flagIsSet(flags header.TCPFlags) bool { - return s.flags&flags != 0 -} - -// flagsAreSet checks if all flags in flags are set in s.flags. -func (s *segment) flagsAreSet(flags header.TCPFlags) bool { - return s.flags&flags == flags -} - // setOwner sets the owning endpoint for this segment. Its required // to be called to ensure memory accounting for receive/send buffer // queues is done properly. @@ -198,10 +187,10 @@ func (s *segment) incRef() { // as the data length plus one for each of the SYN and FIN bits set. func (s *segment) logicalLen() seqnum.Size { l := seqnum.Size(s.data.Size()) - if s.flagIsSet(header.TCPFlagSyn) { + if s.flags.Contains(header.TCPFlagSyn) { l++ } - if s.flagIsSet(header.TCPFlagFin) { + if s.flags.Contains(header.TCPFlagFin) { l++ } return l @@ -243,7 +232,7 @@ func (s *segment) parse(skipChecksumValidation bool) bool { return false } - s.options = []byte(s.hdr[header.TCPMinimumSize:]) + s.options = s.hdr[header.TCPMinimumSize:] s.parsedOptions = header.ParseTCPOptions(s.options) if skipChecksumValidation { s.csumValid = true @@ -262,5 +251,5 @@ func (s *segment) parse(skipChecksumValidation bool) bool { // sackBlock returns a header.SACKBlock that represents this segment. func (s *segment) sackBlock() header.SACKBlock { - return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())} + return header.SACKBlock{Start: s.sequenceNumber, End: s.sequenceNumber.Add(s.logicalLen())} } diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go index 7422d8c02..dcfa80f95 100644 --- a/pkg/tcpip/transport/tcp/segment_state.go +++ b/pkg/tcpip/transport/tcp/segment_state.go @@ -15,8 +15,6 @@ package tcp import ( - "time" - "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -55,23 +53,3 @@ func (s *segment) loadOptions(options []byte) { // allocated so there is no cost here. s.options = options } - -// saveRcvdTime is invoked by stateify. -func (s *segment) saveRcvdTime() unixTime { - return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()} -} - -// loadRcvdTime is invoked by stateify. -func (s *segment) loadRcvdTime(unix unixTime) { - s.rcvdTime = time.Unix(unix.second, unix.nano) -} - -// saveXmitTime is invoked by stateify. -func (s *segment) saveXmitTime() unixTime { - return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()} -} - -// loadXmitTime is invoked by stateify. -func (s *segment) loadXmitTime(unix unixTime) { - s.rcvdTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/segment_test.go b/pkg/tcpip/transport/tcp/segment_test.go index 486016fc0..2e6ea06f5 100644 --- a/pkg/tcpip/transport/tcp/segment_test.go +++ b/pkg/tcpip/transport/tcp/segment_test.go @@ -19,6 +19,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -39,10 +40,11 @@ func checkSegmentSize(t *testing.T, name string, seg *segment, want segmentSizeW } func TestSegmentMerge(t *testing.T) { + var clock faketime.NullClock id := stack.TransportEndpointID{} - seg1 := newOutgoingSegment(id, buffer.NewView(10)) + seg1 := newOutgoingSegment(id, &clock, buffer.NewView(10)) defer seg1.decRef() - seg2 := newOutgoingSegment(id, buffer.NewView(20)) + seg2 := newOutgoingSegment(id, &clock, buffer.NewView(20)) defer seg2.decRef() checkSegmentSize(t, "seg1", seg1, segmentSizeWants{ diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index f43e86677..72d58dcff 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -94,7 +94,7 @@ type sender struct { // firstRetransmittedSegXmitTime is the original transmit time of // the first segment that was retransmitted due to RTO expiration. - firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` + firstRetransmittedSegXmitTime tcpip.MonotonicTime // zeroWindowProbing is set if the sender is currently probing // for zero receive window. @@ -169,7 +169,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint SndUna: iss + 1, SndNxt: iss + 1, RTTMeasureSeqNum: iss + 1, - LastSendTime: time.Now(), + LastSendTime: ep.stack.Clock().NowMonotonic(), MaxPayloadSize: maxPayloadSize, MaxSentAck: irs + 1, FastRecovery: stack.TCPFastRecoveryState{ @@ -197,9 +197,9 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s.SndWndScale = uint8(sndWndScale) } - s.resendTimer.init(&s.resendWaker) - s.reorderTimer.init(&s.reorderWaker) - s.probeTimer.init(&s.probeWaker) + s.resendTimer.init(s.ep.stack.Clock(), &s.resendWaker) + s.reorderTimer.init(s.ep.stack.Clock(), &s.reorderWaker) + s.probeTimer.init(s.ep.stack.Clock(), &s.probeWaker) s.updateMaxPayloadSize(int(ep.route.MTU()), 0) @@ -441,7 +441,7 @@ func (s *sender) retransmitTimerExpired() bool { // timeout since the first retransmission. uto := s.ep.userTimeout - if s.firstRetransmittedSegXmitTime.IsZero() { + if s.firstRetransmittedSegXmitTime == (tcpip.MonotonicTime{}) { // We store the original xmitTime of the segment that we are // about to retransmit as the retransmission time. This is // required as by the time the retransmitTimer has expired the @@ -450,7 +450,7 @@ func (s *sender) retransmitTimerExpired() bool { s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime } - elapsed := time.Since(s.firstRetransmittedSegXmitTime) + elapsed := s.ep.stack.Clock().NowMonotonic().Sub(s.firstRetransmittedSegXmitTime) remaining := s.maxRTO if uto != 0 { // Cap to the user specified timeout if one is specified. @@ -616,7 +616,7 @@ func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRt // 'S2' that meets the following 3 criteria for determinig // loss, the sequence range of one segment of up to SMSS // octects starting with S2 MUST be returned. - if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) { + if !s.ep.scoreboard.IsSACKED(header.SACKBlock{Start: segSeq, End: segSeq.Add(1)}) { // NextSeg(): // // (1.a) S2 is greater than HighRxt @@ -866,8 +866,8 @@ func (s *sender) enableZeroWindowProbing() { // We piggyback the probing on the retransmit timer with the // current retranmission interval, as we may start probing while // segment retransmissions. - if s.firstRetransmittedSegXmitTime.IsZero() { - s.firstRetransmittedSegXmitTime = time.Now() + if s.firstRetransmittedSegXmitTime == (tcpip.MonotonicTime{}) { + s.firstRetransmittedSegXmitTime = s.ep.stack.Clock().NowMonotonic() } s.resendTimer.enable(s.RTO) } @@ -875,7 +875,7 @@ func (s *sender) enableZeroWindowProbing() { func (s *sender) disableZeroWindowProbing() { s.zeroWindowProbing = false s.unackZeroWindowProbes = 0 - s.firstRetransmittedSegXmitTime = time.Time{} + s.firstRetransmittedSegXmitTime = tcpip.MonotonicTime{} s.resendTimer.disable() } @@ -925,7 +925,7 @@ func (s *sender) sendData() { // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding // the retrasmission timeout." - if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && time.Now().Sub(s.LastSendTime) > s.RTO { + if !s.FastRecovery.Active && s.state != tcpip.RTORecovery && s.ep.stack.Clock().NowMonotonic().Sub(s.LastSendTime) > s.RTO { if s.SndCwnd > InitialCwnd { s.SndCwnd = InitialCwnd } @@ -1024,7 +1024,7 @@ func (s *sender) SetPipe() { if segEnd.LessThan(endSeq) { endSeq = segEnd } - sb := header.SACKBlock{startSeq, endSeq} + sb := header.SACKBlock{Start: startSeq, End: endSeq} // SetPipe(): // // After initializing pipe to zero, the following steps are @@ -1132,7 +1132,7 @@ func (s *sender) isDupAck(seg *segment) bool { // (b) The incoming acknowledgment carries no data. seg.logicalLen() == 0 && // (c) The SYN and FIN bits are both off. - !seg.flagIsSet(header.TCPFlagFin) && !seg.flagIsSet(header.TCPFlagSyn) && + !seg.flags.Intersects(header.TCPFlagFin|header.TCPFlagSyn) && // (d) the ACK number is equal to the greatest acknowledgment received on // the given connection (TCP.UNA from RFC793). seg.ackNumber == s.SndUna && @@ -1234,7 +1234,7 @@ func checkDSACK(rcvdSeg *segment) bool { func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Check if we can extract an RTT measurement from this ack. if !rcvdSeg.parsedOptions.TS && s.RTTMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { - s.updateRTO(time.Now().Sub(s.RTTMeasureTime)) + s.updateRTO(s.ep.stack.Clock().NowMonotonic().Sub(s.RTTMeasureTime)) s.RTTMeasureSeqNum = s.SndNxt } @@ -1444,7 +1444,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { if s.SndUna == s.SndNxt { s.Outstanding = 0 // Reset firstRetransmittedSegXmitTime to the zero value. - s.firstRetransmittedSegXmitTime = time.Time{} + s.firstRetransmittedSegXmitTime = tcpip.MonotonicTime{} s.resendTimer.disable() s.probeTimer.disable() } @@ -1455,7 +1455,7 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // * Upon receiving an ACK: // * Step 4: Update RACK reordering window - s.rc.updateRACKReorderWindow(rcvdSeg) + s.rc.updateRACKReorderWindow() // After the reorder window is calculated, detect any loss by checking // if the time elapsed after the segments are sent is greater than the @@ -1502,7 +1502,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() } } - seg.xmitTime = time.Now() + seg.xmitTime = s.ep.stack.Clock().NowMonotonic() seg.xmitCount++ seg.lost = false err := s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber) @@ -1527,7 +1527,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error { - s.LastSendTime = time.Now() + s.LastSendTime = s.ep.stack.Clock().NowMonotonic() if seq == s.RTTMeasureSeqNum { s.RTTMeasureTime = s.LastSendTime } diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go deleted file mode 100644 index 2f805d8ce..000000000 --- a/pkg/tcpip/transport/tcp/snd_state.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp - -import ( - "time" -) - -// +stateify savable -type unixTime struct { - second int64 - nano int64 -} - -// afterLoad is invoked by stateify. -func (s *sender) afterLoad() { - s.resendTimer.init(&s.resendWaker) - s.reorderTimer.init(&s.reorderWaker) - s.probeTimer.init(&s.probeWaker) -} - -// saveFirstRetransmittedSegXmitTime is invoked by stateify. -func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime { - return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()} -} - -// loadFirstRetransmittedSegXmitTime is invoked by stateify. -func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) { - s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano) -} diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index c58361bc1..d6cf786a1 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -38,7 +38,7 @@ const ( func setStackRACKPermitted(t *testing.T, c *context.Context) { t.Helper() - opt := tcpip.TCPRecovery(tcpip.TCPRACKLossDetection) + opt := tcpip.TCPRACKLossDetection if err := c.Stack().SetTransportProtocolOption(header.TCPProtocolNumber, &opt); err != nil { t.Fatalf("c.s.SetTransportProtocolOption(%d, &%v(%v)): %s", header.TCPProtocolNumber, opt, opt, err) } @@ -50,7 +50,7 @@ func TestRACKUpdate(t *testing.T) { c := context.New(t, uint32(mtu)) defer c.Cleanup() - var xmitTime time.Time + var xmitTime tcpip.MonotonicTime probeDone := make(chan struct{}) c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { // Validate that the endpoint Sender.RACKState is what we expect. @@ -79,7 +79,7 @@ func TestRACKUpdate(t *testing.T) { } // Write the data. - xmitTime = time.Now() + xmitTime = c.Stack().Clock().NowMonotonic() var r bytes.Reader r.Reset(data) if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9916182e3..9bbe9bc3e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4259,7 +4259,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { SrcPort: context.TestPort, DstPort: c.Port, Flags: header.TCPFlagAck, - SeqNum: seqnum.Value(iss), + SeqNum: iss, AckNum: c.IRS.Add(1), RcvWnd: 30000, }) @@ -5335,7 +5335,7 @@ func TestKeepalive(t *testing.T) { checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next-1)), + checker.TCPSeqNum(next-1), checker.TCPAckNum(uint32(iss)), checker.TCPFlags(header.TCPFlagAck), ), @@ -5360,12 +5360,7 @@ func TestKeepalive(t *testing.T) { }) checker.IPv4(t, c.GetPacket(), - checker.TCP( - checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next)), - checker.TCPAckNum(uint32(0)), - checker.TCPFlags(header.TCPFlagRst), - ), + checker.TCP(checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst)), ) if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { @@ -5507,7 +5502,7 @@ func TestListenBacklogFull(t *testing.T) { // Now execute send one more SYN. The stack should not respond as the backlog // is full at this point. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + uint16(lastPortOffset), + SrcPort: context.TestPort + lastPortOffset, DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: seqnum.Value(context.TestInitialSequenceNumber), @@ -5884,7 +5879,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { r.Reset(data) newEP.Write(&r, tcpip.WriteOptions{}) pkt := c.GetPacket() - tcp = header.TCP(header.IPv4(pkt).Payload()) + tcp = header.IPv4(pkt).Payload() if string(tcp.Payload()) != data { t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) } @@ -6118,7 +6113,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { } pkt := c.GetPacket() - tcpHdr = header.TCP(header.IPv4(pkt).Payload()) + tcpHdr = header.IPv4(pkt).Payload() if string(tcpHdr.Payload()) != data { t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) } @@ -6243,6 +6238,54 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { } } +func TestListenDropIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + c.Create(-1 /*epRcvBuf*/) + + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := c.EP.Listen(1 /*backlog*/); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + initialDropped := stats.DroppedPackets.Value() + + // Send RST, FIN segments, that are expected to be dropped by the listener. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + }) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin, + }) + + // To ensure that the RST, FIN sent earlier are indeed received and ignored + // by the listener, send a SYN and wait for the SYN to be ACKd. + irs := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + }) + checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs)+1), + )) + + if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want { + t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want) + } +} + func TestEndpointBindListenAcceptState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -6375,7 +6418,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Allocate a large enough payload for the test. payloadSize := receiveBufferSize * 2 - b := make([]byte, int(payloadSize)) + b := make([]byte, payloadSize) worker := (c.EP).(interface { StopWork() @@ -6429,7 +6472,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // ack, 1 for the non-zero window p := c.GetPacket() checker.IPv4(t, p, checker.TCP( - checker.TCPAckNum(uint32(wantAckNum)), + checker.TCPAckNum(wantAckNum), func(t *testing.T, h header.Transport) { tcp, ok := h.(header.TCP) if !ok { @@ -6484,14 +6527,14 @@ func TestReceiveBufferAutoTuning(t *testing.T) { c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) - tsVal := uint32(rawEP.TSVal) + tsVal := rawEP.TSVal rawEP.NextSeqNum-- rawEP.SendPacketWithTS(nil, tsVal) rawEP.NextSeqNum++ pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) curRcvWnd := int(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.WindowScale scaleRcvWnd := func(rcvWnd int) uint16 { - return uint16(rcvWnd >> uint16(c.WindowScale)) + return uint16(rcvWnd >> c.WindowScale) } // Allocate a large array to send to the endpoint. b := make([]byte, receiveBufferSize*48) @@ -6619,19 +6662,16 @@ func TestDelayEnabled(t *testing.T) { defer c.Cleanup() checkDelayOption(t, c, false, false) // Delay is disabled by default. - for _, v := range []struct { - delayEnabled tcpip.TCPDelayEnabled - wantDelayOption bool - }{ - {delayEnabled: false, wantDelayOption: false}, - {delayEnabled: true, wantDelayOption: true}, - } { - c := context.New(t, defaultMTU) - defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err) - } - checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) + for _, delayEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("delayEnabled=%t", delayEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + opt := tcpip.TCPDelayEnabled(delayEnabled) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, delayEnabled, err) + } + checkDelayOption(t, c, opt, delayEnabled) + }) } } @@ -7042,7 +7082,7 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { // Receive the SYN-ACK reply. b = c.GetPacket() - tcpHdr = header.TCP(header.IPv4(b).Payload()) + tcpHdr = header.IPv4(b).Payload() c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) ackHeaders = &context.Headers{ @@ -7467,7 +7507,7 @@ func TestTCPUserTimeout(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPSeqNum(uint32(next)), + checker.TCPSeqNum(next), checker.TCPAckNum(uint32(0)), checker.TCPFlags(header.TCPFlagRst), ), @@ -7545,7 +7585,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { DstPort: c.Port, Flags: header.TCPFlagAck, SeqNum: iss, - AckNum: seqnum.Value(c.IRS + 1), + AckNum: c.IRS + 1, RcvWnd: 30000, }) diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go index 38a335840..5645c772e 100644 --- a/pkg/tcpip/transport/tcp/timer.go +++ b/pkg/tcpip/transport/tcp/timer.go @@ -15,21 +15,29 @@ package tcp import ( + "math" "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" ) type timerState int const ( + // The timer is disabled. timerStateDisabled timerState = iota + // The timer is enabled, but the clock timer may be set to an earlier + // expiration time due to a previous orphaned state. timerStateEnabled + // The timer is disabled, but the clock timer is enabled, which means that + // it will cause a spurious wakeup unless the timer is enabled before the + // clock timer fires. timerStateOrphaned ) // timer is a timer implementation that reduces the interactions with the -// runtime timer infrastructure by letting timers run (and potentially +// clock timer infrastructure by letting timers run (and potentially // eventually expire) even if they are stopped. It makes it cheaper to // disable/reenable timers at the expense of spurious wakes. This is useful for // cases when the same timer is disabled/reenabled repeatedly with relatively @@ -39,44 +47,37 @@ const ( // (currently at least 200ms), and get disabled when acks are received, and // reenabled when new pending segments are sent. // -// It is advantageous to avoid interacting with the runtime because it acquires +// It is advantageous to avoid interacting with the clock because it acquires // a global mutex and performs O(log n) operations, where n is the global number // of timers, whenever a timer is enabled or disabled, and may make a syscall. // // This struct is thread-compatible. type timer struct { - // state is the current state of the timer, it can be one of the - // following values: - // disabled - the timer is disabled. - // orphaned - the timer is disabled, but the runtime timer is - // enabled, which means that it will evetually cause a - // spurious wake (unless it gets enabled again before - // then). - // enabled - the timer is enabled, but the runtime timer may be set - // to an earlier expiration time due to a previous - // orphaned state. state timerState + clock tcpip.Clock + // target is the expiration time of the current timer. It is only // meaningful in the enabled state. - target time.Time + target tcpip.MonotonicTime - // runtimeTarget is the expiration time of the runtime timer. It is + // clockTarget is the expiration time of the clock timer. It is // meaningful in the enabled and orphaned states. - runtimeTarget time.Time + clockTarget tcpip.MonotonicTime - // timer is the runtime timer used to wait on. - timer *time.Timer + // timer is the clock timer used to wait on. + timer tcpip.Timer } // init initializes the timer. Once it expires, it the given waker will be // asserted. -func (t *timer) init(w *sleep.Waker) { +func (t *timer) init(clock tcpip.Clock, w *sleep.Waker) { t.state = timerStateDisabled + t.clock = clock - // Initialize a runtime timer that will assert the waker, then + // Initialize a clock timer that will assert the waker, then // immediately stop it. - t.timer = time.AfterFunc(time.Hour, func() { + t.timer = t.clock.AfterFunc(math.MaxInt64, func() { w.Assert() }) t.timer.Stop() @@ -106,9 +107,9 @@ func (t *timer) checkExpiration() bool { // The timer is enabled, but it may have expired early. Check if that's // the case, and if so, reset the runtime timer to the correct time. - now := time.Now() + now := t.clock.NowMonotonic() if now.Before(t.target) { - t.runtimeTarget = t.target + t.clockTarget = t.target t.timer.Reset(t.target.Sub(now)) return false } @@ -134,11 +135,11 @@ func (t *timer) enabled() bool { // enable enables the timer, programming the runtime timer if necessary. func (t *timer) enable(d time.Duration) { - t.target = time.Now().Add(d) + t.target = t.clock.NowMonotonic().Add(d) // Check if we need to set the runtime timer. - if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) { - t.runtimeTarget = t.target + if t.state == timerStateDisabled || t.target.Before(t.clockTarget) { + t.clockTarget = t.target t.timer.Reset(d) } diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go index dbd6dff54..479752de7 100644 --- a/pkg/tcpip/transport/tcp/timer_test.go +++ b/pkg/tcpip/transport/tcp/timer_test.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip/faketime" ) func TestCleanup(t *testing.T) { @@ -27,9 +28,11 @@ func TestCleanup(t *testing.T) { isAssertedTimeoutSeconds = timerDurationSeconds + 1 ) + clock := faketime.NewManualClock() + tmr := timer{} w := sleep.Waker{} - tmr.init(&w) + tmr.init(clock, &w) tmr.enable(timerDurationSeconds * time.Second) tmr.cleanup() @@ -39,7 +42,7 @@ func TestCleanup(t *testing.T) { // The waker should not be asserted. for i := 0; i < isAssertedTimeoutSeconds; i++ { - time.Sleep(time.Second) + clock.Advance(time.Second) if w.IsAsserted() { t.Fatalf("waker asserted unexpectedly") } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index dd5c910ae..cdc344ab7 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -49,6 +49,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index f7dd50d35..def9d7186 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -17,6 +17,7 @@ package udp import ( "io" "sync/atomic" + "time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -34,25 +35,26 @@ type udpPacket struct { destinationAddress tcpip.FullAddress packetInfo tcpip.IPPacketInfo data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + receivedAt time.Time `state:".(int64)"` // tos stores either the receiveTOS or receiveTClass value. tos uint8 } // EndpointState represents the state of a UDP endpoint. -type EndpointState uint32 +type EndpointState tcpip.EndpointState // Endpoint states. Note that are represented in a netstack-specific manner and // may not be meaningful externally. Specifically, they need to be translated to // Linux's representation for these states if presented to userspace. const ( - StateInitial EndpointState = iota + _ EndpointState = iota + StateInitial StateBound StateConnected StateClosed ) -// String implements fmt.Stringer.String. +// String implements fmt.Stringer. func (s EndpointState) String() string { switch s { case StateInitial: @@ -98,7 +100,7 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` // state must be read/set using the EndpointState()/setEndpointState() // methods. - state EndpointState + state uint32 route *stack.Route `state:"manual"` dstPort uint16 ttl uint8 @@ -176,7 +178,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // Linux defaults to TTL=1. multicastTTL: 1, multicastMemberships: make(map[multicastMembership]struct{}), - state: StateInitial, + state: uint32(StateInitial), uniqueID: s.UniqueID(), } e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) @@ -204,15 +206,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Precondition: e.mu must be held to call this method. func (e *endpoint) setEndpointState(state EndpointState) { - atomic.StoreUint32((*uint32)(&e.state), uint32(state)) + atomic.StoreUint32(&e.state, uint32(state)) } // EndpointState() returns the current state of the endpoint. func (e *endpoint) EndpointState() EndpointState { - return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + return EndpointState(atomic.LoadUint32(&e.state)) } -// UniqueID implements stack.TransportEndpoint.UniqueID. +// UniqueID implements stack.TransportEndpoint. func (e *endpoint) UniqueID() uint64 { return e.uniqueID } @@ -226,14 +228,14 @@ func (e *endpoint) LastError() tcpip.Error { return err } -// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +// UpdateLastError implements tcpip.SocketOptionsHandler. func (e *endpoint) UpdateLastError(err tcpip.Error) { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() } -// Abort implements stack.TransportEndpoint.Abort. +// Abort implements stack.TransportEndpoint. func (e *endpoint) Abort() { e.Close() } @@ -289,10 +291,10 @@ func (e *endpoint) Close() { e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } -// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. -func (e *endpoint) ModerateRecvBuf(copied int) {} +// ModerateRecvBuf implements tcpip.Endpoint. +func (*endpoint) ModerateRecvBuf(int) {} -// Read implements tcpip.Endpoint.Read. +// Read implements tcpip.Endpoint. func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) { if err := e.LastError(); err != nil { return tcpip.ReadResult{}, err @@ -320,7 +322,7 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // Control Messages cm := tcpip.ControlMessages{ HasTimestamp: true, - Timestamp: p.timestamp, + Timestamp: p.receivedAt.UnixNano(), } if e.ops.GetReceiveTOS() { cm.HasTOS = true @@ -581,21 +583,21 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return int64(len(v)), nil } -// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. +// OnReuseAddressSet implements tcpip.SocketOptionsHandler. func (e *endpoint) OnReuseAddressSet(v bool) { e.mu.Lock() e.portFlags.MostRecent = v e.mu.Unlock() } -// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet. +// OnReusePortSet implements tcpip.SocketOptionsHandler. func (e *endpoint) OnReusePortSet(v bool) { e.mu.Lock() e.portFlags.LoadBalanced = v e.mu.Unlock() } -// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +// SetSockOptInt implements tcpip.Endpoint. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { switch opt { case tcpip.MTUDiscoverOption: @@ -629,11 +631,14 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { return nil } +var _ tcpip.SocketOptionsHandler = (*endpoint)(nil) + +// HasNIC implements tcpip.SocketOptionsHandler. func (e *endpoint) HasNIC(id int32) bool { - return id == 0 || e.stack.HasNIC(tcpip.NICID(id)) + return e.stack.HasNIC(tcpip.NICID(id)) } -// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +// SetSockOpt implements tcpip.Endpoint. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { switch v := opt.(type) { case *tcpip.MulticastInterfaceOption: @@ -749,7 +754,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { return nil } -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +// GetSockOptInt implements tcpip.Endpoint. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.IPv4TOSOption: @@ -795,14 +800,14 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +// GetSockOpt implements tcpip.Endpoint. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { switch o := opt.(type) { case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ - e.multicastNICID, - e.multicastAddr, + NIC: e.multicastNICID, + InterfaceAddr: e.multicastAddr, } e.mu.Unlock() @@ -872,7 +877,7 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres return unwrapped, netProto, nil } -// Disconnect implements tcpip.Endpoint.Disconnect. +// Disconnect implements tcpip.Endpoint. func (e *endpoint) Disconnect() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -1075,7 +1080,7 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id BindToDevice: bindToDevice, Dest: tcpip.FullAddress{}, } - port, err := e.stack.ReservePort(portRes, nil /* testPort */) + port, err := e.stack.ReservePort(e.stack.Rand(), portRes, nil /* testPort */) if err != nil { return id, bindToDevice, err } @@ -1301,12 +1306,12 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB senderAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.RemoteAddress, - Port: header.UDP(hdr).SourcePort(), + Port: hdr.SourcePort(), }, destinationAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.LocalAddress, - Port: header.UDP(hdr).DestinationPort(), + Port: hdr.DestinationPort(), }, data: pkt.Data().ExtractVV(), } @@ -1328,7 +1333,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB packet.packetInfo.LocalAddr = localAddr packet.packetInfo.DestinationAddr = localAddr packet.packetInfo.NIC = pkt.NICID - packet.timestamp = e.stack.Clock().NowNanoseconds() + packet.receivedAt = e.stack.Clock().Now() e.rcvMu.Unlock() @@ -1386,7 +1391,7 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB } } -// State implements tcpip.Endpoint.State. +// State implements tcpip.Endpoint. func (e *endpoint) State() uint32 { return uint32(e.EndpointState()) } @@ -1405,19 +1410,19 @@ func (e *endpoint) Stats() tcpip.EndpointStats { return &e.stats } -// Wait implements tcpip.Endpoint.Wait. +// Wait implements tcpip.Endpoint. func (*endpoint) Wait() {} func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr) } -// SetOwner implements tcpip.Endpoint.SetOwner. +// SetOwner implements tcpip.Endpoint. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// SocketOptions implements tcpip.Endpoint.SocketOptions. +// SocketOptions implements tcpip.Endpoint. func (e *endpoint) SocketOptions() *tcpip.SocketOptions { return &e.ops } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 4aba68b21..1f638c3f6 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -15,26 +15,38 @@ package udp import ( + "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// saveReceivedAt is invoked by stateify. +func (p *udpPacket) saveReceivedAt() int64 { + return p.receivedAt.UnixNano() +} + +// loadReceivedAt is invoked by stateify. +func (p *udpPacket) loadReceivedAt(nsec int64) { + p.receivedAt = time.Unix(0, nsec) +} + // saveData saves udpPacket.data field. -func (u *udpPacket) saveData() buffer.VectorisedView { - // We cannot save u.data directly as u.data.views may alias to u.views, +func (p *udpPacket) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, // which is not allowed by state framework (in-struct pointer). - return u.data.Clone(nil) + return p.data.Clone(nil) } // loadData loads udpPacket.data field. -func (u *udpPacket) loadData(data buffer.VectorisedView) { - // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization +func (p *udpPacket) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization // here because data.views is not guaranteed to be loaded by now. Plus, // data.views will be allocated anyway so there really is little point - // of utilizing u.views for data.views. - u.data = data + // of utilizing p.views for data.views. + p.data = data } // afterLoad is invoked by stateify. diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 705ad1f64..7c357cb09 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -90,7 +90,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.RegisterNICID = r.pkt.NICID ep.boundPortFlags = ep.portFlags - ep.state = StateConnected + ep.state = uint32(StateConnected) ep.rcvMu.Lock() ep.rcvReady = true diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index dc2e3f493..4008cacf2 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -16,17 +16,16 @@ package udp_test import ( "bytes" - "context" "fmt" "io/ioutil" "math/rand" "testing" - "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -298,16 +297,18 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() - return newDualTestContextWithOptions(t, mtu, stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, - HandleLocal: true, - }) + return newDualTestContextWithHandleLocal(t, mtu, true) } -func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext { +func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext { t.Helper() + options := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: handleLocal, + Clock: &faketime.NullClock{}, + } s := stack.New(options) ep := channel.New(256, mtu, "") wep := stack.LinkEndpoint(ep) @@ -378,9 +379,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { c.t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) + p, ok := c.linkEP.Read() if !ok { c.t.Fatalf("Packet wasn't written out") return nil @@ -534,7 +533,9 @@ func newMinPayload(minSize int) []byte { func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}}) + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: &faketime.NullClock{}, + }) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -606,7 +607,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe case <-ch: res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) - case <-time.After(300 * time.Millisecond): + default: if packetShouldBeDropped { return // expected to time out } @@ -820,11 +821,7 @@ func TestV4ReadSelfSource(t *testing.T) { {"NoHandleLocal", true, &tcpip.ErrWouldBlock{}, 1}, } { t.Run(tt.name, func(t *testing.T) { - c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - HandleLocal: tt.handleLocal, - }) + c := newDualTestContextWithHandleLocal(t, defaultMTU, tt.handleLocal) defer c.cleanup() c.createEndpointForFlow(unicastV4) @@ -1034,17 +1031,17 @@ func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, che payload := testWriteNoVerify(c, flow, setDest) // Received the packet and check the payload. b := c.getPacketAndVerify(flow, checkers...) - var udp header.UDP + var udpH header.UDP if flow.isV4() { - udp = header.UDP(header.IPv4(b).Payload()) + udpH = header.IPv4(b).Payload() } else { - udp = header.UDP(header.IPv6(b).Payload()) + udpH = header.IPv6(b).Payload() } - if !bytes.Equal(payload, udp.Payload()) { - c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + if !bytes.Equal(payload, udpH.Payload()) { + c.t.Fatalf("Bad payload: got %x, want %x", udpH.Payload(), payload) } - return udp.SourcePort() + return udpH.SourcePort() } func testDualWrite(c *testContext) uint16 { @@ -1198,7 +1195,7 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { r.Reset(payload) n, err := c.ep.Write(&r, writeOpts) if err != nil { - c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) + c.t.Fatalf("c.ep.Write(...) = %s, want nil", err) } if got, want := n, int64(len(payload)); got != want { c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) @@ -1462,7 +1459,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { name: "IPv4 unicast", proto: header.IPv4ProtocolNumber, flow: unicastV4, - expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackAddr, Port: stackPort}, }, { name: "IPv4 multicast", @@ -1474,7 +1471,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastAddr, Port: stackPort}, }, { name: "IPv4 broadcast", @@ -1486,13 +1483,13 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: broadcastAddr, Port: stackPort}, }, { name: "IPv6 unicast", proto: header.IPv6ProtocolNumber, flow: unicastV6, - expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: stackV6Addr, Port: stackPort}, }, { name: "IPv6 multicast", @@ -1504,7 +1501,7 @@ func TestReadRecvOriginalDstAddr(t *testing.T) { // behaviour. We still include the test so that once the bug is // resolved, this test will start to fail and the individual tasked // with fixing this bug knows to also fix this test :). - expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort}, + expectedOriginalDstAddr: tcpip.FullAddress{NIC: 1, Addr: multicastV6Addr, Port: stackPort}, }, } @@ -1614,6 +1611,7 @@ func TestTTL(t *testing.T) { } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{p}, + Clock: &faketime.NullClock{}, }) ep := s.NetworkProtocolInstance(n).NewEndpoint(&testInterface{}, nil) wantTTL = ep.DefaultTTL() @@ -1759,7 +1757,7 @@ func TestReceiveTosTClass(t *testing.T) { c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) } - want := true + const want = true optionSetter(want) got := optionGetter() @@ -1889,18 +1887,14 @@ func TestV4UnknownDestination(t *testing.T) { } } if !tc.icmpRequired { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if p, ok := c.linkEP.ReadContext(ctx); ok { + if p, ok := c.linkEP.Read(); ok { t.Fatalf("unexpected packet received: %+v", p) } return } // ICMP required. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) + p, ok := c.linkEP.Read() if !ok { t.Fatalf("packet wasn't written out") return @@ -1987,18 +1981,14 @@ func TestV6UnknownDestination(t *testing.T) { } } if !tc.icmpRequired { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - if p, ok := c.linkEP.ReadContext(ctx); ok { + if p, ok := c.linkEP.Read(); ok { t.Fatalf("unexpected packet received: %+v", p) } return } // ICMP required. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - p, ok := c.linkEP.ReadContext(ctx) + p, ok := c.linkEP.Read() if !ok { t.Fatalf("packet wasn't written out") return @@ -2115,8 +2105,8 @@ func TestShortHeader(t *testing.T) { Data: buf.ToVectorisedView(), })) - if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { - t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) + if got, want := c.s.Stats().NICs.MalformedL4RcvdPackets.Value(), uint64(1); got != want { + t.Errorf("got c.s.Stats().NIC.MalformedL4RcvdPackets.Value() = %d, want = %d", got, want) } } @@ -2124,25 +2114,27 @@ func TestShortHeader(t *testing.T) { // global and endpoint stats are incremented. func TestBadChecksumErrors(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV6} { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() + t.Run(flow.String(), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - c.createEndpoint(flow.sockProto()) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } + c.createEndpoint(flow.sockProto()) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } - payload := newPayload() - c.injectPacket(flow, payload, true /* badChecksum */) + payload := newPayload() + c.injectPacket(flow, payload, true /* badChecksum */) - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } + }) } } @@ -2484,6 +2476,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: &faketime.NullClock{}, }) e := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID1, e); err != nil { |