diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/buffer/view.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/buffer/view_test.go | 36 | ||||
-rw-r--r-- | pkg/tcpip/header/tcp.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 250 | ||||
-rw-r--r-- | pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables.go | 42 | ||||
-rw-r--r-- | pkg/tcpip/stack/iptables_types.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/stack/ndp_test.go | 24 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 36 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv_test.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/snd.go | 208 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 4 |
18 files changed, 554 insertions, 123 deletions
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index f01217c91..9a3c5d6c3 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -59,6 +59,9 @@ func (v *View) Reader() bytes.Reader { // ToVectorisedView returns a VectorisedView containing the receiver. func (v View) ToVectorisedView() VectorisedView { + if len(v) == 0 { + return VectorisedView{} + } return NewVectorisedView(len(v), []View{v}) } @@ -229,6 +232,9 @@ func (vv *VectorisedView) Append(vv2 VectorisedView) { // AppendView appends the given view into this vectorised view. func (vv *VectorisedView) AppendView(v View) { + if len(v) == 0 { + return + } vv.views = append(vv.views, v) vv.size += len(v) } diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index c56795c7b..726e54de9 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -483,3 +483,39 @@ func TestPullUp(t *testing.T) { } } } + +func TestToVectorisedView(t *testing.T) { + testCases := []struct { + in View + want VectorisedView + }{ + {nil, VectorisedView{}}, + {View{}, VectorisedView{}}, + {View{'a'}, VectorisedView{size: 1, views: []View{{'a'}}}}, + } + for _, tc := range testCases { + if got, want := tc.in.ToVectorisedView(), tc.want; !reflect.DeepEqual(got, want) { + t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want) + } + } +} + +func TestAppendView(t *testing.T) { + testCases := []struct { + vv VectorisedView + in View + want VectorisedView + }{ + {VectorisedView{}, nil, VectorisedView{}}, + {VectorisedView{}, View{}, VectorisedView{}}, + {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, nil, VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}}, + {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, View{}, VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}}, + {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, View{'e'}, VectorisedView{[]View{{'a', 'b', 'c', 'd'}, {'e'}}, 5}}, + } + for _, tc := range testCases { + tc.vv.AppendView(tc.in) + if got, want := tc.vv, tc.want; !reflect.DeepEqual(got, want) { + t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want) + } + } +} diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 21581257b..29454c4b9 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -609,5 +609,8 @@ func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.V } // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming // the payload, so we'll accept any payload that overlaps the receieve window. - return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThan(rcvAcc) + // segSeq < rcvAcc is more correct according to RFC, however, Linux does it + // differently, it uses segSeq <= rcvAcc, we'd want to keep the same behavior + // as Linux. + return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThanEq(rcvAcc) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 880ea7de2..78420d6e6 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -34,5 +34,6 @@ go_test( "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", + "@com_github_google_go-cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 9db42b2a4..64046cbbf 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -249,10 +249,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + nicName := e.stack.FindNICNameFromID(e.NICID()) // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Output, &pkt, gso, r, ""); !ok { + if ok := ipt.Check(stack.Output, &pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. return nil } @@ -319,10 +320,11 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe pkt = pkt.Next() } + nicName := e.stack.FindNICNameFromID(e.NICID()) // iptables filtering. All packets that reach here are locally // generated. ipt := e.stack.IPTables() - dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r) + dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -445,7 +447,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. ipt := e.stack.IPTables() - if ok := ipt.Check(stack.Input, &pkt, nil, nil, ""); !ok { + if ok := ipt.Check(stack.Input, &pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 5a864d832..36035c820 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -20,6 +20,7 @@ import ( "math/rand" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -473,3 +474,252 @@ func TestInvalidFragments(t *testing.T) { }) } } + +// TestReceiveFragments feeds fragments in through the incoming packet path to +// test reassembly +func TestReceiveFragments(t *testing.T) { + const addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1 + const addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2 + const nicID = 1 + + // Build and return a UDP header containing payload. + udpGen := func(payloadLen int, multiplier uint8) buffer.View { + payload := buffer.NewView(payloadLen) + for i := 0; i < len(payload); i++ { + payload[i] = uint8(i) * multiplier + } + + udpLength := header.UDPMinimumSize + len(payload) + + hdr := buffer.NewPrependable(udpLength) + u := header.UDP(hdr.Prepend(udpLength)) + u.Encode(&header.UDPFields{ + SrcPort: 5555, + DstPort: 80, + Length: uint16(udpLength), + }) + copy(u.Payload(), payload) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength)) + sum = header.Checksum(payload, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + return hdr.View() + } + + // UDP header plus a payload of 0..256 + ipv4Payload1 := udpGen(256, 1) + udpPayload1 := ipv4Payload1[header.UDPMinimumSize:] + // UDP header plus a payload of 0..256 in increments of 2. + ipv4Payload2 := udpGen(128, 2) + udpPayload2 := ipv4Payload2[header.UDPMinimumSize:] + + type fragmentData struct { + id uint16 + flags uint8 + fragmentOffset uint16 + payload buffer.View + } + + tests := []struct { + name string + fragments []fragmentData + expectedPayloads [][]byte + }{ + { + name: "No fragmentation", + fragments: []fragmentData{ + { + id: 1, + flags: 0, + fragmentOffset: 0, + payload: ipv4Payload1, + }, + }, + expectedPayloads: [][]byte{udpPayload1}, + }, + { + name: "More fragments without payload", + fragments: []fragmentData{ + { + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1, + }, + }, + expectedPayloads: nil, + }, + { + name: "Non-zero fragment offset without payload", + fragments: []fragmentData{ + { + id: 1, + flags: 0, + fragmentOffset: 8, + payload: ipv4Payload1, + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments", + fragments: []fragmentData{ + { + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1[:64], + }, + { + id: 1, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload1[64:], + }, + }, + expectedPayloads: [][]byte{udpPayload1}, + }, + { + name: "Second fragment has MoreFlags set", + fragments: []fragmentData{ + { + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1[:64], + }, + { + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 64, + payload: ipv4Payload1[64:], + }, + }, + expectedPayloads: nil, + }, + { + name: "Two fragments with different IDs", + fragments: []fragmentData{ + { + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1[:64], + }, + { + id: 2, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload1[64:], + }, + }, + expectedPayloads: nil, + }, + { + name: "Two interleaved fragmented packets", + fragments: []fragmentData{ + { + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload1[:64], + }, + { + id: 2, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload2[:64], + }, + { + id: 1, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload1[64:], + }, + { + id: 2, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload2[64:], + }, + }, + expectedPayloads: [][]byte{udpPayload1, udpPayload2}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Setup a stack and endpoint. + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00")) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + } + + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s", bindAddr, err) + } + + // Prepare and send the fragments. + for _, frag := range test.fragments { + hdr := buffer.NewPrependable(header.IPv4MinimumSize) + + // Serialize IPv4 fixed header. + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: header.IPv4MinimumSize + uint16(len(frag.payload)), + ID: frag.id, + Flags: frag.flags, + FragmentOffset: frag.fragmentOffset, + TTL: 64, + Protocol: uint8(header.UDPProtocolNumber), + SrcAddr: addr1, + DstAddr: addr2, + }) + + vv := hdr.View().ToVectorisedView() + vv.AppendView(frag.payload) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.PacketBuffer{ + Data: vv, + }) + } + + if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want { + t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) + } + + for i, expectedPayload := range test.expectedPayloads { + gotPayload, _, err := ep.Read(nil) + if err != nil { + t.Fatalf("(i=%d) Read(nil): %s", i, err) + } + if diff := cmp.Diff(buffer.View(expectedPayload), gotPayload); diff != "" { + t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) + } + } + + if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + }) + } +} diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go index 8b4213eec..d199ded6a 100644 --- a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go +++ b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Code generated by "stringer -type=DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. +// Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT. package stack @@ -22,9 +22,9 @@ func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} - _ = x[DHCPv6NoConfiguration-0] - _ = x[DHCPv6ManagedAddress-1] - _ = x[DHCPv6OtherConfigurations-2] + _ = x[DHCPv6NoConfiguration-1] + _ = x[DHCPv6ManagedAddress-2] + _ = x[DHCPv6OtherConfigurations-3] } const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations" @@ -32,8 +32,9 @@ const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAd var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66} func (i DHCPv6ConfigurationFromNDPRA) String() string { + i -= 1 if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) { - return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i), 10) + ")" + return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i+1), 10) + ")" } return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]] } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 7c3c47d50..443423b3c 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "strings" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -178,7 +179,7 @@ const ( // dropped. // // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool { // Packets are manipulated only if connection and matching // NAT rule exists. it.connections.HandlePacket(pkt, hook, gso, r) @@ -187,7 +188,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr for _, tablename := range it.Priorities[hook] { table := it.Tables[tablename] ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -228,10 +229,10 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, gso, r, ""); !ok { + if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -251,11 +252,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * // Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a // precondition. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { case RuleAccept: return chainAccept @@ -272,7 +273,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -298,7 +299,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a // precondition. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // If pkt.NetworkHeader hasn't been set yet, it will be contained in @@ -313,7 +314,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // Check whether the packet matches the IP header filter. - if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader)) { + if !filterMatch(rule.Filter, header.IPv4(pkt.NetworkHeader), hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -335,7 +336,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) } -func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { +func filterMatch(filter IPHeaderFilter, hdr header.IPv4, hook Hook, nicName string) bool { // TODO(gvisor.dev/issue/170): Support other fields of the filter. // Check the transport protocol. if filter.Protocol != 0 && filter.Protocol != hdr.TransportProtocol() { @@ -355,5 +356,26 @@ func filterMatch(filter IPHeaderFilter, hdr header.IPv4) bool { return false } + // Check the output interface. + // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING + // hooks after supported. + if hook == Output { + n := len(filter.OutputInterface) + if n == 0 { + return true + } + + // If the interface name ends with '+', any interface which begins + // with the name should be matched. + ifName := filter.OutputInterface + matches = true + if strings.HasSuffix(ifName, "+") { + matches = strings.HasPrefix(nicName, ifName[:n-1]) + } else { + matches = nicName == ifName + } + return filter.OutputInterfaceInvert != matches + } + return true } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 1bb0ba1bd..fe06007ae 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -158,6 +158,19 @@ type IPHeaderFilter struct { // true the filter will match packets that fail the destination // comparison. DstInvert bool + + // OutputInterface matches the name of the outgoing interface for the + // packet. + OutputInterface string + + // OutputInterfaceMask masks the characters of the interface name when + // comparing with OutputInterface. + OutputInterfaceMask string + + // OutputInterfaceInvert inverts the meaning of outgoing interface check, + // i.e. when true the filter will match packets that fail the outgoing + // interface comparison. + OutputInterfaceInvert bool } // A Matcher is the interface for matching packets. diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index 15343acbc..526c7d6ff 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -199,9 +199,11 @@ var ( type DHCPv6ConfigurationFromNDPRA int const ( + _ DHCPv6ConfigurationFromNDPRA = iota + // DHCPv6NoConfiguration indicates that no configurations are available via // DHCPv6. - DHCPv6NoConfiguration DHCPv6ConfigurationFromNDPRA = iota + DHCPv6NoConfiguration // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6. // @@ -315,9 +317,6 @@ type NDPDispatcher interface { // OnDHCPv6Configuration will be called with an updated configuration that is // available via DHCPv6 for a specified NIC. // - // NDPDispatcher assumes that the initial configuration available by DHCPv6 is - // DHCPv6NoConfiguration. - // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA) @@ -1808,6 +1807,8 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { if got := len(ndp.defaultRouters); got != 0 { panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got)) } + + ndp.dhcpv6Configuration = 0 } // startSolicitingRouters starts soliciting routers, as per RFC 4861 section diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 67f012840..b3d174cdd 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -4888,7 +4888,12 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { } } - // The initial DHCPv6 configuration should be stack.DHCPv6NoConfiguration. + // Even if the first RA reports no DHCPv6 configurations are available, the + // dispatcher should get an event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) + expectDHCPv6Event(stack.DHCPv6NoConfiguration) + // Receiving the same update again should not result in an event to the + // dispatcher. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false)) expectNoDHCPv6Event() @@ -4896,8 +4901,6 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { // Configurations. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectDHCPv6Event(stack.DHCPv6OtherConfigurations) - // Receiving the same update again should not result in an event to the - // NDPDispatcher. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() @@ -4933,6 +4936,21 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) { expectDHCPv6Event(stack.DHCPv6OtherConfigurations) e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) expectNoDHCPv6Event() + + // Cycling the NIC should cause the last DHCPv6 configuration to be cleared. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + // Receive an RA that updates the DHCPv6 configuration to Other + // Configurations. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectDHCPv6Event(stack.DHCPv6OtherConfigurations) + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true)) + expectNoDHCPv6Event() } // TestRouterSolicitation tests the initial Router Solicitations that are sent diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8f4c1fe42..54103fdb3 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -1233,7 +1233,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address); !ok { + if ok := ipt.Check(Prerouting, &pkt, nil, nil, address.Address, ""); !ok { // iptables is telling us to drop the packet. return } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index e33fae4eb..b39ffa9fb 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1898,9 +1898,23 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres nic.mu.RLock() defer nic.mu.RUnlock() - // An endpoint with this id exists, check if it can be used and return it. + // An endpoint with this id exists, check if it can be + // used and return it. return ref.ep, nil } } return nil, tcpip.ErrBadAddress } + +// FindNICNameFromID returns the name of the nic for the given NICID. +func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { + s.mu.Lock() + defer s.mu.Unlock() + + nic, ok := s.nics[id] + if !ok { + return "" + } + + return nic.Name() +} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 6fe97fefd..dd89a292a 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -70,7 +70,16 @@ func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale // acceptable checks if the segment sequence number range is acceptable // according to the table on page 26 of RFC 793. func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { - return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvAcc) + // r.rcvWnd could be much larger than the window size we advertised in our + // outgoing packets, we should use what we have advertised for acceptability + // test. + scaledWindowSize := r.rcvWnd >> r.rcvWndScale + if scaledWindowSize > 0xffff { + // This is what we actually put in the Window field. + scaledWindowSize = 0xffff + } + advertisedWindowSize := scaledWindowSize << r.rcvWndScale + return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize)) } // getSendParams returns the parameters needed by the sender when building @@ -259,7 +268,14 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // If we are in one of the shutdown states then we need to do // additional checks before we try and process the segment. switch state { - case StateCloseWait, StateClosing, StateLastAck: + case StateCloseWait: + // If the ACK acks something not yet sent then we send an ACK. + if r.ep.snd.sndNxt.LessThan(s.ackNumber) { + r.ep.snd.sendAck() + return true, nil + } + fallthrough + case StateClosing, StateLastAck: if !s.sequenceNumber.LessThanEq(r.rcvNxt) { // Just drop the segment as we have // already received a FIN and this @@ -276,7 +292,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // SHUT_RD) then any data past the rcvNxt should // trigger a RST. endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) - if rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) { return true, tcpip.ErrConnectionAborted } if state == StateFinWait1 { @@ -329,13 +345,6 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { state := r.ep.EndpointState() closed := r.ep.closed - if state != StateEstablished { - drop, err := r.handleRcvdSegmentClosing(s, state, closed) - if drop || err != nil { - return drop, err - } - } - segLen := seqnum.Size(s.data.Size()) segSeq := s.sequenceNumber @@ -347,6 +356,13 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { return true, nil } + if state != StateEstablished { + drop, err := r.handleRcvdSegmentClosing(s, state, closed) + if drop || err != nil { + return drop, err + } + } + // Store the time of the last ack. r.lastRcvdAckTime = time.Now() diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go index c9eeff935..8a026ec46 100644 --- a/pkg/tcpip/transport/tcp/rcv_test.go +++ b/pkg/tcpip/transport/tcp/rcv_test.go @@ -30,7 +30,7 @@ func TestAcceptable(t *testing.T) { }{ // The segment is smaller than the window. {105, 2, 100, 104, false}, - {105, 2, 101, 105, false}, + {105, 2, 101, 105, true}, {105, 2, 102, 106, true}, {105, 2, 103, 107, true}, {105, 2, 104, 108, true}, @@ -39,7 +39,7 @@ func TestAcceptable(t *testing.T) { {105, 2, 107, 111, false}, // The segment is larger than the window. - {105, 4, 103, 105, false}, + {105, 4, 103, 105, true}, {105, 4, 104, 106, true}, {105, 4, 105, 107, true}, {105, 4, 106, 108, true}, diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 7712ce652..074edded6 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -96,6 +96,8 @@ func (s *segment) clone() *segment { route: s.route.Clone(), viewToDeliver: s.viewToDeliver, rcvdTime: s.rcvdTime, + xmitTime: s.xmitTime, + xmitCount: s.xmitCount, } t.data = s.data.Clone(t.views[:]) return t diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index a3018914b..9e547a221 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -598,22 +598,34 @@ func (s *sender) splitSeg(seg *segment, size int) { seg.data.CapLength(size) } -// NextSeg implements the RFC6675 NextSeg() operation. It returns segments that -// match rule 1, 3 and 4 of the NextSeg() operation defined in RFC6675. Rule 2 -// is handled by the normal send logic. -func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) { +// NextSeg implements the RFC6675 NextSeg() operation. +// +// NextSeg starts scanning the writeList starting from nextSegHint and returns +// the hint to be passed on the next call to NextSeg. This is required to avoid +// iterating the write list repeatedly when NextSeg is invoked in a loop during +// recovery. The returned hint will be nil if there are no more segments that +// can match rules defined by NextSeg operation in RFC6675. +// +// rescueRtx will be true only if nextSeg is a rescue retransmission as +// described by Step 4) of the NextSeg algorithm. +func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRtx bool) { var s3 *segment var s4 *segment - smss := s.ep.scoreboard.SMSS() // Step 1. - for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { - if !s.isAssignedSequenceNumber(seg) { + for seg := nextSegHint; seg != nil; seg = seg.Next() { + // Stop iteration if we hit a segment that has never been + // transmitted (i.e. either it has no assigned sequence number + // or if it does have one, it's >= the next sequence number + // to be sent [i.e. >= s.sndNxt]). + if !s.isAssignedSequenceNumber(seg) || s.sndNxt.LessThanEq(seg.sequenceNumber) { + hint = nil break } segSeq := seg.sequenceNumber - if seg.data.Size() > int(smss) { + if smss := s.ep.scoreboard.SMSS(); seg.data.Size() > int(smss) { s.splitSeg(seg, int(smss)) } + // See RFC 6675 Section 4 // // 1. If there exists a smallest unSACKED sequence number @@ -630,8 +642,9 @@ func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) { // NextSeg(): // (1.c) IsLost(S2) returns true. if s.ep.scoreboard.IsLost(segSeq) { - return seg, s3, s4 + return seg, seg.Next(), false } + // NextSeg(): // // (3): If the conditions for rules (1) and (2) @@ -643,6 +656,7 @@ func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) { // SHOULD be returned. if s3 == nil { s3 = seg + hint = seg.Next() } } // NextSeg(): @@ -651,10 +665,12 @@ func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) { // but there exists outstanding unSACKED data, we // provide the opportunity for a single "rescue" // retransmission per entry into loss recovery. If - // HighACK is greater than RescueRxt, the one - // segment of upto SMSS octects that MUST include - // the highest outstanding unSACKed sequence number - // SHOULD be returned. + // HighACK is greater than RescueRxt (or RescueRxt + // is undefined), then one segment of upto SMSS + // octects that MUST include the highest outstanding + // unSACKed sequence number SHOULD be returned, and + // RescueRxt set to RecoveryPoint. HighRxt MUST NOT + // be updated. if s.fr.rescueRxt.LessThan(s.sndUna - 1) { if s4 != nil { if s4.sequenceNumber.LessThan(segSeq) { @@ -663,12 +679,31 @@ func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) { } else { s4 = seg } - s.fr.rescueRxt = s.fr.last } } } - return nil, s3, s4 + // If we got here then no segment matched step (1). + // Step (2): "If no sequence number 'S2' per rule (1) + // exists but there exists available unsent data and the + // receiver's advertised window allows, the sequence + // range of one segment of up to SMSS octets of + // previously unsent data starting with sequence number + // HighData+1 MUST be returned." + for seg := s.writeNext; seg != nil; seg = seg.Next() { + if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) { + continue + } + // We do not split the segment here to <= smss as it has + // potentially not been assigned a sequence number yet. + return seg, nil, false + } + + if s3 != nil { + return s3, hint, false + } + + return s4, nil, true } // maybeSendSegment tries to send the specified segment and either coalesces @@ -792,64 +827,47 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // section 5, step C. func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) { s.SetPipe() + + if smss := int(s.ep.scoreboard.SMSS()); limit > smss { + // Cap segment size limit to s.smss as SACK recovery requires + // that all retransmissions or new segments send during recovery + // be of <= SMSS. + limit = smss + } + + nextSegHint := s.writeList.Front() for s.outstanding < s.sndCwnd { - nextSeg, s3, s4 := s.NextSeg() - if nextSeg == nil { - // NextSeg(): - // - // Step (2): "If no sequence number 'S2' per rule (1) - // exists but there exists available unsent data and the - // receiver's advertised window allows, the sequence - // range of one segment of up to SMSS octets of - // previously unsent data starting with sequence number - // HighData+1 MUST be returned." - for seg := s.writeNext; seg != nil; seg = seg.Next() { - if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) { - continue - } - // Step C.3 described below is handled by - // maybeSendSegment which increments sndNxt when - // a segment is transmitted. - // - // Step C.3 "If any of the data octets sent in - // (C.1) are above HighData, HighData must be - // updated to reflect the transmission of - // previously unsent data." - if sent := s.maybeSendSegment(seg, limit, end); !sent { - break - } - dataSent = true - s.outstanding++ - s.writeNext = seg.Next() - nextSeg = seg - break - } - if nextSeg != nil { - continue - } - } - rescueRtx := false - if nextSeg == nil && s3 != nil { - nextSeg = s3 - } - if nextSeg == nil && s4 != nil { - nextSeg = s4 - rescueRtx = true - } + var nextSeg *segment + var rescueRtx bool + nextSeg, nextSegHint, rescueRtx = s.NextSeg(nextSegHint) if nextSeg == nil { - break + return dataSent } - segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen()) - if !rescueRtx && nextSeg.sequenceNumber.LessThan(s.sndNxt) { - // RFC 6675, Step C.2 + if !s.isAssignedSequenceNumber(nextSeg) || s.sndNxt.LessThanEq(nextSeg.sequenceNumber) { + // New data being sent. + + // Step C.3 described below is handled by + // maybeSendSegment which increments sndNxt when + // a segment is transmitted. // - // "If any of the data octets sent in (C.1) are below - // HighData, HighRxt MUST be set to the highest sequence - // number of the retransmitted segment unless NextSeg () - // rule (4) was invoked for this retransmission." - s.fr.highRxt = segEnd - 1 + // Step C.3 "If any of the data octets sent in + // (C.1) are above HighData, HighData must be + // updated to reflect the transmission of + // previously unsent data." + // + // We pass s.smss as the limit as the Step 2) requires that + // new data sent should be of size s.smss or less. + if sent := s.maybeSendSegment(nextSeg, limit, end); !sent { + return dataSent + } + dataSent = true + s.outstanding++ + s.writeNext = nextSeg.Next() + continue } + // Now handle the retransmission case where we matched either step 1,3 or 4 + // of the NextSeg algorithm. // RFC 6675, Step C.4. // // "The estimate of the amount of data outstanding in the network @@ -858,6 +876,22 @@ func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) s.outstanding++ dataSent = true s.sendSegment(nextSeg) + + segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen()) + if rescueRtx { + // We do the last part of rule (4) of NextSeg here to update + // RescueRxt as until this point we don't know if we are going + // to use the rescue transmission. + s.fr.rescueRxt = s.fr.last + } else { + // RFC 6675, Step C.2 + // + // "If any of the data octets sent in (C.1) are below + // HighData, HighRxt MUST be set to the highest sequence + // number of the retransmitted segment unless NextSeg () + // rule (4) was invoked for this retransmission." + s.fr.highRxt = segEnd - 1 + } } return dataSent } @@ -903,7 +937,7 @@ func (s *sender) sendData() { // "A TCP SHOULD set cwnd to no more than RW before beginning // transmission if the TCP has not sent data in the interval exceeding // the retrasmission timeout." - if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto { + if !s.fr.active && s.state != RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { if s.sndCwnd > InitialCwnd { s.sndCwnd = InitialCwnd } @@ -921,6 +955,9 @@ func (s *sender) sendData() { limit = cwndLimit } if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + // Move writeNext along so that we don't try and scan data that + // has already been SACKED. + s.writeNext = seg.Next() continue } if sent := s.maybeSendSegment(seg, limit, end); !sent { @@ -966,6 +1003,8 @@ func (s *sender) enterFastRecovery() { s.fr.first = s.sndUna s.fr.last = s.sndNxt - 1 s.fr.maxCwnd = s.sndCwnd + s.outstanding + s.fr.highRxt = s.sndUna + s.fr.rescueRxt = s.sndUna if s.ep.sackPermitted { s.state = SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() @@ -1258,6 +1297,7 @@ func (s *sender) handleRcvdSegment(seg *segment) { if s.writeNext == seg { s.writeNext = seg.Next() } + s.writeList.Remove(seg) // if SACK is enabled then Only reduce outstanding if @@ -1329,7 +1369,23 @@ func (s *sender) sendSegment(seg *segment) *tcpip.Error { } seg.xmitTime = time.Now() seg.xmitCount++ - return s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber) + err := s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber) + + // Every time a packet containing data is sent (including a + // retransmission), if SACK is enabled and we are retransmitting data + // then use the conservative timer described in RFC6675 Section 6.0, + // otherwise follow the standard time described in RFC6298 Section 5.1. + if err != nil && seg.data.Size() != 0 { + if s.fr.active && seg.xmitCount > 1 && s.ep.sackPermitted { + s.resendTimer.enable(s.rto) + } else { + if !s.resendTimer.enabled() { + s.resendTimer.enable(s.rto) + } + } + } + + return err } // sendSegmentFromView sends a new segment containing the given payload, flags @@ -1345,19 +1401,5 @@ func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq // Remember the max sent ack. s.maxSentAck = rcvNxt - // Every time a packet containing data is sent (including a - // retransmission), if SACK is enabled then use the conservative timer - // described in RFC6675 Section 4.0, otherwise follow the standard time - // described in RFC6298 Section 5.2. - if data.Size() != 0 { - if s.ep.sackPermitted { - s.resendTimer.enable(s.rto) - } else { - if !s.resendTimer.enabled() { - s.resendTimer.enable(s.rto) - } - } - } - return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd) } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 49e4ba214..d2c90ebd5 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4905,6 +4905,8 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { } for _, test := range tests { + test := test // capture range variable + t.Run(test.name, func(t *testing.T) { t.Parallel() @@ -5007,6 +5009,8 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) { } for _, test := range tests { + test := test // capture range variable + t.Run(test.name, func(t *testing.T) { t.Parallel() |