From 7bf1d426d5f8ba5a3ddfc8f0cd73cc12e8a83c16 Mon Sep 17 00:00:00 2001 From: Chris Kuiper Date: Wed, 4 Sep 2019 14:18:02 -0700 Subject: Handle subnet and broadcast addresses correctly with NIC.subnets This also renames "subnet" to "addressRange" to avoid any more confusion with an interface IP's subnet. Lastly, this also removes the Stack.ContainsSubnet(..) API since it isn't used by anyone. Plus the same information can be obtained from Stack.NICAddressRanges(). PiperOrigin-RevId: 267229843 --- pkg/tcpip/stack/nic.go | 63 +++++++++++----------- pkg/tcpip/stack/stack.go | 32 ++++------- pkg/tcpip/stack/stack_test.go | 121 ++++++++++++++++++++++++++---------------- pkg/tcpip/tcpip.go | 9 ++++ 4 files changed, 129 insertions(+), 96 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ae56e0ffd..43719085e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -36,13 +36,13 @@ type NIC struct { demux *transportDemuxer - mu sync.RWMutex - spoofing bool - promiscuous bool - primary map[tcpip.NetworkProtocolNumber]*ilist.List - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - subnets []tcpip.Subnet - mcastJoins map[NetworkEndpointID]int32 + mu sync.RWMutex + spoofing bool + promiscuous bool + primary map[tcpip.NetworkProtocolNumber]*ilist.List + endpoints map[NetworkEndpointID]*referencedNetworkEndpoint + addressRanges []tcpip.Subnet + mcastJoins map[NetworkEndpointID]int32 stats NICStats } @@ -224,7 +224,17 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // the caller or if the address is found in the NIC's subnets. createTempEP := spoofingOrPromiscuous if !createTempEP { - for _, sn := range n.subnets { + for _, sn := range n.addressRanges { + // Skip the subnet address. + if address == sn.ID() { + continue + } + // For now just skip the broadcast address, until we support it. + // FIXME(b/137608825): Add support for sending/receiving directed + // (subnet) broadcast. + if address == sn.Broadcast() { + continue + } if sn.Contains(address) { createTempEP = true break @@ -381,45 +391,38 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { return addrs } -// AddSubnet adds a new subnet to n, so that it starts accepting packets -// targeted at the given address and network protocol. -func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { +// AddAddressRange adds a range of addresses to n, so that it starts accepting +// packets targeted at the given addresses and network protocol. The range is +// given by a subnet address, and all addresses contained in the subnet are +// used except for the subnet address itself and the subnet's broadcast +// address. +func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { n.mu.Lock() - n.subnets = append(n.subnets, subnet) + n.addressRanges = append(n.addressRanges, subnet) n.mu.Unlock() } -// RemoveSubnet removes the given subnet from n. -func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) { +// RemoveAddressRange removes the given address range from n. +func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Lock() // Use the same underlying array. - tmp := n.subnets[:0] - for _, sub := range n.subnets { + tmp := n.addressRanges[:0] + for _, sub := range n.addressRanges { if sub != subnet { tmp = append(tmp, sub) } } - n.subnets = tmp + n.addressRanges = tmp n.mu.Unlock() } -// ContainsSubnet reports whether this NIC contains the given subnet. -func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool { - for _, s := range n.Subnets() { - if s == subnet { - return true - } - } - return false -} - // Subnets returns the Subnets associated with this NIC. -func (n *NIC) Subnets() []tcpip.Subnet { +func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() - sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints)) + sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints)) for nid := range n.endpoints { sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) if err != nil { @@ -429,7 +432,7 @@ func (n *NIC) Subnets() []tcpip.Subnet { } sns = append(sns, sn) } - return append(sns, n.subnets...) + return append(sns, n.addressRanges...) } func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 1d5e84a8b..6beca6ae8 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -702,14 +702,14 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { } // NICSubnets returns a map of NICIDs to their associated subnets. -func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet { +func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() defer s.mu.RUnlock() nics := map[tcpip.NICID][]tcpip.Subnet{} for id, nic := range s.nics { - nics[id] = append(nics[id], nic.Subnets()...) + nics[id] = append(nics[id], nic.AddressRanges()...) } return nics } @@ -810,45 +810,35 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc return nic.AddAddress(protocolAddress, peb) } -// AddSubnet adds a subnet range to the specified NIC. -func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { +// AddAddressRange adds a range of addresses to the specified NIC. The range is +// given by a subnet address, and all addresses contained in the subnet are +// used except for the subnet address itself and the subnet's broadcast +// address. +func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - nic.AddSubnet(protocol, subnet) + nic.AddAddressRange(protocol, subnet) return nil } return tcpip.ErrUnknownNICID } -// RemoveSubnet removes the subnet range from the specified NIC. -func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { +// RemoveAddressRange removes the range of addresses from the specified NIC. +func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - nic.RemoveSubnet(subnet) + nic.RemoveAddressRange(subnet) return nil } return tcpip.ErrUnknownNICID } -// ContainsSubnet reports whether the specified NIC contains the specified -// subnet. -func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) { - s.mu.RLock() - defer s.mu.RUnlock() - - if nic, ok := s.nics[id]; ok { - return nic.ContainsSubnet(subnet), nil - } - - return false, tcpip.ErrUnknownNICID -} - // RemoveAddress removes an existing network-layer address from the specified // NIC. func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 4debd1eec..c6a8160af 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1128,8 +1128,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } } -// Set the subnet, then check that packet is delivered. -func TestSubnetAcceptsMatchingPacket(t *testing.T) { +// Add a range of addresses, then check that a packet is delivered. +func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, linkEP := channel.New(10, defaultMTU, "") @@ -1155,14 +1155,45 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) + if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } testRecv(t, fakeNet, localAddrByte, linkEP, buf) } -// Set the subnet, then check that CheckLocalAddress returns the correct NIC. +func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { + t.Helper() + + // Loop over all addresses and check them. + numOfAddresses := 1 << uint(8-subnet.Prefix()) + if numOfAddresses < 1 || numOfAddresses > 255 { + t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) + } + + addrBytes := []byte(subnet.ID()) + for i := 0; i < numOfAddresses; i++ { + addr := tcpip.Address(addrBytes) + wantNicID := nicID + // The subnet and broadcast addresses are skipped. + if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() { + wantNicID = 0 + } + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID) + } + addrBytes[0]++ + } + + // Trying the next address should always fail since it is outside the range. + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0) + } +} + +// Set a range of addresses, then remove it again, and check at each step that +// CheckLocalAddress returns the correct NIC for each address or zero if not +// existent. func TestCheckLocalAddressForSubnet(t *testing.T) { const nicID tcpip.NICID = 1 s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) @@ -1181,35 +1212,28 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { } subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) - if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) - } - // Loop over all subnet addresses and check them. - numOfAddresses := 1 << uint(8-subnet.Prefix()) - if numOfAddresses < 1 || numOfAddresses > 255 { - t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) - } - addr := []byte(subnet.ID()) - for i := 0; i < numOfAddresses; i++ { - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID) - } - addr[0]++ + testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) + + if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } - // Trying the next address should fail since it is outside the subnet range. - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0) + testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */) + + if err := s.RemoveAddressRange(nicID, subnet); err != nil { + t.Fatal("RemoveAddressRange failed:", err) } + + testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) } -// Set destination outside the subnet, then check it doesn't get delivered. -func TestSubnetRejectsNonmatchingPacket(t *testing.T) { +// Set a range of addresses, then send a packet to a destination outside the +// range and then check it doesn't get delivered. +func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, linkEP := channel.New(10, defaultMTU, "") @@ -1235,8 +1259,8 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) + if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) } @@ -1281,7 +1305,20 @@ func TestNetworkOptions(t *testing.T) { } } -func TestSubnetAddRemove(t *testing.T) { +func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool { + ranges, ok := s.NICAddressRanges()[id] + if !ok { + return false + } + for _, r := range ranges { + if r == addrRange { + return true + } + } + return false +} + +func TestAddresRangeAddRemove(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) id, _ := channel.New(10, defaultMTU, "") if err := s.CreateNIC(1, id); err != nil { @@ -1290,35 +1327,29 @@ func TestSubnetAddRemove(t *testing.T) { addr := tcpip.Address("\x01\x01\x01\x01") mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) - subnet, err := tcpip.NewSubnet(addr, mask) + addrRange, err := tcpip.NewSubnet(addr, mask) if err != nil { t.Fatal("NewSubnet failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if contained { - t.Fatal("got s.ContainsSubnet(...) = true, want = false") + if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) + if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil { + t.Fatal("AddAddressRange failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if !contained { - t.Fatal("got s.ContainsSubnet(...) = false, want = true") + if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } - if err := s.RemoveSubnet(1, subnet); err != nil { - t.Fatal("RemoveSubnet failed:", err) + if err := s.RemoveAddressRange(1, addrRange); err != nil { + t.Fatal("RemoveAddressRange failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if contained { - t.Fatal("got s.ContainsSubnet(...) = true, want = false") + if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 05aa42c98..418e771d2 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -219,6 +219,15 @@ func (s *Subnet) Mask() AddressMask { return s.mask } +// Broadcast returns the subnet's broadcast address. +func (s *Subnet) Broadcast() Address { + addr := []byte(s.address) + for i := range addr { + addr[i] |= ^s.mask[i] + } + return Address(addr) +} + // NICID is a number that uniquely identifies a NIC. type NICID int32 -- cgit v1.2.3 From 3dc3cffb2ddb09291822fc37d00f2e62f41956b5 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 4 Sep 2019 14:58:32 -0700 Subject: Fix RST generation bugs. There are a few cases addressed by this change - We no longer generate a RST in response to a RST packet. - When we receive a RST we cleanup and release all reservations immediately as the connection is now aborted. - An ACK received by a listening socket generates a RST when SYN cookies are not in-use. The only reason an ACK should land at the listening socket is if we are using SYN cookies otherwise the goroutine for the handshake in progress should have gotten the packet and it should never have arrived at the listening endpoint. - Also fixes the error returned when a connection times out due to a Keepalive timer expiration from ECONNRESET to a ETIMEDOUT. PiperOrigin-RevId: 267238427 --- pkg/tcpip/transport/tcp/accept.go | 30 ++++++++ pkg/tcpip/transport/tcp/connect.go | 19 +++-- pkg/tcpip/transport/tcp/tcp_test.go | 138 ++++++++++++++++++++++++++++++++++-- 3 files changed, 176 insertions(+), 11 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index e9c5099ea..0802e984e 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -143,6 +143,15 @@ func decSynRcvdCount() { synRcvdCount.Unlock() } +// synCookiesInUse() returns true if the synRcvdCount is greater than +// SynRcvdCountThreshold. +func synCookiesInUse() bool { + synRcvdCount.Lock() + v := synRcvdCount.value + synRcvdCount.Unlock() + return v >= SynRcvdCountThreshold +} + // newListenContext creates a new listen context. func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ @@ -446,6 +455,27 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { return } + if !synCookiesInUse() { + // Send a reset as this is an ACK for which there is no + // half open connections and we are not using cookies + // yet. + // + // The only time we should reach here when a connection + // was opened and closed really quickly and a delayed + // ACK was received from the sender. + replyWithReset(s) + return + } + + // Since SYN cookies are in use this is potentially an ACK to a + // SYN-ACK we sent but don't have a half open connection state + // as cookies are being used to protect against a potential SYN + // flood. In such cases validate the cookie and if valid create + // a fully connected endpoint and deliver to the accept queue. + // + // If not, silently drop the ACK to avoid leaking information + // when under a potential syn flood attack. + // // Validate the cookie. data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1) if !ok || int(data) >= len(mssTable) { diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 00d2ae524..21038a65a 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -720,13 +720,18 @@ func (e *endpoint) handleClose() *tcpip.Error { return nil } -// resetConnectionLocked sends a RST segment and puts the endpoint in an error -// state with the given error code. This method must only be called from the -// protocol goroutine. +// resetConnectionLocked puts the endpoint in an error state with the given +// error code and sends a RST if and only if the error is not ErrConnectionReset +// indicating that the connection is being reset due to receiving a RST. This +// method must only be called from the protocol goroutine. func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { - e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + // Only send a reset if the connection is being aborted for a reason + // other than receiving a reset. e.state = StateError e.hardError = err + if err != tcpip.ErrConnectionReset { + e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + } } // completeWorkerLocked is called by the worker goroutine when it's about to @@ -806,7 +811,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { if e.keepalive.unacked >= e.keepalive.count { e.keepalive.Unlock() - return tcpip.ErrConnectionReset + return tcpip.ErrTimeout } // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with @@ -1068,6 +1073,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.workMu.Lock() if err := funcs[v].f(); err != nil { e.mu.Lock() + // Ensure we release all endpoint registration and route + // references as the connection is now in an error + // state. + e.workerCleanup = true e.resetConnectionLocked(err) // Lock released below. epilogue() diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index f79b8ec5f..32bb45224 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -190,21 +190,146 @@ func TestTCPResetsSentIncrement(t *testing.T) { } } +// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates +// a RST if an ACK is received on the listening socket for which there is no +// active handshake in progress and we are not using SYN cookies. +func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + c.GetPacket() + + // Now resend the same ACK, this ACK should generate a RST as there + // should be no endpoint in SYN-RCVD state and we are not using + // syn-cookies yet. The reason we send the same ACK is we need a valid + // cookie(IRS) generated by the netstack without which the ACK will be + // rejected. + c.SendPacket(nil, ackHeaders) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck))) +} + func TestTCPResetsReceivedIncrement(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() stats := c.Stack().Stats() want := stats.TCP.ResetsReceived.Value() + 1 - ackNum := seqnum.Value(789) + iss := seqnum.Value(789) rcvWnd := seqnum.Size(30000) - c.CreateConnected(ackNum, rcvWnd, nil) + c.CreateConnected(iss, rcvWnd, nil) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, - SeqNum: c.IRS.Add(2), - AckNum: ackNum.Add(2), + SeqNum: iss.Add(1), + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + Flags: header.TCPFlagRst, + }) + + if got := stats.TCP.ResetsReceived.Value(); got != want { + t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want) + } +} + +func TestTCPResetsDoNotGenerateResets(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + want := stats.TCP.ResetsReceived.Value() + 1 + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.CreateConnected(iss, rcvWnd, nil) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: iss.Add(1), + AckNum: c.IRS.Add(1), RcvWnd: rcvWnd, Flags: header.TCPFlagRst, }) @@ -212,6 +337,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) { if got := stats.TCP.ResetsReceived.Value(); got != want { t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want) } + c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond) } func TestActiveHandshake(t *testing.T) { @@ -3459,8 +3585,8 @@ func TestKeepalive(t *testing.T) { ), ) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout) } } -- cgit v1.2.3