diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/header/ipv4.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 43 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 60 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 21 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 193 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/BUILD | 21 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/multicast_broadcast_test.go | 274 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/BUILD | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 33 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rack.go | 82 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rack_state.go | 29 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 40 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_rack_test.go | 74 |
18 files changed, 801 insertions, 127 deletions
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 62ac932bb..d0d1efd0d 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -101,6 +101,11 @@ const ( // IPv4Version is the version of the ipv4 protocol. IPv4Version = 4 + // IPv4AllSystems is the all systems IPv4 multicast address as per + // IANA's IPv4 Multicast Address Space Registry. See + // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml. + IPv4AllSystems tcpip.Address = "\xe0\x00\x00\x01" + // IPv4Broadcast is the broadcast address of the IPv4 procotol. IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff" diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index d5f5d38f7..6c4f0ae3e 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -52,27 +52,25 @@ const ( ) type endpoint struct { - nicID tcpip.NICID - id stack.NetworkEndpointID - prefixLen int - linkEP stack.LinkEndpoint - dispatcher stack.TransportDispatcher - fragmentation *fragmentation.Fragmentation - protocol *protocol - stack *stack.Stack + nicID tcpip.NICID + id stack.NetworkEndpointID + prefixLen int + linkEP stack.LinkEndpoint + dispatcher stack.TransportDispatcher + protocol *protocol + stack *stack.Stack } // NewEndpoint creates a new ipv4 endpoint. func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) { e := &endpoint{ - nicID: nicID, - id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, - prefixLen: addrWithPrefix.PrefixLen, - linkEP: linkEP, - dispatcher: dispatcher, - fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), - protocol: p, - stack: st, + nicID: nicID, + id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, + prefixLen: addrWithPrefix.PrefixLen, + linkEP: linkEP, + dispatcher: dispatcher, + protocol: p, + stack: st, } return e, nil @@ -442,7 +440,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } var ready bool var err error - pkt.Data, ready, err = e.fragmentation.Process( + pkt.Data, ready, err = e.protocol.fragmentation.Process( + // As per RFC 791 section 2.3, the identification value is unique + // for a source-destination pair and protocol. fragmentation.FragmentID{ Source: h.SourceAddress(), Destination: h.DestinationAddress(), @@ -484,6 +484,8 @@ type protocol struct { // uint8 portion of it is meaningful and it must be accessed // atomically. defaultTTL uint32 + + fragmentation *fragmentation.Fragmentation } // Number returns the ipv4 protocol number. @@ -605,5 +607,10 @@ func NewProtocol() stack.NetworkProtocol { } hashIV := r[buckets] - return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL} + return &protocol{ + ids: ids, + hashIV: hashIV, + defaultTTL: DefaultTTL, + fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + } } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index a0a5c9c01..4a0b53c45 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -51,7 +51,6 @@ type endpoint struct { linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache dispatcher stack.TransportDispatcher - fragmentation *fragmentation.Fragmentation protocol *protocol } @@ -342,7 +341,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { var ready bool // Note that pkt doesn't have its transport header set after reassembly, // and won't until DeliverNetworkPacket sets it. - pkt.Data, ready, err = e.fragmentation.Process( + pkt.Data, ready, err = e.protocol.fragmentation.Process( // IPv6 ignores the Protocol field since the ID only needs to be unique // across source-destination pairs, as per RFC 8200 section 4.5. fragmentation.FragmentID{ @@ -445,7 +444,8 @@ type protocol struct { // defaultTTL is the current default TTL for the protocol. Only the // uint8 portion of it is meaningful and it must be accessed // atomically. - defaultTTL uint32 + defaultTTL uint32 + fragmentation *fragmentation.Fragmentation } // Number returns the ipv6 protocol number. @@ -478,7 +478,6 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi linkEP: linkEP, linkAddrCache: linkAddrCache, dispatcher: dispatcher, - fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), protocol: p, }, nil } @@ -606,5 +605,8 @@ func calculateMTU(mtu uint32) uint32 { // NewProtocol returns an IPv6 network protocol. func NewProtocol() stack.NetworkProtocol { - return &protocol{defaultTTL: DefaultTTL} + return &protocol{ + defaultTTL: DefaultTTL, + fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + } } diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 644ba7c33..5d286ccbc 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1689,13 +1689,7 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) AddressWithPrefix: item, } - for _, i := range list { - if i == protocolAddress { - return true - } - } - - return false + return containsAddr(list, protocolAddress) } // TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f21066fce..eaaf756cd 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -217,6 +217,11 @@ func (n *NIC) disableLocked() *tcpip.Error { } if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + // The NIC may have already left the multicast group. + if err := n.leaveGroupLocked(header.IPv4AllSystems, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + // The address may have already been removed. if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress { return err @@ -255,6 +260,13 @@ func (n *NIC) enable() *tcpip.Error { if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { return err } + + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts + // multicast group. Note, the IANA calls the all-hosts multicast group the + // all-systems multicast group. + if err := n.joinGroupLocked(header.IPv4ProtocolNumber, header.IPv4AllSystems); err != nil { + return err + } } // Join the IPv6 All-Nodes Multicast group if the stack is configured to @@ -609,6 +621,9 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // If none exists a temporary one may be created if we are in promiscuous mode // or spoofing. Promiscuous mode will only be checked if promiscuous is true. // Similarly, spoofing will only be checked if spoofing is true. +// +// If the address is the IPv4 broadcast address for an endpoint's network, that +// endpoint will be returned. func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { n.mu.RLock() @@ -633,6 +648,16 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t } } + // Check if address is a broadcast address for the endpoint's network. + // + // Only IPv4 has a notion of broadcast addresses. + if protocol == header.IPv4ProtocolNumber { + if ref := n.getRefForBroadcastRLocked(address); ref != nil { + n.mu.RUnlock() + return ref + } + } + // A usable reference was not found, create a temporary one if requested by // the caller or if the address is found in the NIC's subnets. createTempEP := spoofingOrPromiscuous @@ -670,8 +695,34 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t return ref } +// getRefForBroadcastLocked returns an endpoint where address is the IPv4 +// broadcast address for the endpoint's network. +// +// n.mu MUST be read locked. +func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint { + for _, ref := range n.mu.endpoints { + // Only IPv4 has a notion of broadcast addresses. + if ref.protocol != header.IPv4ProtocolNumber { + continue + } + + addr := ref.addrWithPrefix() + subnet := addr.Subnet() + if subnet.IsBroadcast(address) && ref.tryIncRef() { + return ref + } + } + + return nil +} + /// getRefOrCreateTempLocked returns an existing endpoint for address or creates /// and returns a temporary endpoint. +// +// If the address is the IPv4 broadcast address for an endpoint's network, that +// endpoint will be returned. +// +// n.mu must be write locked. func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // No need to check the type as we are ok with expired endpoints at this @@ -685,6 +736,15 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add n.removeEndpointLocked(ref) } + // Check if address is a broadcast address for an endpoint's network. + // + // Only IPv4 has a notion of broadcast addresses. + if protocol == header.IPv4ProtocolNumber { + if ref := n.getRefForBroadcastRLocked(address); ref != nil { + return ref + } + } + // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3f07e4159..5b19c5d59 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -73,6 +73,16 @@ type TCPCubicState struct { WEst float64 } +// TCPRACKState is used to hold a copy of the internal RACK state when the +// TCPProbeFunc is invoked. +type TCPRACKState struct { + XmitTime time.Time + EndSequence seqnum.Value + FACK seqnum.Value + RTT time.Duration + Reord bool +} + // TCPEndpointID is the unique 4 tuple that identifies a given endpoint. type TCPEndpointID struct { // LocalPort is the local port associated with the endpoint. @@ -212,6 +222,9 @@ type TCPSenderState struct { // Cubic holds the state related to CUBIC congestion control. Cubic TCPCubicState + + // RACKState holds the state related to RACK loss detection algorithm. + RACKState TCPRACKState } // TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. @@ -1972,8 +1985,8 @@ func generateRandInt64() int64 { // FindNetworkEndpoint returns the network endpoint for the given address. func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() for _, nic := range s.nics { id := NetworkEndpointID{address} @@ -1992,8 +2005,8 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres // FindNICNameFromID returns the name of the nic for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index f22062889..0b6deda02 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -277,6 +277,17 @@ func (l *linkEPWithMockedAttach) isAttached() bool { return l.attached } +// Checks to see if list contains an address. +func containsAddr(list []tcpip.ProtocolAddress, item tcpip.ProtocolAddress) bool { + for _, i := range list { + if i == item { + return true + } + } + + return false +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. @@ -1704,7 +1715,7 @@ func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, sub // Trying the next address should always fail since it is outside the range. if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0) + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = 0", fakeNetNumber, tcpip.Address(addrBytes), gotNicID) } } @@ -3089,6 +3100,13 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { const nicID = 1 + broadcastAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 32, + }, + } e := loopback.New() s := stack.New(stack.Options{ @@ -3099,49 +3117,41 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) } - allStackAddrs := s.AllAddresses() - allNICAddrs, ok := allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) + } } // Enabling the NIC should add the IPv4 broadcast address. if err := s.EnableNIC(nicID); err != nil { t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } - allStackAddrs = s.AllAddresses() - allNICAddrs, ok = allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 1 { - t.Fatalf("got len(allNICAddrs) = %d, want = 1", l) - } - want := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: header.IPv4Broadcast, - PrefixLen: 32, - }, - } - if allNICAddrs[0] != want { - t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want) + + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if !containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, want = %+v", allNICAddrs, broadcastAddr) + } } // Disabling the NIC should remove the IPv4 broadcast address. if err := s.DisableNIC(nicID); err != nil { t.Fatalf("s.DisableNIC(%d): %s", nicID, err) } - allStackAddrs = s.AllAddresses() - allNICAddrs, ok = allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) + } } } @@ -3189,50 +3199,93 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { } } -func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { +func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) { const nicID = 1 - e := loopback.New() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - }) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + }{ + { + name: "IPv6 All-Nodes", + proto: header.IPv6ProtocolNumber, + addr: header.IPv6AllNodesMulticastAddress, + }, + { + name: "IPv4 All-Systems", + proto: header.IPv4ProtocolNumber, + addr: header.IPv4AllSystems, + }, } - // Should not be in the IPv6 all-nodes multicast group yet because the NIC has - // not been enabled yet. - isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } - // The all-nodes multicast group should be joined when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if !isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress) - } + // Should not be in the multicast group yet because the NIC has not been + // enabled yet. + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } - // The all-nodes multicast group should be left when the NIC is disabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) + } + + // The multicast group should be left when the NIC is disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } + + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) + } + + // Leaving the group before disabling the NIC should not cause an error. + if err := s.LeaveGroup(test.proto, nicID, test.addr); err != nil { + t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", test.proto, nicID, test.addr, err) + } + + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } + }) } } diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD new file mode 100644 index 000000000..7fff30462 --- /dev/null +++ b/pkg/tcpip/tests/integration/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_test") + +package(licenses = ["notice"]) + +go_test( + name = "integration_test", + size = "small", + srcs = ["multicast_broadcast_test.go"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go new file mode 100644 index 000000000..d9b2d147a --- /dev/null +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -0,0 +1,274 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const defaultMTU = 1280 + +// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some +// multicast or broadcast address. +func TestIncomingMulticastAndBroadcast(t *testing.T) { + const ( + nicID = 1 + remotePort = 5555 + localPort = 80 + ttl = 255 + ) + + data := []byte{1, 2, 3, 4} + + // Local IPv4 subnet: 192.168.1.58/24 + ipv4Addr := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 24, + } + ipv4Subnet := ipv4Addr.Subnet() + ipv4SubnetBcast := ipv4Subnet.Broadcast() + + // Local IPv6 subnet: 200a::1/64 + ipv6Addr := tcpip.AddressWithPrefix{ + Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + PrefixLen: 64, + } + ipv6Subnet := ipv6Addr.Subnet() + ipv6SubnetBcast := ipv6Subnet.Broadcast() + + // Remote addrs. + remoteIPv4Addr := tcpip.Address("\x64\x0a\x7b\x18") + remoteIPv6Addr := tcpip.Address("\x20\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + + rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) { + payloadLen := header.UDPMinimumSize + len(data) + totalLen := header.IPv4MinimumSize + payloadLen + hdr := buffer.NewPrependable(totalLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv4Addr, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(udp.ProtocolNumber), + TTL: ttl, + SrcAddr: remoteIPv4Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) { + payloadLen := header.UDPMinimumSize + len(data) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv6Addr, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLen), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + tests := []struct { + name string + bindAddr tcpip.Address + dstAddr tcpip.Address + expectRx bool + }{ + { + name: "IPv4 unicast binding to unicast", + bindAddr: ipv4Addr.Address, + dstAddr: ipv4Addr.Address, + expectRx: true, + }, + { + name: "IPv4 unicast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4Addr.Address, + expectRx: false, + }, + { + name: "IPv4 unicast binding to wildcard", + dstAddr: ipv4Addr.Address, + expectRx: true, + }, + + { + name: "IPv4 directed broadcast binding to subnet broadcast", + bindAddr: ipv4SubnetBcast, + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + { + name: "IPv4 directed broadcast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4SubnetBcast, + expectRx: false, + }, + { + name: "IPv4 directed broadcast binding to wildcard", + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + + { + name: "IPv4 broadcast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: header.IPv4Broadcast, + expectRx: true, + }, + { + name: "IPv4 broadcast binding to subnet broadcast", + bindAddr: ipv4SubnetBcast, + dstAddr: header.IPv4Broadcast, + expectRx: false, + }, + { + name: "IPv4 broadcast binding to wildcard", + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + + { + name: "IPv4 all-systems multicast binding to all-systems multicast", + bindAddr: header.IPv4AllSystems, + dstAddr: header.IPv4AllSystems, + expectRx: true, + }, + { + name: "IPv4 all-systems multicast binding to wildcard", + dstAddr: header.IPv4AllSystems, + expectRx: true, + }, + { + name: "IPv4 all-systems multicast binding to unicast", + bindAddr: ipv4Addr.Address, + dstAddr: header.IPv4AllSystems, + expectRx: false, + }, + + // IPv6 has no notion of a broadcast. + { + name: "IPv6 unicast binding to wildcard", + dstAddr: ipv6Addr.Address, + expectRx: true, + }, + { + name: "IPv6 broadcast-like address binding to wildcard", + dstAddr: ipv6SubnetBcast, + expectRx: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err) + } + ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr} + if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err) + } + + var netproto tcpip.NetworkProtocolNumber + var rxUDP func(*channel.Endpoint, tcpip.Address) + switch l := len(test.dstAddr); l { + case header.IPv4AddressSize: + netproto = header.IPv4ProtocolNumber + rxUDP = rxIPv4UDP + case header.IPv6AddressSize: + netproto = header.IPv6ProtocolNumber + rxUDP = rxIPv6UDP + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + wq := waiter.Queue{} + ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("ep.Bind(%+v): %s", bindAddr, err) + } + + rxUDP(e, test.dstAddr) + if gotPayload, _, err := ep.Read(nil); test.expectRx { + if err != nil { + t.Fatalf("Read(nil): %s", err) + } + if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + } + }) + } +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index e860ee484..234fb95ce 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -40,6 +40,8 @@ go_library( "endpoint_state.go", "forwarder.go", "protocol.go", + "rack.go", + "rack_state.go", "rcv.go", "rcv_state.go", "reno.go", @@ -83,6 +85,7 @@ go_test( "dual_stack_test.go", "sack_scoreboard_test.go", "tcp_noracedetector_test.go", + "tcp_rack_test.go", "tcp_sack_test.go", "tcp_test.go", "tcp_timestamp_test.go", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 6e00e5526..913ea6535 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -521,7 +521,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { synOpts := header.TCPSynOptions{ WS: -1, TS: opts.TS, - TSVal: tcpTimeStamp(timeStampOffset()), + TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), TSEcr: opts.TSVal, MSS: mssForRoute(&s.route), } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 6e5e55b6f..8dd759ba2 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1166,13 +1166,18 @@ func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { return nil } -// handleSegment handles a given segment and notifies the worker goroutine if -// if the connection should be terminated. -func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { - // Invoke the tcp probe if installed. +func (e *endpoint) probeSegment() { if e.probe != nil { e.probe(e.completeState()) } +} + +// handleSegment handles a given segment and notifies the worker goroutine if +// if the connection should be terminated. +func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { + // Invoke the tcp probe if installed. The tcp probe function will update + // the TCPEndpointState after the segment is processed. + defer e.probeSegment() if s.flagIsSet(header.TCPFlagRst) { if ok, err := e.handleReset(s); !ok { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 682687ebe..39ea38fe6 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2692,15 +2692,14 @@ func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { // timestamp returns the timestamp value to be used in the TSVal field of the // timestamp option for outgoing TCP segments for a given endpoint. func (e *endpoint) timestamp() uint32 { - return tcpTimeStamp(e.tsOffset) + return tcpTimeStamp(time.Now(), e.tsOffset) } // tcpTimeStamp returns a timestamp offset by the provided offset. This is // not inlined above as it's used when SYN cookies are in use and endpoint // is not created at the time when the SYN cookie is sent. -func tcpTimeStamp(offset uint32) uint32 { - now := time.Now() - return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset +func tcpTimeStamp(curTime time.Time, offset uint32) uint32 { + return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset } // timeStampOffset returns a randomized timestamp offset to be used when sending @@ -2843,6 +2842,14 @@ func (e *endpoint) completeState() stack.TCPEndpointState { WEst: cubic.wEst, } } + + rc := e.snd.rc + s.Sender.RACKState = stack.TCPRACKState{ + XmitTime: rc.xmitTime, + EndSequence: rc.endSequence, + FACK: rc.fack, + RTT: rc.rtt, + } return s } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index b34e47bbd..d9abb8d94 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -80,6 +80,25 @@ const ( // enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. type SACKEnabled bool +// Recovery is used by stack.(*Stack).TransportProtocolOption to +// set loss detection algorithm in TCP. +type Recovery int32 + +const ( + // RACKLossDetection indicates RACK is used for loss detection and + // recovery. + RACKLossDetection Recovery = 1 << iota + + // RACKStaticReoWnd indicates the reordering window should not be + // adjusted when DSACK is received. + RACKStaticReoWnd + + // RACKNoDupTh indicates RACK should not consider the classic three + // duplicate acknowledgements rule to mark the segments as lost. This + // is used when reordering is not detected. + RACKNoDupTh +) + // DelayEnabled is used by stack.(Stack*).TransportProtocolOption to // enable/disable Nagle's algorithm in TCP. type DelayEnabled bool @@ -161,6 +180,7 @@ func (s *synRcvdCounter) Threshold() uint64 { type protocol struct { mu sync.RWMutex sackEnabled bool + recovery Recovery delayEnabled bool sendBufferSize SendBufferSizeOption recvBufferSize ReceiveBufferSizeOption @@ -280,6 +300,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case Recovery: + p.mu.Lock() + p.recovery = Recovery(v) + p.mu.Unlock() + return nil + case DelayEnabled: p.mu.Lock() p.delayEnabled = bool(v) @@ -394,6 +420,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil + case *Recovery: + p.mu.RLock() + *v = Recovery(p.recovery) + p.mu.RUnlock() + return nil + case *DelayEnabled: p.mu.RLock() *v = DelayEnabled(p.delayEnabled) @@ -535,6 +567,7 @@ func NewProtocol() stack.TransportProtocol { minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, + recovery: RACKLossDetection, } p.dispatcher.init(runtime.GOMAXPROCS(0)) return &p diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go new file mode 100644 index 000000000..d969ca23a --- /dev/null +++ b/pkg/tcpip/transport/tcp/rack.go @@ -0,0 +1,82 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" + + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +// RACK is a loss detection algorithm used in TCP to detect packet loss and +// reordering using transmission timestamp of the packets instead of packet or +// sequence counts. To use RACK, SACK should be enabled on the connection. + +// rackControl stores the rack related fields. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-6.1 +// +// +stateify savable +type rackControl struct { + // xmitTime is the latest transmission timestamp of rackControl.seg. + xmitTime time.Time `state:".(unixTime)"` + + // endSequence is the ending TCP sequence number of rackControl.seg. + endSequence seqnum.Value + + // fack is the highest selectively or cumulatively acknowledged + // sequence. + fack seqnum.Value + + // rtt is the RTT of the most recently delivered packet on the + // connection (either cumulatively acknowledged or selectively + // acknowledged) that was not marked invalid as a possible spurious + // retransmission. + rtt time.Duration +} + +// Update will update the RACK related fields when an ACK has been received. +// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 +func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, offset uint32) { + rtt := time.Now().Sub(seg.xmitTime) + + // If the ACK is for a retransmitted packet, do not update if it is a + // spurious inference which is determined by below checks: + // 1. When Timestamping option is available, if the TSVal is less than the + // transmit time of the most recent retransmitted packet. + // 2. When RTT calculated for the packet is less than the smoothed RTT + // for the connection. + // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 + // step 2 + if seg.xmitCount > 1 { + if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 { + if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) { + return + } + } + if rtt < srtt { + return + } + } + + rc.rtt = rtt + // Update rc.xmitTime and rc.endSequence to the transmit time and + // ending sequence number of the packet which has been acknowledged + // most recently. + endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) { + rc.xmitTime = seg.xmitTime + rc.endSequence = endSeq + } +} diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go new file mode 100644 index 000000000..c9dc7e773 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rack_state.go @@ -0,0 +1,29 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveXmitTime is invoked by stateify. +func (rc *rackControl) saveXmitTime() unixTime { + return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()} +} + +// loadXmitTime is invoked by stateify. +func (rc *rackControl) loadXmitTime(unix unixTime) { + rc.xmitTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 5862c32f2..c55589c45 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -191,6 +191,10 @@ type sender struct { // cc is the congestion control algorithm in use for this sender. cc congestionControl + + // rc has the fields needed for implementing RACK loss detection + // algorithm. + rc rackControl } // rtt is a synchronization wrapper used to appease stateify. See the comment @@ -1272,21 +1276,21 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { // handleRcvdSegment is called when a segment is received; it is responsible for // updating the send-related state. -func (s *sender) handleRcvdSegment(seg *segment) { +func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Check if we can extract an RTT measurement from this ack. - if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) { + if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) { s.updateRTO(time.Now().Sub(s.rttMeasureTime)) s.rttMeasureSeqNum = s.sndNxt } // Update Timestamp if required. See RFC7323, section-4.3. - if s.ep.sendTSOk && seg.parsedOptions.TS { - s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber) + if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS { + s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber) } // Insert SACKBlock information into our scoreboard. if s.ep.sackPermitted { - for _, sb := range seg.parsedOptions.SACKBlocks { + for _, sb := range rcvdSeg.parsedOptions.SACKBlocks { // Only insert the SACK block if the following holds // true: // * SACK block acks data after the ack number in the @@ -1299,27 +1303,27 @@ func (s *sender) handleRcvdSegment(seg *segment) { // NOTE: This check specifically excludes DSACK blocks // which have start/end before sndUna and are used to // indicate spurious retransmissions. - if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) { + if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) { s.ep.scoreboard.Insert(sb) - seg.hasNewSACKInfo = true + rcvdSeg.hasNewSACKInfo = true } } s.SetPipe() } // Count the duplicates and do the fast retransmit if needed. - rtx := s.checkDuplicateAck(seg) + rtx := s.checkDuplicateAck(rcvdSeg) // Stash away the current window size. - s.sndWnd = seg.window + s.sndWnd = rcvdSeg.window - ack := seg.ackNumber + ack := rcvdSeg.ackNumber // Disable zero window probing if remote advertizes a non-zero receive // window. This can be with an ACK to the zero window probe (where the // acknumber refers to the already acknowledged byte) OR to any previously // unacknowledged segment. - if s.zeroWindowProbing && seg.window > 0 && + if s.zeroWindowProbing && rcvdSeg.window > 0 && (ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) { s.disableZeroWindowProbing() } @@ -1344,10 +1348,10 @@ func (s *sender) handleRcvdSegment(seg *segment) { // averaged RTT measurement only if the segment acknowledges // some new data, i.e., only if it advances the left edge of // the send window. - if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 { + if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 { // TSVal/Ecr values sent by Netstack are at a millisecond // granularity. - elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond + elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond s.updateRTO(elapsed) } @@ -1361,6 +1365,9 @@ func (s *sender) handleRcvdSegment(seg *segment) { ackLeft := acked originalOutstanding := s.outstanding + s.rtt.Lock() + srtt := s.rtt.srtt + s.rtt.Unlock() for ackLeft > 0 { // We use logicalLen here because we can have FIN // segments (which are always at the end of list) that @@ -1380,6 +1387,11 @@ func (s *sender) handleRcvdSegment(seg *segment) { s.writeNext = seg.Next() } + // Update the RACK fields if SACK is enabled. + if s.ep.sackPermitted { + s.rc.Update(seg, rcvdSeg, srtt, s.ep.tsOffset) + } + s.writeList.Remove(seg) // if SACK is enabled then Only reduce outstanding if @@ -1435,7 +1447,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { // that the window opened up, or the congestion window was inflated due // to a duplicate ack during fast recovery. This will also re-enable // the retransmit timer if needed. - if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo { + if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo { s.sendData() } } diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go new file mode 100644 index 000000000..e03f101e8 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -0,0 +1,74 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" +) + +// TestRACKUpdate tests the RACK related fields are updated when an ACK is +// received on a SACK enabled connection. +func TestRACKUpdate(t *testing.T) { + const maxPayload = 10 + const tsOptionSize = 12 + const maxTCPOptionSize = 40 + + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) + defer c.Cleanup() + + var xmitTime time.Time + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that the endpoint Sender.RACKState is what we expect. + if state.Sender.RACKState.XmitTime.Before(xmitTime) { + t.Fatalf("RACK transmit time failed to update when an ACK is received") + } + + gotSeq := state.Sender.RACKState.EndSequence + wantSeq := state.Sender.SndNxt + if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) { + t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq) + } + + if state.Sender.RACKState.RTT == 0 { + t.Fatalf("RACK RTT failed to update when an ACK is received") + } + }) + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + + data := buffer.NewView(maxPayload) + for i := range data { + data[i] = byte(i) + } + + // Write the data. + xmitTime = time.Now() + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + bytesRead := 0 + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + c.SendAck(790, bytesRead) + time.Sleep(200 * time.Millisecond) +} |