diff options
Diffstat (limited to 'pkg/tcpip')
146 files changed, 14700 insertions, 5936 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 454e07662..89b765f1b 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,10 +1,25 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "sock_err_list", + out = "sock_err_list.go", + package = "tcpip", + prefix = "sockError", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*SockError", + "Linker": "*SockError", + }, +) + go_library( name = "tcpip", srcs = [ + "sock_err_list.go", + "socketops.go", "tcpip.go", "time_unsafe.go", "timer.go", diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 4f551cd92..7193f56ad 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -286,45 +286,47 @@ type opErrorer interface { // commonRead implements the common logic between net.Conn.Read and // net.PacketConn.ReadFrom. -func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) { +func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) (int, error) { select { case <-deadline: - return nil, errorer.newOpError("read", &timeoutError{}) + return 0, errorer.newOpError("read", &timeoutError{}) default: } - read, _, err := ep.Read(addr) + w := tcpip.SliceWriter(b) + opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil} + res, err := ep.Read(&w, len(b), opts) if err == tcpip.ErrWouldBlock { - if dontWait { - return nil, errWouldBlock - } // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) wq.EventRegister(&waitEntry, waiter.EventIn) defer wq.EventUnregister(&waitEntry) for { - read, _, err = ep.Read(addr) + res, err = ep.Read(&w, len(b), opts) if err != tcpip.ErrWouldBlock { break } select { case <-deadline: - return nil, errorer.newOpError("read", &timeoutError{}) + return 0, errorer.newOpError("read", &timeoutError{}) case <-notifyCh: } } } if err == tcpip.ErrClosedForReceive { - return nil, io.EOF + return 0, io.EOF } if err != nil { - return nil, errorer.newOpError("read", errors.New(err.String())) + return 0, errorer.newOpError("read", errors.New(err.String())) } - return read, nil + if addr != nil { + *addr = res.RemoteAddr + } + return res.Count, nil } // Read implements net.Conn.Read. @@ -334,31 +336,11 @@ func (c *TCPConn) Read(b []byte) (int, error) { deadline := c.readCancel() - numRead := 0 - defer func() { - if numRead != 0 { - c.ep.ModerateRecvBuf(numRead) - } - }() - for numRead != len(b) { - if len(c.read) == 0 { - var err error - c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0) - if err != nil { - if numRead != 0 { - return numRead, nil - } - return numRead, err - } - } - n := copy(b[numRead:], c.read) - c.read.TrimFront(n) - numRead += n - if len(c.read) == 0 { - c.read = nil - } + n, err := commonRead(b, c.ep, c.wq, deadline, nil, c) + if n != 0 { + c.ep.ModerateRecvBuf(n) } - return numRead, nil + return n, err } // Write implements net.Conn.Write. @@ -652,12 +634,11 @@ func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) { deadline := c.readCancel() var addr tcpip.FullAddress - read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false) + n, err := commonRead(b, c.ep, c.wq, deadline, &addr, c) if err != nil { return 0, nil, err } - - return copy(b, read), fullToUDPAddr(addr), nil + return n, fullToUDPAddr(addr), nil } func (c *UDPConn) Write(b []byte) (int, error) { diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8db70a700..5dd1b1b6b 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -105,18 +105,18 @@ func (vv *VectorisedView) TrimFront(count int) { } // Read implements io.Reader. -func (vv *VectorisedView) Read(v View) (copied int, err error) { - count := len(v) +func (vv *VectorisedView) Read(b []byte) (copied int, err error) { + count := len(b) for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { vv.size -= count - copy(v[copied:], vv.views[0][:count]) + copy(b[copied:], vv.views[0][:count]) vv.views[0].TrimFront(count) copied += count return copied, nil } count -= len(vv.views[0]) - copy(v[copied:], vv.views[0]) + copy(b[copied:], vv.views[0]) copied += len(vv.views[0]) vv.removeFirst() } @@ -145,6 +145,35 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int return copied } +// ReadTo reads up to count bytes from vv to dst. It also removes them from vv +// unless peek is true. +func (vv *VectorisedView) ReadTo(dst io.Writer, count int, peek bool) (int, error) { + var err error + done := 0 + for _, v := range vv.Views() { + remaining := count - done + if remaining <= 0 { + break + } + if len(v) > remaining { + v = v[:remaining] + } + + var n int + n, err = dst.Write(v) + if n > 0 { + done += n + } + if err != nil { + break + } + } + if !peek { + vv.TrimFront(done) + } + return done, err +} + // CapLength irreversibly reduces the length of the vectorised view. func (vv *VectorisedView) CapLength(length int) { if length < 0 { diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 726e54de9..e0ef8a94d 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -235,14 +235,16 @@ func TestToClone(t *testing.T) { } } -func TestVVReadToVV(t *testing.T) { - testCases := []struct { - comment string - vv VectorisedView - bytesToRead int - wantBytes string - leftVV VectorisedView - }{ +type readToTestCases struct { + comment string + vv VectorisedView + bytesToRead int + wantBytes string + leftVV VectorisedView +} + +func createReadToTestCases() []readToTestCases { + return []readToTestCases{ { comment: "large VV, short read", vv: vv(30, "012345678901234567890123456789"), @@ -279,8 +281,10 @@ func TestVVReadToVV(t *testing.T) { leftVV: vv(0, ""), }, } +} - for _, tc := range testCases { +func TestVVReadToVV(t *testing.T) { + for _, tc := range createReadToTestCases() { t.Run(tc.comment, func(t *testing.T) { var readTo VectorisedView inSize := tc.vv.Size() @@ -301,6 +305,52 @@ func TestVVReadToVV(t *testing.T) { } } +func TestVVReadTo(t *testing.T) { + for _, tc := range createReadToTestCases() { + t.Run(tc.comment, func(t *testing.T) { + var dst bytes.Buffer + origSize := tc.vv.Size() + copied, err := tc.vv.ReadTo(&dst, tc.bytesToRead, false /* peek */) + if got, want := copied, len(tc.wantBytes); err != nil || got != want { + t.Errorf("got ReadTo(&dst, %d, false) = %d, %v; want %d, nil", tc.bytesToRead, got, err, want) + } + if got, want := string(dst.Bytes()), tc.wantBytes; got != want { + t.Errorf("got dst = %q, want %q", got, want) + } + if got, want := tc.vv.Size(), origSize-copied; got != want { + t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want) + } + if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want { + t.Errorf("got after-read data in tc.vv = %q, want %q", got, want) + } + }) + } +} + +func TestVVReadToPeek(t *testing.T) { + for _, tc := range createReadToTestCases() { + t.Run(tc.comment, func(t *testing.T) { + var dst bytes.Buffer + origSize := tc.vv.Size() + origData := string(tc.vv.ToView()) + copied, err := tc.vv.ReadTo(&dst, tc.bytesToRead, true /* peek */) + if got, want := copied, len(tc.wantBytes); err != nil || got != want { + t.Errorf("got ReadTo(&dst, %d, false) = %d, %v; want %d, nil", tc.bytesToRead, got, err, want) + } + if got, want := string(dst.Bytes()), tc.wantBytes; got != want { + t.Errorf("got dst = %q, want %q", got, want) + } + // Expect tc.vv is unchanged. + if got, want := tc.vv.Size(), origSize; got != want { + t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want) + } + if got, want := string(tc.vv.ToView()), origData; got != want { + t.Errorf("got after-read data in tc.vv = %q, want %q", got, want) + } + }) + } +} + func TestVVRead(t *testing.T) { testCases := []struct { comment string diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 8868cf4e3..0ac2000ca 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -20,6 +20,7 @@ import ( "encoding/binary" "reflect" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" @@ -116,6 +117,10 @@ func TTL(ttl uint8) NetworkChecker { v = ip.TTL() case header.IPv6: v = ip.HopLimit() + case *ipv6HeaderWithExtHdr: + v = ip.HopLimit() + default: + t.Fatalf("unrecognized header type %T for TTL evaluation", ip) } if v != ttl { t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) @@ -216,6 +221,42 @@ func IPv4Options(want header.IPv4Options) NetworkChecker { } } +// IPv4RouterAlert returns a checker that checks that the RouterAlert option is +// set in an IPv4 packet. +func IPv4RouterAlert() NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + ip, ok := h[0].(header.IPv4) + if !ok { + t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) + } + iterator := ip.Options().MakeIterator() + for { + opt, done, err := iterator.Next() + if err != nil { + t.Fatalf("error acquiring next IPv4 option %s", err) + } + if done { + break + } + if opt.Type() != header.IPv4OptionRouterAlertType { + continue + } + want := [header.IPv4OptionRouterAlertLength]byte{ + byte(header.IPv4OptionRouterAlertType), + header.IPv4OptionRouterAlertLength, + header.IPv4OptionRouterAlertValue, + header.IPv4OptionRouterAlertValue, + } + if diff := cmp.Diff(want[:], opt.Contents()); diff != "" { + t.Errorf("router alert option mismatch (-want +got):\n%s", diff) + } + return + } + t.Errorf("failed to find router alert option in %v", ip.Options()) + } +} + // FragmentOffset creates a checker that checks the FragmentOffset field. func FragmentOffset(offset uint16) NetworkChecker { return func(t *testing.T, h []header.Network) { @@ -284,6 +325,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { } } +// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress +// field in ControlMessages. +func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasOriginalDstAddress { + t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress) + } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" { + t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff) + } + } +} + // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { @@ -904,6 +958,12 @@ func ICMPv4Payload(want []byte) TransportChecker { t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) } payload := icmpv4.Payload() + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(want) == 0 && len(payload) == 0 { + return + } + if diff := cmp.Diff(want, payload); diff != "" { t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) } @@ -994,12 +1054,86 @@ func ICMPv6Payload(want []byte) TransportChecker { t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) } payload := icmpv6.Payload() + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(want) == 0 && len(payload) == 0 { + return + } + if diff := cmp.Diff(want, payload); diff != "" { t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) } } } +// MLD creates a checker that checks that the packet contains a valid MLD +// message for type of mldType, with potentially additional checks specified by +// checkers. +// +// Checkers may assume that a valid ICMPv6 is passed to it containing a valid +// MLD message as far as the size of the message (minSize) is concerned. The +// values within the message are up to checkers to validate. +func MLD(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + // Check normal ICMPv6 first. + ICMPv6( + ICMPv6Type(msgType), + ICMPv6Code(0))(t, h) + + last := h[len(h)-1] + + icmp := header.ICMPv6(last.Payload()) + if got := len(icmp.MessageBody()); got < minSize { + t.Fatalf("ICMPv6 MLD (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) + } + + for _, f := range checkers { + f(t, icmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// MLDMaxRespDelay creates a checker that checks the Maximum Response Delay +// field of a MLD message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid MLD message as far as the size is concerned. +func MLDMaxRespDelay(want time.Duration) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + ns := header.MLD(icmp.MessageBody()) + + if got := ns.MaximumResponseDelay(); got != want { + t.Errorf("got %T.MaximumResponseDelay() = %s, want = %s", ns, got, want) + } + } +} + +// MLDMulticastAddress creates a checker that checks the Multicast Address +// field of a MLD message. +// +// The returned TransportChecker assumes that a valid ICMPv6 is passed to it +// containing a valid MLD message as far as the size is concerned. +func MLDMulticastAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmp := h.(header.ICMPv6) + ns := header.MLD(icmp.MessageBody()) + + if got := ns.MulticastAddress(); got != want { + t.Errorf("got %T.MulticastAddress() = %s, want = %s", ns, got, want) + } + } +} + // NDP creates a checker that checks that the packet contains a valid NDP // message for type of ty, with potentially additional checks specified by // checkers. @@ -1019,7 +1153,7 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N last := h[len(h)-1] icmp := header.ICMPv6(last.Payload()) - if got := len(icmp.NDPPayload()); got < minSize { + if got := len(icmp.MessageBody()); got < minSize { t.Fatalf("ICMPv6 NDP (type = %d) payload size of %d is less than the minimum size of %d", msgType, got, minSize) } @@ -1053,7 +1187,7 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) if got := ns.TargetAddress(); got != want { t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want) @@ -1082,7 +1216,7 @@ func NDPNATargetAddress(want tcpip.Address) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) if got := na.TargetAddress(); got != want { t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want) @@ -1100,7 +1234,7 @@ func NDPNASolicitedFlag(want bool) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) if got := na.SolicitedFlag(); got != want { t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want) @@ -1171,7 +1305,7 @@ func NDPNAOptions(opts []header.NDPOption) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) ndpOptions(t, na.Options(), opts) } } @@ -1186,7 +1320,7 @@ func NDPNSOptions(opts []header.NDPOption) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ndpOptions(t, ns.Options(), opts) } } @@ -1211,7 +1345,273 @@ func NDPRSOptions(opts []header.NDPOption) TransportChecker { t.Helper() icmp := h.(header.ICMPv6) - rs := header.NDPRouterSolicit(icmp.NDPPayload()) + rs := header.NDPRouterSolicit(icmp.MessageBody()) ndpOptions(t, rs.Options(), opts) } } + +// IGMP checks the validity and properties of the given IGMP packet. It is +// expected to be used in conjunction with other IGMP transport checkers for +// specific properties. +func IGMP(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.IGMPProtocolNumber { + t.Fatalf("Bad protocol, got %d, want %d", p, header.IGMPProtocolNumber) + } + + igmp := header.IGMP(last.Payload()) + for _, f := range checkers { + f(t, igmp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// IGMPType creates a checker that checks the IGMP Type field. +func IGMPType(want header.IGMPType) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.Type(); got != want { + t.Errorf("got igmp.Type() = %d, want = %d", got, want) + } + } +} + +// IGMPMaxRespTime creates a checker that checks the IGMP Max Resp Time field. +func IGMPMaxRespTime(want time.Duration) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.MaxRespTime(); got != want { + t.Errorf("got igmp.MaxRespTime() = %s, want = %s", got, want) + } + } +} + +// IGMPGroupAddress creates a checker that checks the IGMP Group Address field. +func IGMPGroupAddress(want tcpip.Address) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + igmp, ok := h.(header.IGMP) + if !ok { + t.Fatalf("got transport header = %T, want = header.IGMP", h) + } + if got := igmp.GroupAddress(); got != want { + t.Errorf("got igmp.GroupAddress() = %s, want = %s", got, want) + } + } +} + +// IPv6ExtHdrChecker is a function to check an extension header. +type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader) + +// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers. +func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) { + t.Helper() + + ipv6 := header.IPv6(b) + if !ipv6.IsValid(len(b)) { + t.Error("not a valid IPv6 packet") + return + } + + payloadIterator := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), + buffer.View(ipv6.Payload()).ToVectorisedView(), + ) + + var rawPayloadHeader header.IPv6RawPayloadHeader + for { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done) + return + } + r, ok := h.(header.IPv6RawPayloadHeader) + if ok { + rawPayloadHeader = r + break + } + } + + networkHeader := ipv6HeaderWithExtHdr{ + IPv6: ipv6, + transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier), + payload: rawPayloadHeader.Buf.ToView(), + } + + for _, checker := range checkers { + checker(t, []header.Network{&networkHeader}) + } +} + +// IPv6ExtHdr checks for the presence of extension headers. +// +// All the extension headers in headers will be checked exhaustively in the +// order provided. +func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr) + if !ok { + t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0]) + return + } + + payloadIterator := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()), + buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(), + ) + + for _, check := range headers { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done) + return + } + check(t, h) + } + // Validate we consumed all headers. + // + // The next one over should be a raw payload and then iterator should + // terminate. + wantDone := false + for { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done != wantDone { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone) + return + } + if done { + break + } + if _, ok := h.(header.IPv6RawPayloadHeader); !ok { + t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h) + continue + } + wantDone = true + } + } +} + +var _ header.Network = (*ipv6HeaderWithExtHdr)(nil) + +// ipv6HeaderWithExtHdr provides a header.Network implementation that takes +// extension headers into consideration, which is not the case with vanilla +// header.IPv6. +type ipv6HeaderWithExtHdr struct { + header.IPv6 + transport tcpip.TransportProtocolNumber + payload []byte +} + +// TransportProtocol implements header.Network. +func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber { + return h.transport +} + +// Payload implements header.Network. +func (h *ipv6HeaderWithExtHdr) Payload() []byte { + return h.payload +} + +// IPv6ExtHdrOptionChecker is a function to check an extension header option. +type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption) + +// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop +// extension header and validates the containing options with checkers. +// +// checkers must exhaustively contain all the expected options. +func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker { + return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) { + t.Helper() + + hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr) + if !ok { + t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader) + return + } + optionsIterator := hbh.Iter() + for _, f := range checkers { + opt, done, err := optionsIterator.Next() + if err != nil { + t.Errorf("optionsIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done) + } + f(t, opt) + } + // Validate all options were consumed. + for { + opt, done, err := optionsIterator.Next() + if err != nil { + t.Errorf("optionsIterator.Next(): %s", err) + return + } + if !done { + t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done) + } + if done { + break + } + } + } +} + +// IPv6RouterAlert validates that an extension header option is the RouterAlert +// option and matches on its value. +func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { + return func(t *testing.T, opt header.IPv6ExtHdrOption) { + routerAlert, ok := opt.(*header.IPv6RouterAlertOption) + if !ok { + t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt) + return + } + if routerAlert.Value != want { + t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want) + } + } +} + +// IgnoreCmpPath returns a cmp.Option that ignores listed field paths. +func IgnoreCmpPath(paths ...string) cmp.Option { + ignores := map[string]struct{}{} + for _, path := range paths { + ignores[path] = struct{}{} + } + return cmp.FilterPath(func(path cmp.Path) bool { + _, ok := ignores[path.String()] + return ok + }, cmp.Ignore()) +} diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index d87797617..0bdc12d53 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -11,11 +11,13 @@ go_library( "gue.go", "icmpv4.go", "icmpv6.go", + "igmp.go", "interfaces.go", "ipv4.go", "ipv6.go", "ipv6_extension_headers.go", "ipv6_fragment.go", + "mld.go", "ndp_neighbor_advert.go", "ndp_neighbor_solicit.go", "ndp_options.go", @@ -39,6 +41,8 @@ go_test( size = "small", srcs = [ "checksum_test.go", + "igmp_test.go", + "ipv4_test.go", "ipv6_test.go", "ipversion_test.go", "tcp_test.go", @@ -58,6 +62,7 @@ go_test( srcs = [ "eth_test.go", "ipv6_extension_headers_test.go", + "mld_test.go", "ndp_test.go", ], library = ":header", diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index 309403482..5ab20ee86 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -19,6 +19,7 @@ package header_test import ( "fmt" "math/rand" + "sync" "testing" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -169,3 +170,96 @@ func BenchmarkChecksum(b *testing.B) { } } } + +func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) { + // icmpChecksum should not do any modifications of the header to + // calculate its checksum. Let's call it from a few go-routines and the + // race detector will trigger a warning if there are any concurrent + // read/write accesses. + + const concurrency = 5 + start := make(chan int) + ready := make(chan bool, concurrency) + var wg sync.WaitGroup + wg.Add(concurrency) + defer wg.Wait() + + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + + ready <- true + <-start + + if got := headerChecksum(); want != got { + t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) + } + if got := icmpChecksum(); want != got { + t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) + } + }() + } + for i := 0; i < concurrency; i++ { + <-ready + } + close(start) +} + +func TestICMPv4Checksum(t *testing.T) { + rnd := rand.New(rand.NewSource(42)) + + h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize)) + if _, err := rnd.Read(h); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + h.SetChecksum(0) + + buf := make([]byte, 13) + if _, err := rnd.Read(buf); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + vv := buffer.NewVectorisedView(len(buf), []buffer.View{ + buffer.NewViewFromBytes(buf[:5]), + buffer.NewViewFromBytes(buf[5:]), + }) + + want := header.Checksum(vv.ToView(), 0) + want = ^header.Checksum(h, want) + h.SetChecksum(want) + + testICMPChecksum(t, h.Checksum, func() uint16 { + return header.ICMPv4Checksum(h, vv) + }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) +} + +func TestICMPv6Checksum(t *testing.T) { + rnd := rand.New(rand.NewSource(42)) + + h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize)) + if _, err := rnd.Read(h); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + h.SetChecksum(0) + + buf := make([]byte, 13) + if _, err := rnd.Read(buf); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + vv := buffer.NewVectorisedView(len(buf), []buffer.View{ + buffer.NewViewFromBytes(buf[:7]), + buffer.NewViewFromBytes(buf[7:10]), + buffer.NewViewFromBytes(buf[10:]), + }) + + dst := header.IPv6Loopback + src := header.IPv6Loopback + + want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) + want = header.Checksum(vv.ToView(), want) + want = ^header.Checksum(h, want) + h.SetChecksum(want) + + testICMPChecksum(t, h.Checksum, func() uint16 { + return header.ICMPv6Checksum(h, src, dst, vv) + }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) +} diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 2f13dea6a..5f9b8e9e2 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -16,6 +16,7 @@ package header import ( "encoding/binary" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -199,17 +200,24 @@ func (b ICMPv4) SetSequence(sequence uint16) { // ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header, // and payload. func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { - // Calculate the IPv6 pseudo-header upper-layer checksum. - xsum := uint16(0) - for _, v := range vv.Views() { - xsum = Checksum(v, xsum) - } + xsum := ChecksumVV(vv, 0) + + // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. + xsum = Checksum(h[:2], xsum) + xsum = Checksum(h[4:], xsum) - // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. - h2, h3 := h[2], h[3] - h[2], h[3] = 0, 0 - xsum = ^Checksum(h, xsum) - h[2], h[3] = h2, h3 + return ^xsum +} - return xsum +// ICMPOriginFromNetProto returns the appropriate SockErrOrigin to use when +// a packet having a `net` header causing an ICMP error. +func ICMPOriginFromNetProto(net tcpip.NetworkProtocolNumber) tcpip.SockErrOrigin { + switch net { + case IPv4ProtocolNumber: + return tcpip.SockExtErrorOriginICMP + case IPv6ProtocolNumber: + return tcpip.SockExtErrorOriginICMP6 + default: + panic(fmt.Sprintf("unsupported net proto to extract ICMP error origin: %d", net)) + } } diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 4303fc5d5..eca9750ab 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -115,6 +115,12 @@ const ( ICMPv6NeighborSolicit ICMPv6Type = 135 ICMPv6NeighborAdvert ICMPv6Type = 136 ICMPv6RedirectMsg ICMPv6Type = 137 + + // Multicast Listener Discovery (MLD) messages, see RFC 2710. + + ICMPv6MulticastListenerQuery ICMPv6Type = 130 + ICMPv6MulticastListenerReport ICMPv6Type = 131 + ICMPv6MulticastListenerDone ICMPv6Type = 132 ) // IsErrorType returns true if the receiver is an ICMP error type. @@ -245,10 +251,9 @@ func (b ICMPv6) SetSequence(sequence uint16) { binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence) } -// NDPPayload returns the NDP payload buffer. That is, it returns the ICMPv6 -// packet's message body as defined by RFC 4443 section 2.1; the portion of the -// ICMPv6 buffer after the first ICMPv6HeaderSize bytes. -func (b ICMPv6) NDPPayload() []byte { +// MessageBody returns the message body as defined by RFC 4443 section 2.1; the +// portion of the ICMPv6 buffer after the first ICMPv6HeaderSize bytes. +func (b ICMPv6) MessageBody() []byte { return b[ICMPv6HeaderSize:] } @@ -260,22 +265,13 @@ func (b ICMPv6) Payload() []byte { // ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header, // IPv6 src/dst addresses and the payload. func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { - // Calculate the IPv6 pseudo-header upper-layer checksum. - xsum := Checksum([]byte(src), 0) - xsum = Checksum([]byte(dst), xsum) - var upperLayerLength [4]byte - binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size())) - xsum = Checksum(upperLayerLength[:], xsum) - xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum) - for _, v := range vv.Views() { - xsum = Checksum(v, xsum) - } - - // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. - h2, h3 := h[2], h[3] - h[2], h[3] = 0, 0 - xsum = ^Checksum(h, xsum) - h[2], h[3] = h2, h3 - - return xsum + xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) + + xsum = ChecksumVV(vv, xsum) + + // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. + xsum = Checksum(h[:2], xsum) + xsum = Checksum(h[4:], xsum) + + return ^xsum } diff --git a/pkg/tcpip/header/igmp.go b/pkg/tcpip/header/igmp.go new file mode 100644 index 000000000..5c5be1b9d --- /dev/null +++ b/pkg/tcpip/header/igmp.go @@ -0,0 +1,181 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +// IGMP represents an IGMP header stored in a byte array. +type IGMP []byte + +// IGMP implements `Transport`. +var _ Transport = (*IGMP)(nil) + +const ( + // IGMPMinimumSize is the minimum size of a valid IGMP packet in bytes, + // as per RFC 2236, Section 2, Page 2. + IGMPMinimumSize = 8 + + // IGMPQueryMinimumSize is the minimum size of a valid Membership Query + // Message in bytes, as per RFC 2236, Section 2, Page 2. + IGMPQueryMinimumSize = 8 + + // IGMPReportMinimumSize is the minimum size of a valid Report Message in + // bytes, as per RFC 2236, Section 2, Page 2. + IGMPReportMinimumSize = 8 + + // IGMPLeaveMessageMinimumSize is the minimum size of a valid Leave Message + // in bytes, as per RFC 2236, Section 2, Page 2. + IGMPLeaveMessageMinimumSize = 8 + + // IGMPTTL is the TTL for all IGMP messages, as per RFC 2236, Section 3, Page + // 3. + IGMPTTL = 1 + + // igmpTypeOffset defines the offset of the type field in an IGMP message. + igmpTypeOffset = 0 + + // igmpMaxRespTimeOffset defines the offset of the MaxRespTime field in an + // IGMP message. + igmpMaxRespTimeOffset = 1 + + // igmpChecksumOffset defines the offset of the checksum field in an IGMP + // message. + igmpChecksumOffset = 2 + + // igmpGroupAddressOffset defines the offset of the Group Address field in an + // IGMP message. + igmpGroupAddressOffset = 4 + + // IGMPProtocolNumber is IGMP's transport protocol number. + IGMPProtocolNumber tcpip.TransportProtocolNumber = 2 +) + +// IGMPType is the IGMP type field as per RFC 2236. +type IGMPType byte + +// Values for the IGMP Type described in RFC 2236 Section 2.1, Page 2. +// Descriptions below come from there. +const ( + // IGMPMembershipQuery indicates that the message type is Membership Query. + // "There are two sub-types of Membership Query messages: + // - General Query, used to learn which groups have members on an + // attached network. + // - Group-Specific Query, used to learn if a particular group + // has any members on an attached network. + // These two messages are differentiated by the Group Address, as + // described in section 1.4 ." + IGMPMembershipQuery IGMPType = 0x11 + // IGMPv1MembershipReport indicates that the message is a Membership Report + // generated by a host using the IGMPv1 protocol: "an additional type of + // message, for backwards-compatibility with IGMPv1" + IGMPv1MembershipReport IGMPType = 0x12 + // IGMPv2MembershipReport indicates that the Message type is a Membership + // Report generated by a host using the IGMPv2 protocol. + IGMPv2MembershipReport IGMPType = 0x16 + // IGMPLeaveGroup indicates that the message type is a Leave Group + // notification message. + IGMPLeaveGroup IGMPType = 0x17 +) + +// Type is the IGMP type field. +func (b IGMP) Type() IGMPType { return IGMPType(b[igmpTypeOffset]) } + +// SetType sets the IGMP type field. +func (b IGMP) SetType(t IGMPType) { b[igmpTypeOffset] = byte(t) } + +// MaxRespTime gets the MaxRespTimeField. This is meaningful only in Membership +// Query messages, in other cases it is set to 0 by the sender and ignored by +// the receiver. +func (b IGMP) MaxRespTime() time.Duration { + // As per RFC 2236 section 2.2, + // + // The Max Response Time field is meaningful only in Membership Query + // messages, and specifies the maximum allowed time before sending a + // responding report in units of 1/10 second. In all other messages, it + // is set to zero by the sender and ignored by receivers. + return DecisecondToDuration(b[igmpMaxRespTimeOffset]) +} + +// SetMaxRespTime sets the MaxRespTimeField. +func (b IGMP) SetMaxRespTime(m byte) { b[igmpMaxRespTimeOffset] = m } + +// Checksum is the IGMP checksum field. +func (b IGMP) Checksum() uint16 { + return binary.BigEndian.Uint16(b[igmpChecksumOffset:]) +} + +// SetChecksum sets the IGMP checksum field. +func (b IGMP) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[igmpChecksumOffset:], checksum) +} + +// GroupAddress gets the Group Address field. +func (b IGMP) GroupAddress() tcpip.Address { + return tcpip.Address(b[igmpGroupAddressOffset:][:IPv4AddressSize]) +} + +// SetGroupAddress sets the Group Address field. +func (b IGMP) SetGroupAddress(address tcpip.Address) { + if n := copy(b[igmpGroupAddressOffset:], address); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d", n, IPv4AddressSize)) + } +} + +// SourcePort implements Transport.SourcePort. +func (IGMP) SourcePort() uint16 { + return 0 +} + +// DestinationPort implements Transport.DestinationPort. +func (IGMP) DestinationPort() uint16 { + return 0 +} + +// SetSourcePort implements Transport.SetSourcePort. +func (IGMP) SetSourcePort(uint16) { +} + +// SetDestinationPort implements Transport.SetDestinationPort. +func (IGMP) SetDestinationPort(uint16) { +} + +// Payload implements Transport.Payload. +func (IGMP) Payload() []byte { + return nil +} + +// IGMPCalculateChecksum calculates the IGMP checksum over the provided IGMP +// header. +func IGMPCalculateChecksum(h IGMP) uint16 { + // The header contains a checksum itself, set it aside to avoid checksumming + // the checksum and replace it afterwards. + existingXsum := h.Checksum() + h.SetChecksum(0) + xsum := ^Checksum(h, 0) + h.SetChecksum(existingXsum) + return xsum +} + +// DecisecondToDuration converts a value representing deci-seconds to a +// time.Duration. +func DecisecondToDuration(ds uint8) time.Duration { + return time.Duration(ds) * time.Second / 10 +} diff --git a/pkg/tcpip/header/igmp_test.go b/pkg/tcpip/header/igmp_test.go new file mode 100644 index 000000000..b6126d29a --- /dev/null +++ b/pkg/tcpip/header/igmp_test.go @@ -0,0 +1,110 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +// TestIGMPHeader tests the functions within header.igmp +func TestIGMPHeader(t *testing.T) { + const maxRespTimeTenthSec = 0xF0 + b := []byte{ + 0x11, // IGMP Type, Membership Query + maxRespTimeTenthSec, // Maximum Response Time + 0xC0, 0xC0, // Checksum + 0x01, 0x02, 0x03, 0x04, // Group Address + } + + igmpHeader := header.IGMP(b) + + if got, want := igmpHeader.Type(), header.IGMPMembershipQuery; got != want { + t.Errorf("got igmpHeader.Type() = %x, want = %x", got, want) + } + + if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(maxRespTimeTenthSec); got != want { + t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want) + } + + if got, want := igmpHeader.Checksum(), uint16(0xC0C0); got != want { + t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, want) + } + + if got, want := igmpHeader.GroupAddress(), tcpip.Address("\x01\x02\x03\x04"); got != want { + t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, want) + } + + igmpType := header.IGMPv2MembershipReport + igmpHeader.SetType(igmpType) + if got := igmpHeader.Type(); got != igmpType { + t.Errorf("got igmpHeader.Type() = %x, want = %x", got, igmpType) + } + if got := header.IGMPType(b[0]); got != igmpType { + t.Errorf("got IGMPtype in backing buffer = %x, want %x", got, igmpType) + } + + respTime := byte(0x02) + igmpHeader.SetMaxRespTime(respTime) + if got, want := igmpHeader.MaxRespTime(), header.DecisecondToDuration(respTime); got != want { + t.Errorf("got igmpHeader.MaxRespTime() = %s, want = %s", got, want) + } + + checksum := uint16(0x0102) + igmpHeader.SetChecksum(checksum) + if got := igmpHeader.Checksum(); got != checksum { + t.Errorf("got igmpHeader.Checksum() = %x, want = %x", got, checksum) + } + + groupAddress := tcpip.Address("\x04\x03\x02\x01") + igmpHeader.SetGroupAddress(groupAddress) + if got := igmpHeader.GroupAddress(); got != groupAddress { + t.Errorf("got igmpHeader.GroupAddress() = %s, want = %s", got, groupAddress) + } +} + +// TestIGMPChecksum ensures that the checksum calculator produces the expected +// checksum. +func TestIGMPChecksum(t *testing.T) { + b := []byte{ + 0x11, // IGMP Type, Membership Query + 0xF0, // Maximum Response Time + 0xC0, 0xC0, // Checksum + 0x01, 0x02, 0x03, 0x04, // Group Address + } + + igmpHeader := header.IGMP(b) + + // Calculate the initial checksum after setting the checksum temporarily to 0 + // to avoid checksumming the checksum. + initialChecksum := igmpHeader.Checksum() + igmpHeader.SetChecksum(0) + checksum := ^header.Checksum(b, 0) + igmpHeader.SetChecksum(initialChecksum) + + if got := header.IGMPCalculateChecksum(igmpHeader); got != checksum { + t.Errorf("got IGMPCalculateChecksum = %x, want %x", got, checksum) + } +} + +func TestDecisecondToDuration(t *testing.T) { + const valueInDeciseconds = 5 + if got, want := header.DecisecondToDuration(valueInDeciseconds), valueInDeciseconds*time.Second/10; got != want { + t.Fatalf("got header.DecisecondToDuration(%d) = %s, want = %s", valueInDeciseconds, got, want) + } +} diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 7e32b31b4..e6103f4bc 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -89,8 +89,18 @@ type IPv4Fields struct { // DstAddr is the "destination ip address" of an IPv4 packet. DstAddr tcpip.Address - // Options is between 0 and 40 bytes or nil if empty. - Options IPv4Options + // Options must be 40 bytes or less as they must fit along with the + // rest of the IPv4 header into the maximum size describable in the + // IHL field. RFC 791 section 3.1 says: + // IHL: 4 bits + // + // Internet Header Length is the length of the internet header in 32 + // bit words, and thus points to the beginning of the data. Note that + // the minimum value for a correct header is 5. + // + // That leaves ten 32 bit (4 byte) fields for options. An attempt to encode + // more will fail. + Options IPv4OptionsSerializer } // IPv4 is an IPv4 header. @@ -147,6 +157,9 @@ const ( // IPv4Any is the non-routable IPv4 "any" meta address. IPv4Any tcpip.Address = "\x00\x00\x00\x00" + // IPv4AllRoutersGroup is a multicast address for all routers. + IPv4AllRoutersGroup tcpip.Address = "\xe0\x00\x00\x02" + // IPv4MinimumProcessableDatagramSize is the minimum size of an IP // packet that every IPv4 capable host must be able to // process/reassemble. @@ -272,25 +285,21 @@ func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } -// IPv4Options is a buffer that holds all the raw IP options. -type IPv4Options []byte - -// AllocationSize implements stack.NetOptions. -// It reports the size to allocate for the Options. RFC 791 page 23 (end of -// section 3.1) says of the padding at the end of the options: +// padIPv4OptionsLength returns the total length for IPv4 options of length l +// after applying padding according to RFC 791: // The internet header padding is used to ensure that the internet // header ends on a 32 bit boundary. -func (o IPv4Options) AllocationSize() int { - return (len(o) + IPv4IHLStride - 1) & ^(IPv4IHLStride - 1) +func padIPv4OptionsLength(length uint8) uint8 { + return (length + IPv4IHLStride - 1) & ^uint8(IPv4IHLStride-1) } -// Options returns a buffer holding the options or nil. +// IPv4Options is a buffer that holds all the raw IP options. +type IPv4Options []byte + +// Options returns a buffer holding the options. func (b IPv4) Options() IPv4Options { hdrLen := b.HeaderLength() - if hdrLen > IPv4MinimumSize { - return IPv4Options(b[options:hdrLen:hdrLen]) - } - return nil + return IPv4Options(b[options:hdrLen:hdrLen]) } // TransportProtocol implements Network.TransportProtocol. @@ -365,20 +374,16 @@ func (b IPv4) CalculateChecksum() uint16 { func (b IPv4) Encode(i *IPv4Fields) { // The size of the options defines the size of the whole header and thus the // IHL field. Options are rare and this is a heavily used function so it is - // worth a bit of optimisation here to keep the copy out of the fast path. - hdrLen := IPv4MinimumSize + // worth a bit of optimisation here to keep the serializer out of the fast + // path. + hdrLen := uint8(IPv4MinimumSize) if len(i.Options) != 0 { - // AllocationSize is always >= len(i.Options). - aLen := i.Options.AllocationSize() - hdrLen += aLen - if hdrLen > len(b) { - panic(fmt.Sprintf("encode received %d bytes, wanted >= %d", len(b), hdrLen)) - } - if aLen != copy(b[options:], i.Options) { - _ = copy(b[options+len(i.Options):options+aLen], []byte{0, 0, 0, 0}) - } + hdrLen += i.Options.Serialize(b[options:]) } - b.SetHeaderLength(uint8(hdrLen)) + if hdrLen > IPv4MaximumHeaderSize { + panic(fmt.Sprintf("%d is larger than maximum IPv4 header size of %d", hdrLen, IPv4MaximumHeaderSize)) + } + b.SetHeaderLength(hdrLen) b[tos] = i.TOS b.SetTotalLength(i.TotalLength) binary.BigEndian.PutUint16(b[id:], i.ID) @@ -458,6 +463,10 @@ const ( // options and may appear multiple times. IPv4OptionNOPType IPv4OptionType = 1 + // IPv4OptionRouterAlertType is the option type for the Router Alert option, + // defined in RFC 2113 Section 2.1. + IPv4OptionRouterAlertType IPv4OptionType = 20 | 0x80 + // IPv4OptionRecordRouteType is used by each router on the path of the packet // to record its path. It is carried over to an Echo Reply. IPv4OptionRecordRouteType IPv4OptionType = 7 @@ -858,3 +867,162 @@ func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) } // Contents implements IPv4Option. func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) } + +// Router Alert option specific related constants. +// +// from RFC 2113 section 2.1: +// +// +--------+--------+--------+--------+ +// |10010100|00000100| 2 octet value | +// +--------+--------+--------+--------+ +// +// Type: +// Copied flag: 1 (all fragments must carry the option) +// Option class: 0 (control) +// Option number: 20 (decimal) +// +// Length: 4 +// +// Value: A two octet code with the following values: +// 0 - Router shall examine packet +// 1-65535 - Reserved +const ( + // IPv4OptionRouterAlertLength is the length of a Router Alert option. + IPv4OptionRouterAlertLength = 4 + + // IPv4OptionRouterAlertValue is the only permissible value of the 16 bit + // payload of the router alert option. + IPv4OptionRouterAlertValue = 0 + + // iPv4OptionRouterAlertValueOffset is the offset for the value of a + // RouterAlert option. + iPv4OptionRouterAlertValueOffset = 2 +) + +// IPv4SerializableOption is an interface to represent serializable IPv4 option +// types. +type IPv4SerializableOption interface { + // optionType returns the type identifier of the option. + optionType() IPv4OptionType +} + +// IPv4SerializableOptionPayload is an interface providing serialization of the +// payload of an IPv4 option. +type IPv4SerializableOptionPayload interface { + // length returns the size of the payload. + length() uint8 + + // serializeInto serializes the payload into the provided byte buffer. + // + // Note, the caller MUST provide a byte buffer with size of at least + // Length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MUST panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto will return the number of bytes that was used to + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a + // larger buffer than required to serialize into. + serializeInto(buffer []byte) uint8 +} + +// IPv4OptionsSerializer is a serializer for IPv4 options. +type IPv4OptionsSerializer []IPv4SerializableOption + +// Length returns the total number of bytes required to serialize the options. +func (s IPv4OptionsSerializer) Length() uint8 { + var total uint8 + for _, opt := range s { + total++ + if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok { + // Add 1 to reported length to account for the length byte. + total += 1 + withPayload.length() + } + } + return padIPv4OptionsLength(total) +} + +// Serialize serializes the provided list of IPV4 options into b. +// +// Note, b must be of sufficient size to hold all the options in s. See +// IPv4OptionsSerializer.Length for details on the getting the total size +// of a serialized IPv4OptionsSerializer. +// +// Serialize panics if b is not of sufficient size to hold all the options in s. +func (s IPv4OptionsSerializer) Serialize(b []byte) uint8 { + var total uint8 + for _, opt := range s { + ty := opt.optionType() + if withPayload, ok := opt.(IPv4SerializableOptionPayload); ok { + // Serialize first to reduce bounds checks. + l := 2 + withPayload.serializeInto(b[2:]) + b[0] = byte(ty) + b[1] = l + b = b[l:] + total += l + continue + } + // Options without payload consist only of the type field. + // + // NB: Repeating code from the branch above is intentional to minimize + // bounds checks. + b[0] = byte(ty) + b = b[1:] + total++ + } + + // According to RFC 791: + // + // The internet header padding is used to ensure that the internet + // header ends on a 32 bit boundary. The padding is zero. + padded := padIPv4OptionsLength(total) + b = b[:padded-total] + for i := range b { + b[i] = 0 + } + return padded +} + +var _ IPv4SerializableOptionPayload = (*IPv4SerializableRouterAlertOption)(nil) +var _ IPv4SerializableOption = (*IPv4SerializableRouterAlertOption)(nil) + +// IPv4SerializableRouterAlertOption provides serialization of the Router Alert +// IPv4 option according to RFC 2113. +type IPv4SerializableRouterAlertOption struct{} + +// Type implements IPv4SerializableOption. +func (*IPv4SerializableRouterAlertOption) optionType() IPv4OptionType { + return IPv4OptionRouterAlertType +} + +// Length implements IPv4SerializableOption. +func (*IPv4SerializableRouterAlertOption) length() uint8 { + return IPv4OptionRouterAlertLength - iPv4OptionRouterAlertValueOffset +} + +// SerializeInto implements IPv4SerializableOption. +func (o *IPv4SerializableRouterAlertOption) serializeInto(buffer []byte) uint8 { + binary.BigEndian.PutUint16(buffer, IPv4OptionRouterAlertValue) + return o.length() +} + +var _ IPv4SerializableOption = (*IPv4SerializableNOPOption)(nil) + +// IPv4SerializableNOPOption provides serialization for the IPv4 no-op option. +type IPv4SerializableNOPOption struct{} + +// Type implements IPv4SerializableOption. +func (*IPv4SerializableNOPOption) optionType() IPv4OptionType { + return IPv4OptionNOPType +} + +var _ IPv4SerializableOption = (*IPv4SerializableListEndOption)(nil) + +// IPv4SerializableListEndOption provides serialization for the IPv4 List End +// option. +type IPv4SerializableListEndOption struct{} + +// Type implements IPv4SerializableOption. +func (*IPv4SerializableListEndOption) optionType() IPv4OptionType { + return IPv4OptionListEndType +} diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go new file mode 100644 index 000000000..6475cd694 --- /dev/null +++ b/pkg/tcpip/header/ipv4_test.go @@ -0,0 +1,179 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +func TestIPv4OptionsSerializer(t *testing.T) { + optCases := []struct { + name string + option []header.IPv4SerializableOption + expect []byte + }{ + { + name: "NOP", + option: []header.IPv4SerializableOption{ + &header.IPv4SerializableNOPOption{}, + }, + expect: []byte{1, 0, 0, 0}, + }, + { + name: "ListEnd", + option: []header.IPv4SerializableOption{ + &header.IPv4SerializableListEndOption{}, + }, + expect: []byte{0, 0, 0, 0}, + }, + { + name: "RouterAlert", + option: []header.IPv4SerializableOption{ + &header.IPv4SerializableRouterAlertOption{}, + }, + expect: []byte{148, 4, 0, 0}, + }, { + name: "NOP and RouterAlert", + option: []header.IPv4SerializableOption{ + &header.IPv4SerializableNOPOption{}, + &header.IPv4SerializableRouterAlertOption{}, + }, + expect: []byte{1, 148, 4, 0, 0, 0, 0, 0}, + }, + } + + for _, opt := range optCases { + t.Run(opt.name, func(t *testing.T) { + s := header.IPv4OptionsSerializer(opt.option) + l := s.Length() + if got := len(opt.expect); got != int(l) { + t.Fatalf("s.Length() = %d, want = %d", got, l) + } + b := make([]byte, l) + for i := range b { + // Fill the buffer with full bytes to ensure padding is being set + // correctly. + b[i] = 0xFF + } + if serializedLength := s.Serialize(b); serializedLength != l { + t.Fatalf("s.Serialize(_) = %d, want %d", serializedLength, l) + } + if diff := cmp.Diff(opt.expect, b); diff != "" { + t.Errorf("mismatched serialized option (-want +got):\n%s", diff) + } + }) + } +} + +// TestIPv4Encode checks that ipv4.Encode correctly fills out the requested +// fields when options are supplied. +func TestIPv4EncodeOptions(t *testing.T) { + tests := []struct { + name string + numberOfNops int + encodedOptions header.IPv4Options // reply should look like this + wantIHL int + }{ + { + name: "valid no options", + wantIHL: header.IPv4MinimumSize, + }, + { + name: "one byte options", + numberOfNops: 1, + encodedOptions: header.IPv4Options{1, 0, 0, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "two byte options", + numberOfNops: 2, + encodedOptions: header.IPv4Options{1, 1, 0, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "three byte options", + numberOfNops: 3, + encodedOptions: header.IPv4Options{1, 1, 1, 0}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "four byte options", + numberOfNops: 4, + encodedOptions: header.IPv4Options{1, 1, 1, 1}, + wantIHL: header.IPv4MinimumSize + 4, + }, + { + name: "five byte options", + numberOfNops: 5, + encodedOptions: header.IPv4Options{1, 1, 1, 1, 1, 0, 0, 0}, + wantIHL: header.IPv4MinimumSize + 8, + }, + { + name: "thirty nine byte options", + numberOfNops: 39, + encodedOptions: header.IPv4Options{ + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 0, + }, + wantIHL: header.IPv4MinimumSize + 40, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + serializeOpts := header.IPv4OptionsSerializer(make([]header.IPv4SerializableOption, test.numberOfNops)) + for i := range serializeOpts { + serializeOpts[i] = &header.IPv4SerializableNOPOption{} + } + paddedOptionLength := serializeOpts.Length() + ipHeaderLength := int(header.IPv4MinimumSize + paddedOptionLength) + if ipHeaderLength > header.IPv4MaximumHeaderSize { + t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + } + totalLen := uint16(ipHeaderLength) + hdr := buffer.NewPrependable(int(totalLen)) + ip := header.IPv4(hdr.Prepend(ipHeaderLength)) + // To check the padding works, poison the last byte of the options space. + if paddedOptionLength != serializeOpts.Length() { + ip.SetHeaderLength(uint8(ipHeaderLength)) + ip.Options()[paddedOptionLength-1] = 0xff + ip.SetHeaderLength(0) + } + ip.Encode(&header.IPv4Fields{ + Options: serializeOpts, + }) + options := ip.Options() + wantOptions := test.encodedOptions + if got, want := int(ip.HeaderLength()), test.wantIHL; got != want { + t.Errorf("got IHL of %d, want %d", got, want) + } + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(wantOptions) == 0 && len(options) == 0 { + return + } + + if diff := cmp.Diff(wantOptions, options); diff != "" { + t.Errorf("options mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 4e7e5f76a..5580d6a78 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -18,7 +18,6 @@ import ( "crypto/sha256" "encoding/binary" "fmt" - "strings" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -48,13 +47,15 @@ type IPv6Fields struct { // FlowLabel is the "flow label" field of an IPv6 packet. FlowLabel uint32 - // PayloadLength is the "payload length" field of an IPv6 packet. + // PayloadLength is the "payload length" field of an IPv6 packet, including + // the length of all extension headers. PayloadLength uint16 - // NextHeader is the "next header" field of an IPv6 packet. - NextHeader uint8 + // TransportProtocol is the transport layer protocol number. Serialized in the + // last "next header" field of the IPv6 header + extension headers. + TransportProtocol tcpip.TransportProtocolNumber - // HopLimit is the "hop limit" field of an IPv6 packet. + // HopLimit is the "Hop Limit" field of an IPv6 packet. HopLimit uint8 // SrcAddr is the "source ip address" of an IPv6 packet. @@ -62,6 +63,9 @@ type IPv6Fields struct { // DstAddr is the "destination ip address" of an IPv6 packet. DstAddr tcpip.Address + + // ExtensionHeaders are the extension headers following the IPv6 header. + ExtensionHeaders IPv6ExtHdrSerializer } // IPv6 represents an ipv6 header stored in a byte array. @@ -148,13 +152,17 @@ const ( // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the // catch-all or wildcard subnet. That is, all IPv6 addresses are considered to // be contained within this subnet. -var IPv6EmptySubnet = func() tcpip.Subnet { - subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any)) - if err != nil { - panic(err) - } - return subnet -}() +var IPv6EmptySubnet = tcpip.AddressWithPrefix{ + Address: IPv6Any, + PrefixLen: 0, +}.Subnet() + +// IPv4MappedIPv6Subnet is the prefix for an IPv4 mapped IPv6 address as defined +// by RFC 4291 section 2.5.5. +var IPv4MappedIPv6Subnet = tcpip.AddressWithPrefix{ + Address: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00", + PrefixLen: 96, +}.Subnet() // IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined // by RFC 4291 section 2.5.6. @@ -171,7 +179,7 @@ func (b IPv6) PayloadLength() uint16 { return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:]) } -// HopLimit returns the value of the "hop limit" field of the ipv6 header. +// HopLimit returns the value of the "Hop Limit" field of the ipv6 header. func (b IPv6) HopLimit() uint8 { return b[hopLimit] } @@ -236,6 +244,11 @@ func (b IPv6) SetDestinationAddress(addr tcpip.Address) { copy(b[v6DstAddr:][:IPv6AddressSize], addr) } +// SetHopLimit sets the value of the "Hop Limit" field. +func (b IPv6) SetHopLimit(v uint8) { + b[hopLimit] = v +} + // SetNextHeader sets the value of the "next header" field of the ipv6 header. func (b IPv6) SetNextHeader(v uint8) { b[IPv6NextHeaderOffset] = v @@ -248,12 +261,14 @@ func (IPv6) SetChecksum(uint16) { // Encode encodes all the fields of the ipv6 header. func (b IPv6) Encode(i *IPv6Fields) { + extHdr := b[IPv6MinimumSize:] b.SetTOS(i.TrafficClass, i.FlowLabel) b.SetPayloadLength(i.PayloadLength) - b[IPv6NextHeaderOffset] = i.NextHeader b[hopLimit] = i.HopLimit b.SetSourceAddress(i.SrcAddr) b.SetDestinationAddress(i.DstAddr) + nextHeader, _ := i.ExtensionHeaders.Serialize(i.TransportProtocol, extHdr) + b[IPv6NextHeaderOffset] = nextHeader } // IsValid performs basic validation on the packet. @@ -281,7 +296,7 @@ func IsV4MappedAddress(addr tcpip.Address) bool { return false } - return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff") + return IPv4MappedIPv6Subnet.Contains(addr) } // IsV6MulticastAddress determines if the provided address is an IPv6 @@ -387,17 +402,6 @@ func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope } -// IsV6UniqueLocalAddress determines if the provided address is an IPv6 -// unique-local address (within the prefix FC00::/7). -func IsV6UniqueLocalAddress(addr tcpip.Address) bool { - if len(addr) != IPv6AddressSize { - return false - } - // According to RFC 4193 section 3.1, a unique local address has the prefix - // FC00::/7. - return (addr[0] & 0xfe) == 0xfc -} - // AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier // (IID) to buf as outlined by RFC 7217 and returns the extended buffer. // @@ -444,9 +448,6 @@ const ( // LinkLocalScope indicates a link-local address. LinkLocalScope IPv6AddressScope = iota - // UniqueLocalScope indicates a unique-local address. - UniqueLocalScope - // GlobalScope indicates a global address. GlobalScope ) @@ -464,9 +465,6 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { case IsV6LinkLocalAddress(addr): return LinkLocalScope, nil - case IsV6UniqueLocalAddress(addr): - return UniqueLocalScope, nil - default: return GlobalScope, nil } diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 583c2c5d3..f18981332 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -18,9 +18,12 @@ import ( "bufio" "bytes" "encoding/binary" + "errors" "fmt" "io" + "math" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -47,6 +50,11 @@ const ( // IPv6NoNextHeaderIdentifier is the header identifier used to signify the end // of an IPv6 payload, as per RFC 8200 section 4.7. IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59 + + // IPv6UnknownExtHdrIdentifier is reserved by IANA. + // https://www.iana.org/assignments/ipv6-parameters/ipv6-parameters.xhtml#extension-header + // "254 Use for experimentation and testing [RFC3692][RFC4727]" + IPv6UnknownExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 254 ) const ( @@ -70,8 +78,8 @@ const ( // Fragment Offset field within an IPv6FragmentExtHdr. ipv6FragmentExtHdrFragmentOffsetOffset = 0 - // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to - // discard from the Fragment Offset. + // ipv6FragmentExtHdrFragmentOffsetShift is the bit offset of the Fragment + // Offset field within an IPv6FragmentExtHdr. ipv6FragmentExtHdrFragmentOffsetShift = 3 // ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an @@ -109,6 +117,37 @@ const ( IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8 ) +// padIPv6OptionsLength returns the total length for IPv6 options of length l +// considering the 8-octet alignment as stated in RFC 8200 Section 4.2. +func padIPv6OptionsLength(length int) int { + return (length + ipv6ExtHdrLenBytesPerUnit - 1) & ^(ipv6ExtHdrLenBytesPerUnit - 1) +} + +// padIPv6Option fills b with the appropriate padding options depending on its +// length. +func padIPv6Option(b []byte) { + switch len(b) { + case 0: // No padding needed. + case 1: // Pad with Pad1. + b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6Pad1ExtHdrOptionIdentifier) + default: // Pad with PadN. + s := b[ipv6ExtHdrOptionPayloadOffset:] + for i := range s { + s[i] = 0 + } + b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6PadNExtHdrOptionIdentifier) + b[ipv6ExtHdrOptionLengthOffset] = uint8(len(s)) + } +} + +// ipv6OptionsAlignmentPadding returns the number of padding bytes needed to +// serialize an option at headerOffset with alignment requirements +// [align]n + alignOffset. +func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) int { + padLen := headerOffset - alignOffset + return ((padLen + align - 1) & ^(align - 1)) - padLen +} + // IPv6PayloadHeader is implemented by the various headers that can be found // in an IPv6 payload. // @@ -201,29 +240,55 @@ type IPv6ExtHdrOption interface { isIPv6ExtHdrOption() } -// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier. -type IPv6ExtHdrOptionIndentifier uint8 +// IPv6ExtHdrOptionIdentifier is an IPv6 extension header option identifier. +type IPv6ExtHdrOptionIdentifier uint8 const ( // ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that // provides 1 byte padding, as outlined in RFC 8200 section 4.2. - ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0 + ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 0 // ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that // provides variable length byte padding, as outlined in RFC 8200 section 4.2. - ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1 + ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 1 + + // ipv6RouterAlertHopByHopOptionIdentifier is the identifier for the Router + // Alert Hop by Hop option as defined in RFC 2711 section 2.1. + ipv6RouterAlertHopByHopOptionIdentifier IPv6ExtHdrOptionIdentifier = 5 + + // ipv6ExtHdrOptionTypeOffset is the option type offset in an extension header + // option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionTypeOffset = 0 + + // ipv6ExtHdrOptionLengthOffset is the option length offset in an extension + // header option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionLengthOffset = 1 + + // ipv6ExtHdrOptionPayloadOffset is the option payload offset in an extension + // header option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionPayloadOffset = 2 ) +// ipv6UnknownActionFromIdentifier maps an extension header option's +// identifier's high bits to the action to take when the identifier is unknown. +func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUnknownAction { + return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift) +} + +// ErrMalformedIPv6ExtHdrOption indicates that an IPv6 extension header option +// is malformed. +var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option") + // IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension // header option that is unknown by the parsing utilities. type IPv6UnknownExtHdrOption struct { - Identifier IPv6ExtHdrOptionIndentifier + Identifier IPv6ExtHdrOptionIdentifier Data []byte } // UnknownAction implements IPv6OptionUnknownAction.UnknownAction. func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction { - return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift) + return ipv6UnknownActionFromIdentifier(o.Identifier) } // isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption. @@ -246,7 +311,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error // options buffer has been exhausted and we are done iterating. return nil, true, nil } - id := IPv6ExtHdrOptionIndentifier(temp) + id := IPv6ExtHdrOptionIdentifier(temp) // If the option identifier indicates the option is a Pad1 option, then we // know the option does not have Length and Data fields. End processing of @@ -289,6 +354,19 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err)) } continue + case ipv6RouterAlertHopByHopOptionIdentifier: + var routerAlertValue [ipv6RouterAlertPayloadLength]byte + if n, err := io.ReadFull(&i.reader, routerAlertValue[:]); err != nil { + switch err { + case io.EOF, io.ErrUnexpectedEOF: + return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption) + default: + return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err) + } + } else if n != int(length) { + return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption) + } + return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil default: bytes := make([]byte, length) if n, err := io.ReadFull(&i.reader, bytes); err != nil { @@ -452,9 +530,11 @@ func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader { // Since we consume the iterator, we return the payload as is. buf = i.payload - // Mark i as done. + // Mark i as done, but keep track of where we were for error reporting. *i = IPv6PayloadIterator{ nextHdrIdentifier: IPv6NoNextHeaderIdentifier, + headerOffset: i.headerOffset, + nextOffset: i.nextOffset, } } else { buf = i.payload.Clone(nil) @@ -602,3 +682,248 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil } + +// IPv6SerializableExtHdr provides serialization for IPv6 extension +// headers. +type IPv6SerializableExtHdr interface { + // identifier returns the assigned IPv6 header identifier for this extension + // header. + identifier() IPv6ExtensionHeaderIdentifier + + // length returns the total serialized length in bytes of this extension + // header, including the common next header and length fields. + length() int + + // serializeInto serializes the receiver into the provided byte + // buffer and with the provided nextHeader value. + // + // Note, the caller MUST provide a byte buffer with size of at least + // length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto returns the number of bytes that was used to serialize the + // receiver. Implementers must only use the number of bytes required to + // serialize the receiver. Callers MAY provide a larger buffer than required + // to serialize into. + serializeInto(nextHeader uint8, b []byte) int +} + +var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil) + +// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop +// options extension header. +type IPv6SerializableHopByHopExtHdr []IPv6SerializableHopByHopOption + +const ( + // ipv6HopByHopExtHdrNextHeaderOffset is the offset of the next header field + // in a hop by hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrNextHeaderOffset = 0 + + // ipv6HopByHopExtHdrLengthOffset is the offset of the length field in a hop + // by hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrLengthOffset = 1 + + // ipv6HopByHopExtHdrPayloadOffset is the offset of the options in a hop by + // hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrOptionsOffset = 2 + + // ipv6HopByHopExtHdrUnaccountedLenWords is the implicit number of 8-octet + // words in a hop by hop extension header's length field, as stated in RFC + // 8200 section 4.3: + // Length of the Hop-by-Hop Options header in 8-octet units, + // not including the first 8 octets. + ipv6HopByHopExtHdrUnaccountedLenWords = 1 +) + +// identifier implements IPv6SerializableExtHdr. +func (IPv6SerializableHopByHopExtHdr) identifier() IPv6ExtensionHeaderIdentifier { + return IPv6HopByHopOptionsExtHdrIdentifier +} + +// length implements IPv6SerializableExtHdr. +func (h IPv6SerializableHopByHopExtHdr) length() int { + var total int + for _, opt := range h { + align, alignOffset := opt.alignment() + total += ipv6OptionsAlignmentPadding(total, align, alignOffset) + total += ipv6ExtHdrOptionPayloadOffset + int(opt.length()) + } + // Account for next header and total length fields and add padding. + return padIPv6OptionsLength(ipv6HopByHopExtHdrOptionsOffset + total) +} + +// serializeInto implements IPv6SerializableExtHdr. +func (h IPv6SerializableHopByHopExtHdr) serializeInto(nextHeader uint8, b []byte) int { + optBuffer := b[ipv6HopByHopExtHdrOptionsOffset:] + totalLength := ipv6HopByHopExtHdrOptionsOffset + for _, opt := range h { + // Calculate alignment requirements and pad buffer if necessary. + align, alignOffset := opt.alignment() + padLen := ipv6OptionsAlignmentPadding(totalLength, align, alignOffset) + if padLen != 0 { + padIPv6Option(optBuffer[:padLen]) + totalLength += padLen + optBuffer = optBuffer[padLen:] + } + + l := opt.serializeInto(optBuffer[ipv6ExtHdrOptionPayloadOffset:]) + optBuffer[ipv6ExtHdrOptionTypeOffset] = uint8(opt.identifier()) + optBuffer[ipv6ExtHdrOptionLengthOffset] = l + l += ipv6ExtHdrOptionPayloadOffset + totalLength += int(l) + optBuffer = optBuffer[l:] + } + padded := padIPv6OptionsLength(totalLength) + if padded != totalLength { + padIPv6Option(optBuffer[:padded-totalLength]) + totalLength = padded + } + wordsLen := totalLength/ipv6ExtHdrLenBytesPerUnit - ipv6HopByHopExtHdrUnaccountedLenWords + if wordsLen > math.MaxUint8 { + panic(fmt.Sprintf("IPv6 hop by hop options too large: %d+1 64-bit words", wordsLen)) + } + b[ipv6HopByHopExtHdrNextHeaderOffset] = nextHeader + b[ipv6HopByHopExtHdrLengthOffset] = uint8(wordsLen) + return totalLength +} + +// IPv6SerializableHopByHopOption provides serialization for hop by hop options. +type IPv6SerializableHopByHopOption interface { + // identifier returns the option identifier of this Hop by Hop option. + identifier() IPv6ExtHdrOptionIdentifier + + // length returns the *payload* size of the option (not considering the type + // and length fields). + length() uint8 + + // alignment returns the alignment requirements from this option. + // + // Alignment requirements take the form [align]n + offset as specified in + // RFC 8200 section 4.2. The alignment requirement is on the offset between + // the option type byte and the start of the hop by hop header. + // + // align must be a power of 2. + alignment() (align int, offset int) + + // serializeInto serializes the receiver into the provided byte + // buffer. + // + // Note, the caller MUST provide a byte buffer with size of at least + // length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto will return the number of bytes that was used to + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a + // larger buffer than required to serialize into. + serializeInto([]byte) uint8 +} + +var _ IPv6SerializableHopByHopOption = (*IPv6RouterAlertOption)(nil) + +// IPv6RouterAlertOption is the IPv6 Router alert Hop by Hop option defined in +// RFC 2711 section 2.1. +type IPv6RouterAlertOption struct { + Value IPv6RouterAlertValue +} + +// IPv6RouterAlertValue is the payload of an IPv6 Router Alert option. +type IPv6RouterAlertValue uint16 + +const ( + // IPv6RouterAlertMLD indicates a datagram containing a Multicast Listener + // Discovery message as defined in RFC 2711 section 2.1. + IPv6RouterAlertMLD IPv6RouterAlertValue = 0 + // IPv6RouterAlertRSVP indicates a datagram containing an RSVP message as + // defined in RFC 2711 section 2.1. + IPv6RouterAlertRSVP IPv6RouterAlertValue = 1 + // IPv6RouterAlertActiveNetworks indicates a datagram containing an Active + // Networks message as defined in RFC 2711 section 2.1. + IPv6RouterAlertActiveNetworks IPv6RouterAlertValue = 2 + + // ipv6RouterAlertPayloadLength is the length of the Router Alert payload + // as defined in RFC 2711. + ipv6RouterAlertPayloadLength = 2 + + // ipv6RouterAlertAlignmentRequirement is the alignment requirement for the + // Router Alert option defined as 2n+0 in RFC 2711. + ipv6RouterAlertAlignmentRequirement = 2 + + // ipv6RouterAlertAlignmentOffsetRequirement is the alignment offset + // requirement for the Router Alert option defined as 2n+0 in RFC 2711 section + // 2.1. + ipv6RouterAlertAlignmentOffsetRequirement = 0 +) + +// UnknownAction implements IPv6ExtHdrOption. +func (*IPv6RouterAlertOption) UnknownAction() IPv6OptionUnknownAction { + return ipv6UnknownActionFromIdentifier(ipv6RouterAlertHopByHopOptionIdentifier) +} + +// isIPv6ExtHdrOption implements IPv6ExtHdrOption. +func (*IPv6RouterAlertOption) isIPv6ExtHdrOption() {} + +// identifier implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) identifier() IPv6ExtHdrOptionIdentifier { + return ipv6RouterAlertHopByHopOptionIdentifier +} + +// length implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) length() uint8 { + return ipv6RouterAlertPayloadLength +} + +// alignment implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) alignment() (int, int) { + // From RFC 2711 section 2.1: + // Alignment requirement: 2n+0. + return ipv6RouterAlertAlignmentRequirement, ipv6RouterAlertAlignmentOffsetRequirement +} + +// serializeInto implements IPv6SerializableHopByHopOption. +func (o *IPv6RouterAlertOption) serializeInto(b []byte) uint8 { + binary.BigEndian.PutUint16(b, uint16(o.Value)) + return ipv6RouterAlertPayloadLength +} + +// IPv6ExtHdrSerializer provides serialization of IPv6 extension headers. +type IPv6ExtHdrSerializer []IPv6SerializableExtHdr + +// Serialize serializes the provided list of IPv6 extension headers into b. +// +// Note, b must be of sufficient size to hold all the headers in s. See +// IPv6ExtHdrSerializer.Length for details on the getting the total size of a +// serialized IPv6ExtHdrSerializer. +// +// Serialize may panic if b is not of sufficient size to hold all the options +// in s. +// +// Serialize takes the transportProtocol value to be used as the last extension +// header's Next Header value and returns the header identifier of the first +// serialized extension header and the total serialized length. +func (s IPv6ExtHdrSerializer) Serialize(transportProtocol tcpip.TransportProtocolNumber, b []byte) (uint8, int) { + nextHeader := uint8(transportProtocol) + if len(s) == 0 { + return nextHeader, 0 + } + var totalLength int + for i, h := range s[:len(s)-1] { + length := h.serializeInto(uint8(s[i+1].identifier()), b) + b = b[length:] + totalLength += length + } + totalLength += s[len(s)-1].serializeInto(nextHeader, b) + return uint8(s[0].identifier()), totalLength +} + +// Length returns the total number of bytes required to serialize the extension +// headers. +func (s IPv6ExtHdrSerializer) Length() int { + var totalLength int + for _, h := range s { + totalLength += h.length() + } + return totalLength +} diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go index ab20c5f37..65adc6250 100644 --- a/pkg/tcpip/header/ipv6_extension_headers_test.go +++ b/pkg/tcpip/header/ipv6_extension_headers_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -59,7 +60,7 @@ func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool func TestIPv6UnknownExtHdrOption(t *testing.T) { tests := []struct { name string - identifier IPv6ExtHdrOptionIndentifier + identifier IPv6ExtHdrOptionIdentifier expectedUnknownAction IPv6OptionUnknownAction }{ { @@ -211,6 +212,31 @@ func TestIPv6OptionsExtHdrIterErr(t *testing.T) { bytes: []byte{1, 3}, err: io.ErrUnexpectedEOF, }, + { + name: "Router alert without data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 0}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with partial data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with partial data and Pad1", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1, 0}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with extra data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 3, 1, 2, 3}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with missing data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1}, + err: io.ErrUnexpectedEOF, + }, } check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) { @@ -990,3 +1016,331 @@ func TestIPv6ExtHdrIter(t *testing.T) { }) } } + +var _ IPv6SerializableHopByHopOption = (*dummyHbHOptionSerializer)(nil) + +// dummyHbHOptionSerializer provides a generic implementation of +// IPv6SerializableHopByHopOption for use in tests. +type dummyHbHOptionSerializer struct { + id IPv6ExtHdrOptionIdentifier + payload []byte + align int + alignOffset int +} + +// identifier implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) identifier() IPv6ExtHdrOptionIdentifier { + return s.id +} + +// length implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) length() uint8 { + return uint8(len(s.payload)) +} + +// alignment implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) alignment() (int, int) { + align := 1 + if s.align != 0 { + align = s.align + } + return align, s.alignOffset +} + +// serializeInto implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) serializeInto(b []byte) uint8 { + return uint8(copy(b, s.payload)) +} + +func TestIPv6HopByHopSerializer(t *testing.T) { + validateDummies := func(t *testing.T, serializable IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) { + t.Helper() + dummy, ok := serializable.(*dummyHbHOptionSerializer) + if !ok { + t.Fatalf("got serializable = %T, want = *dummyHbHOptionSerializer", serializable) + } + unknown, ok := deserialized.(*IPv6UnknownExtHdrOption) + if !ok { + t.Fatalf("got deserialized = %T, want = %T", deserialized, &IPv6UnknownExtHdrOption{}) + } + if dummy.id != unknown.Identifier { + t.Errorf("got deserialized identifier = %d, want = %d", unknown.Identifier, dummy.id) + } + if diff := cmp.Diff(dummy.payload, unknown.Data); diff != "" { + t.Errorf("option payload deserialization mismatch (-want +got):\n%s", diff) + } + } + tests := []struct { + name string + nextHeader uint8 + options []IPv6SerializableHopByHopOption + expect []byte + validate func(*testing.T, IPv6SerializableHopByHopOption, IPv6ExtHdrOption) + }{ + { + name: "single option", + nextHeader: 13, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 15, + payload: []byte{9, 8, 7, 6}, + }, + }, + expect: []byte{13, 0, 15, 4, 9, 8, 7, 6}, + validate: validateDummies, + }, + { + name: "short option padN zero", + nextHeader: 88, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5}, + }, + }, + expect: []byte{88, 0, 22, 2, 4, 5, 1, 0}, + validate: validateDummies, + }, + { + name: "short option pad1", + nextHeader: 11, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 33, + payload: []byte{1, 2, 3}, + }, + }, + expect: []byte{11, 0, 33, 3, 1, 2, 3, 0}, + validate: validateDummies, + }, + { + name: "long option padN", + nextHeader: 55, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 77, + payload: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + }, + }, + expect: []byte{55, 1, 77, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 0, 0}, + validate: validateDummies, + }, + { + name: "two options", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2, 3}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + }, + }, + expect: []byte{33, 1, 11, 3, 1, 2, 3, 22, 3, 4, 5, 6, 1, 2, 0, 0}, + validate: validateDummies, + }, + { + name: "two options align 2n", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2, 3}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + align: 2, + }, + }, + expect: []byte{33, 1, 11, 3, 1, 2, 3, 0, 22, 3, 4, 5, 6, 1, 1, 0}, + validate: validateDummies, + }, + { + name: "two options align 8n+1", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + align: 8, + alignOffset: 1, + }, + }, + expect: []byte{33, 1, 11, 2, 1, 2, 1, 1, 0, 22, 3, 4, 5, 6, 1, 0}, + validate: validateDummies, + }, + { + name: "no options", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{}, + expect: []byte{33, 0, 1, 4, 0, 0, 0, 0}, + }, + { + name: "Router Alert", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{&IPv6RouterAlertOption{Value: IPv6RouterAlertMLD}}, + expect: []byte{33, 0, 5, 2, 0, 0, 1, 0}, + validate: func(t *testing.T, _ IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) { + t.Helper() + routerAlert, ok := deserialized.(*IPv6RouterAlertOption) + if !ok { + t.Fatalf("got deserialized = %T, want = *IPv6RouterAlertOption", deserialized) + } + if routerAlert.Value != IPv6RouterAlertMLD { + t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, IPv6RouterAlertMLD) + } + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := IPv6SerializableHopByHopExtHdr(test.options) + length := s.length() + if length != len(test.expect) { + t.Fatalf("got s.length() = %d, want = %d", length, len(test.expect)) + } + b := make([]byte, length) + for i := range b { + // Fill the buffer with ones to ensure all padding is correctly set. + b[i] = 0xFF + } + if got := s.serializeInto(test.nextHeader, b); got != length { + t.Fatalf("got s.serializeInto(..) = %d, want = %d", got, length) + } + if diff := cmp.Diff(test.expect, b); diff != "" { + t.Fatalf("serialization mismatch (-want +got):\n%s", diff) + } + + // Deserialize the options and verify them. + optLen := (b[ipv6HopByHopExtHdrLengthOffset] + ipv6HopByHopExtHdrUnaccountedLenWords) * ipv6ExtHdrLenBytesPerUnit + iter := ipv6OptionsExtHdr(b[ipv6HopByHopExtHdrOptionsOffset:optLen]).Iter() + for _, testOpt := range test.options { + opt, done, err := iter.Next() + if err != nil { + t.Fatalf("iter.Next(): %s", err) + } + if done { + t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, false, _)", opt, done) + } + test.validate(t, testOpt, opt) + } + opt, done, err := iter.Next() + if err != nil { + t.Fatalf("iter.Next(): %s", err) + } + if !done { + t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, true, _)", opt, done) + } + }) + } +} + +var _ IPv6SerializableExtHdr = (*dummyIPv6ExtHdrSerializer)(nil) + +// dummyIPv6ExtHdrSerializer provides a generic implementation of +// IPv6SerializableExtHdr for use in tests. +// +// The dummy header always carries the nextHeader value in the first byte. +type dummyIPv6ExtHdrSerializer struct { + id IPv6ExtensionHeaderIdentifier + headerContents []byte +} + +// identifier implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) identifier() IPv6ExtensionHeaderIdentifier { + return s.id +} + +// length implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) length() int { + return len(s.headerContents) + 1 +} + +// serializeInto implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) serializeInto(nextHeader uint8, b []byte) int { + b[0] = nextHeader + return copy(b[1:], s.headerContents) + 1 +} + +func TestIPv6ExtHdrSerializer(t *testing.T) { + tests := []struct { + name string + headers []IPv6SerializableExtHdr + nextHeader tcpip.TransportProtocolNumber + expectSerialized []byte + expectNextHeader uint8 + }{ + { + name: "one header", + headers: []IPv6SerializableExtHdr{ + &dummyIPv6ExtHdrSerializer{ + id: 15, + headerContents: []byte{1, 2, 3, 4}, + }, + }, + nextHeader: TCPProtocolNumber, + expectSerialized: []byte{byte(TCPProtocolNumber), 1, 2, 3, 4}, + expectNextHeader: 15, + }, + { + name: "two headers", + headers: []IPv6SerializableExtHdr{ + &dummyIPv6ExtHdrSerializer{ + id: 22, + headerContents: []byte{1, 2, 3}, + }, + &dummyIPv6ExtHdrSerializer{ + id: 23, + headerContents: []byte{4, 5, 6}, + }, + }, + nextHeader: ICMPv6ProtocolNumber, + expectSerialized: []byte{ + 23, 1, 2, 3, + byte(ICMPv6ProtocolNumber), 4, 5, 6, + }, + expectNextHeader: 22, + }, + { + name: "no headers", + headers: []IPv6SerializableExtHdr{}, + nextHeader: UDPProtocolNumber, + expectSerialized: []byte{}, + expectNextHeader: byte(UDPProtocolNumber), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := IPv6ExtHdrSerializer(test.headers) + l := s.Length() + if got, want := l, len(test.expectSerialized); got != want { + t.Fatalf("got serialized length = %d, want = %d", got, want) + } + b := make([]byte, l) + for i := range b { + // Fill the buffer with garbage to make sure we're writing to all bytes. + b[i] = 0xFF + } + nextHeader, serializedLen := s.Serialize(test.nextHeader, b) + if serializedLen != len(test.expectSerialized) || nextHeader != test.expectNextHeader { + t.Errorf( + "got s.Serialize(..) = (%d, %d), want = (%d, %d)", + nextHeader, + serializedLen, + test.expectNextHeader, + len(test.expectSerialized), + ) + } + if diff := cmp.Diff(test.expectSerialized, b); diff != "" { + t.Errorf("serialization mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/header/ipv6_fragment.go b/pkg/tcpip/header/ipv6_fragment.go index 018555a26..9d09f32eb 100644 --- a/pkg/tcpip/header/ipv6_fragment.go +++ b/pkg/tcpip/header/ipv6_fragment.go @@ -27,12 +27,11 @@ const ( idV6 = 4 ) -// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the -// fields of a packet that needs to be encoded. -type IPv6FragmentFields struct { - // NextHeader is the "next header" field of an IPv6 fragment. - NextHeader uint8 +var _ IPv6SerializableExtHdr = (*IPv6SerializableFragmentExtHdr)(nil) +// IPv6SerializableFragmentExtHdr is used to serialize an IPv6 fragment +// extension header as defined in RFC 8200 section 4.5. +type IPv6SerializableFragmentExtHdr struct { // FragmentOffset is the "fragment offset" field of an IPv6 fragment. FragmentOffset uint16 @@ -43,6 +42,29 @@ type IPv6FragmentFields struct { Identification uint32 } +// identifier implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) identifier() IPv6ExtensionHeaderIdentifier { + return IPv6FragmentHeader +} + +// length implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) length() int { + return IPv6FragmentHeaderSize +} + +// serializeInto implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) serializeInto(nextHeader uint8, b []byte) int { + // Prevent too many bounds checks. + _ = b[IPv6FragmentHeaderSize:] + binary.BigEndian.PutUint32(b[idV6:], h.Identification) + binary.BigEndian.PutUint16(b[fragOff:], h.FragmentOffset<<ipv6FragmentExtHdrFragmentOffsetShift) + b[nextHdrFrag] = nextHeader + if h.M { + b[more] |= ipv6FragmentExtHdrMFlagMask + } + return IPv6FragmentHeaderSize +} + // IPv6Fragment represents an ipv6 fragment header stored in a byte array. // Most of the methods of IPv6Fragment access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. @@ -58,16 +80,6 @@ const ( IPv6FragmentHeaderSize = 8 ) -// Encode encodes all the fields of the ipv6 fragment. -func (b IPv6Fragment) Encode(i *IPv6FragmentFields) { - b[nextHdrFrag] = i.NextHeader - binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3) - if i.M { - b[more] |= 1 - } - binary.BigEndian.PutUint32(b[idV6:], i.Identification) -} - // IsValid performs basic validation on the fragment header. func (b IPv6Fragment) IsValid() bool { return len(b) >= IPv6FragmentHeaderSize diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index 426a873b1..e3fbd64f3 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -215,48 +215,6 @@ func TestLinkLocalAddrWithOpaqueIID(t *testing.T) { } } -func TestIsV6UniqueLocalAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expected bool - }{ - { - name: "Valid Unique 1", - addr: uniqueLocalAddr1, - expected: true, - }, - { - name: "Valid Unique 2", - addr: uniqueLocalAddr1, - expected: true, - }, - { - name: "Link Local", - addr: linkLocalAddr, - expected: false, - }, - { - name: "Global", - addr: globalAddr, - expected: false, - }, - { - name: "IPv4", - addr: "\x01\x02\x03\x04", - expected: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) - } - }) - } -} - func TestIsV6LinkLocalMulticastAddress(t *testing.T) { tests := []struct { name string @@ -346,7 +304,7 @@ func TestScopeForIPv6Address(t *testing.T) { { name: "Unique Local", addr: uniqueLocalAddr1, - scope: header.UniqueLocalScope, + scope: header.GlobalScope, err: nil, }, { diff --git a/pkg/tcpip/header/mld.go b/pkg/tcpip/header/mld.go new file mode 100644 index 000000000..ffe03c76a --- /dev/null +++ b/pkg/tcpip/header/mld.go @@ -0,0 +1,103 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // MLDMinimumSize is the minimum size for an MLD message. + MLDMinimumSize = 20 + + // MLDHopLimit is the Hop Limit for all IPv6 packets with an MLD message, as + // per RFC 2710 section 3. + MLDHopLimit = 1 + + // mldMaximumResponseDelayOffset is the offset to the Maximum Response Delay + // field within MLD. + mldMaximumResponseDelayOffset = 0 + + // mldMulticastAddressOffset is the offset to the Multicast Address field + // within MLD. + mldMulticastAddressOffset = 4 +) + +// MLD is a Multicast Listener Discovery message in an ICMPv6 packet. +// +// MLD will only contain the body of an ICMPv6 packet. +// +// As per RFC 2710 section 3, MLD messages have the following format (MLD only +// holds the bytes after the first four bytes in the diagram below): +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Type | Code | Checksum | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Maximum Response Delay | Reserved | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | | +// + + +// | | +// + Multicast Address + +// | | +// + + +// | | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +type MLD []byte + +// MaximumResponseDelay returns the Maximum Response Delay. +func (m MLD) MaximumResponseDelay() time.Duration { + // As per RFC 2710 section 3.4: + // + // The Maximum Response Delay field is meaningful only in Query + // messages, and specifies the maximum allowed delay before sending a + // responding Report, in units of milliseconds. In all other messages, + // it is set to zero by the sender and ignored by receivers. + return time.Duration(binary.BigEndian.Uint16(m[mldMaximumResponseDelayOffset:])) * time.Millisecond +} + +// SetMaximumResponseDelay sets the Maximum Response Delay field. +// +// maxRespDelayMS is the value in milliseconds. +func (m MLD) SetMaximumResponseDelay(maxRespDelayMS uint16) { + binary.BigEndian.PutUint16(m[mldMaximumResponseDelayOffset:], maxRespDelayMS) +} + +// MulticastAddress returns the Multicast Address. +func (m MLD) MulticastAddress() tcpip.Address { + // As per RFC 2710 section 3.5: + // + // In a Query message, the Multicast Address field is set to zero when + // sending a General Query, and set to a specific IPv6 multicast address + // when sending a Multicast-Address-Specific Query. + // + // In a Report or Done message, the Multicast Address field holds a + // specific IPv6 multicast address to which the message sender is + // listening or is ceasing to listen, respectively. + return tcpip.Address(m[mldMulticastAddressOffset:][:IPv6AddressSize]) +} + +// SetMulticastAddress sets the Multicast Address field. +func (m MLD) SetMulticastAddress(multicastAddress tcpip.Address) { + if n := copy(m[mldMulticastAddressOffset:], multicastAddress); n != IPv6AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected to copy %d bytes", n, IPv6AddressSize)) + } +} diff --git a/pkg/tcpip/header/mld_test.go b/pkg/tcpip/header/mld_test.go new file mode 100644 index 000000000..0cecf10d4 --- /dev/null +++ b/pkg/tcpip/header/mld_test.go @@ -0,0 +1,61 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func TestMLD(t *testing.T) { + b := []byte{ + // Maximum Response Delay + 0, 0, + + // Reserved + 0, 0, + + // MulticastAddress + 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, + } + + const maxRespDelay = 513 + binary.BigEndian.PutUint16(b, maxRespDelay) + + mld := MLD(b) + + if got, want := mld.MaximumResponseDelay(), maxRespDelay*time.Millisecond; got != want { + t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want) + } + + const newMaxRespDelay = 1234 + mld.SetMaximumResponseDelay(newMaxRespDelay) + if got, want := mld.MaximumResponseDelay(), newMaxRespDelay*time.Millisecond; got != want { + t.Errorf("got mld.MaximumResponseDelay() = %s, want = %s", got, want) + } + + if got, want := mld.MulticastAddress(), tcpip.Address([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}); got != want { + t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, want) + } + + multicastAddress := tcpip.Address([]byte{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}) + mld.SetMulticastAddress(multicastAddress) + if got := mld.MulticastAddress(); got != multicastAddress { + t.Errorf("got mld.MulticastAddress() = %s, want = %s", got, multicastAddress) + } +} diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go index 5d3975c56..554242f0c 100644 --- a/pkg/tcpip/header/ndp_options.go +++ b/pkg/tcpip/header/ndp_options.go @@ -298,7 +298,7 @@ func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) { return it, nil } -// Serialize serializes the provided list of NDP options into o. +// Serialize serializes the provided list of NDP options into b. // // Note, b must be of sufficient size to hold all the options in s. See // NDPOptionsSerializer.Length for details on the getting the total size diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go index 5ca75c834..2042f214a 100644 --- a/pkg/tcpip/header/parse/parse.go +++ b/pkg/tcpip/header/parse/parse.go @@ -109,6 +109,9 @@ traverseExtensions: fragOffset = extHdr.FragmentOffset() fragMore = extHdr.More() } + rawPayload := it.AsRawHeader(true /* consume */) + extensionsSize = dataClone.Size() - rawPayload.Buf.Size() + break traverseExtensions case header.IPv6RawPayloadHeader: // We've found the payload after any extensions. diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index 39ca774ef..973f06cbc 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -9,7 +9,6 @@ go_library( deps = [ "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index c95aef63c..d9f8e3b35 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -32,7 +31,7 @@ type PacketInfo struct { Pkt *stack.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO - Route stack.Route + Route stack.RouteInfo } // Notification is the interface for receiving notification from the packet @@ -231,15 +230,11 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket stores outbound packets into the channel. func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - // Clone r then release its resource so we only get the relevant fields from - // stack.Route without holding a reference to a NIC's endpoint. - route := r.Clone() - route.Release() p := PacketInfo{ Pkt: pkt, Proto: protocol, GSO: gso, - Route: route, + Route: r.GetFields(), } e.q.Write(p) @@ -249,17 +244,13 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets stores outbound packets into the channel. func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // Clone r then release its resource so we only get the relevant fields from - // stack.Route without holding a reference to a NIC's endpoint. - route := r.Clone() - route.Release() n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ Pkt: pkt, Proto: protocol, GSO: gso, - Route: route, + Route: r.GetFields(), } if !e.q.Write(p) { @@ -271,21 +262,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n, nil } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - p := PacketInfo{ - Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }), - Proto: 0, - GSO: nil, - } - - e.q.Write(p) - - return nil -} - // Wait implements stack.LinkEndpoint.Wait. func (*Endpoint) Wait() {} diff --git a/pkg/tcpip/link/ethernet/BUILD b/pkg/tcpip/link/ethernet/BUILD index ec92ed623..0ae0d201a 100644 --- a/pkg/tcpip/link/ethernet/BUILD +++ b/pkg/tcpip/link/ethernet/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -13,3 +13,17 @@ go_library( "//pkg/tcpip/stack", ], ) + +go_test( + name = "ethernet_test", + size = "small", + srcs = ["ethernet_test.go"], + deps = [ + ":ethernet", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go index 3eef7cd56..89e3e6164 100644 --- a/pkg/tcpip/link/ethernet/ethernet.go +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -49,10 +49,10 @@ func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkP return } + // Note, there is no need to check the destination link address here since + // the ethernet hardware filters frames based on their destination addresses. eth := header.Ethernet(hdr) - if dst := eth.DestinationAddress(); dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { - e.Endpoint.DeliverNetworkPacket(eth.SourceAddress() /* remote */, dst /* local */, eth.Type() /* protocol */, pkt) - } + e.Endpoint.DeliverNetworkPacket(eth.SourceAddress() /* remote */, eth.DestinationAddress() /* local */, eth.Type() /* protocol */, pkt) } // Capabilities implements stack.LinkEndpoint. @@ -62,7 +62,7 @@ func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { // WritePacket implements stack.LinkEndpoint. func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) + e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress(), proto, pkt) return e.Endpoint.WritePacket(r, gso, proto, pkt) } @@ -71,7 +71,7 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe linkAddr := e.Endpoint.LinkAddress() for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt) + e.AddHeader(linkAddr, r.RemoteLinkAddress(), proto, pkt) } return e.Endpoint.WritePackets(r, gso, pkts, proto) diff --git a/pkg/tcpip/link/ethernet/ethernet_test.go b/pkg/tcpip/link/ethernet/ethernet_test.go new file mode 100644 index 000000000..08a7f1ce1 --- /dev/null +++ b/pkg/tcpip/link/ethernet/ethernet_test.go @@ -0,0 +1,71 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ethernet_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkDispatcher = (*testNetworkDispatcher)(nil) + +type testNetworkDispatcher struct { + networkPackets int +} + +func (t *testNetworkDispatcher) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { + t.networkPackets++ +} + +func (*testNetworkDispatcher) DeliverOutboundPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} + +func TestDeliverNetworkPacket(t *testing.T) { + const ( + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + otherLinkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07") + otherLinkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08") + ) + + e := ethernet.New(channel.New(0, 0, linkAddr)) + var networkDispatcher testNetworkDispatcher + e.Attach(&networkDispatcher) + + if networkDispatcher.networkPackets != 0 { + t.Fatalf("got networkDispatcher.networkPackets = %d, want = 0", networkDispatcher.networkPackets) + } + + // An ethernet frame with a destination link address that is not assigned to + // our ethernet link endpoint should still be delivered to the network + // dispatcher since the ethernet endpoint is not expected to filter frames. + eth := buffer.NewView(header.EthernetMinimumSize) + header.Ethernet(eth).Encode(&header.EthernetFields{ + SrcAddr: otherLinkAddr1, + DstAddr: otherLinkAddr2, + Type: header.IPv4ProtocolNumber, + }) + e.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: eth.ToVectorisedView(), + })) + if networkDispatcher.networkPackets != 1 { + t.Fatalf("got networkDispatcher.networkPackets = %d, want = 1", networkDispatcher.networkPackets) + } +} diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 975309fc8..cb94cbea6 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -284,9 +284,12 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher } switch sa.(type) { case *unix.SockaddrLinklayer: - // enable PACKET_FANOUT mode is the underlying socket is - // of type AF_PACKET. - const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG + // Enable PACKET_FANOUT mode if the underlying socket is of type + // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will + // prevent gvisor from receiving fragmented packets and the host does the + // reassembly on our behalf before delivering the fragments. This makes it + // hard to test fragmentation reassembly code in Netstack. + const fanoutType = unix.PACKET_FANOUT_HASH fanoutArg := fanoutID | fanoutType<<16 if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err) @@ -410,7 +413,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // currently writable, the packet is dropped. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { if e.hdrSize > 0 { - e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, pkt) } var builder iovec.Builder @@ -453,7 +456,7 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) for _, pkt := range batch { if e.hdrSize > 0 { - e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) + e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress(), pkt.NetworkProtocolNumber, pkt) } var vnetHdrBuf []byte @@ -558,11 +561,6 @@ func viewsEqual(vs1, vs2 []buffer.View) bool { return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0]) } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - return rawfile.NonBlockingWrite(e.fds[0], vv.ToView()) -} - // InjectOutobund implements stack.InjectableEndpoint.InjectOutbound. func (e *endpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { return rawfile.NonBlockingWrite(e.fds[0], packet) diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 709f829c8..987a34226 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -183,9 +183,8 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize}) defer c.cleanup() - r := &stack.Route{ - RemoteLinkAddress: raddr, - } + var r stack.Route + r.ResolveWith(raddr) // Build payload. payload := buffer.NewView(plen) @@ -220,7 +219,7 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash u L3HdrLen: header.IPv4MaximumHeaderSize, } } - if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil { + if err := c.ep.WritePacket(&r, gso, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -324,10 +323,9 @@ func TestPreserveSrcAddress(t *testing.T) { defer c.cleanup() // Set LocalLinkAddress in route to the value of the bridged address. - r := &stack.Route{ - RemoteLinkAddress: raddr, - LocalLinkAddress: baddr, - } + var r stack.Route + r.LocalLinkAddress = baddr + r.ResolveWith(raddr) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ // WritePacket panics given a prependable with anything less than @@ -336,7 +334,7 @@ func TestPreserveSrcAddress(t *testing.T) { ReserveHeaderBytes: header.EthernetMinimumSize, Data: buffer.VectorisedView{}, }) - if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } @@ -503,7 +501,7 @@ func TestRecvMMsgDispatcherCapLength(t *testing.T) { msgHdrs: make([]rawfile.MMsgHdr, 1), } - for i, _ := range d.views { + for i := range d.views { d.views[i] = make([]buffer.View, len(c.config)) } for i := range d.iovecs { diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 38aa694e4..edca57e4e 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -96,23 +96,6 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList panic("not implemented") } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - // There should be an ethernet header at the beginning of vv. - hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) - if !ok { - // Reject the packet if it's shorter than an ethernet header. - return tcpip.ErrBadAddress - } - linkHeader := header.Ethernet(hdr) - e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt) - - return nil -} - // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*endpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareLoopback diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index e7493e5c5..cbda59775 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index 56a611825..22e79ce3a 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -17,7 +17,6 @@ package muxed import ( "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -106,13 +105,6 @@ func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protoco return tcpip.ErrNoRoute } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - // WriteRawPacket doesn't get a route or network address, so there's - // nowhere to write this. - return tcpip.ErrNoRoute -} - // InjectOutbound writes outbound packets to the appropriate // LinkInjectableEndpoint based on the dest address. func (m *InjectableEndpoint) InjectOutbound(dest tcpip.Address, packet []byte) *tcpip.Error { diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 3e4afcdad..b511d3a31 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -51,7 +51,8 @@ func TestInjectableEndpointDispatch(t *testing.T) { Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), }) pkt.TransportHeader().Push(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} + var packetRoute stack.Route + packetRoute.RemoteAddress = dstIP endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) @@ -73,7 +74,8 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { Data: buffer.NewView(0).ToVectorisedView(), }) pkt.TransportHeader().Push(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} + var packetRoute stack.Route + packetRoute.RemoteAddress = dstIP endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD index 2cdb23475..00b42b924 100644 --- a/pkg/tcpip/link/nested/BUILD +++ b/pkg/tcpip/link/nested/BUILD @@ -11,7 +11,6 @@ go_library( deps = [ "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go index d40de54df..0ee54c3d5 100644 --- a/pkg/tcpip/link/nested/nested.go +++ b/pkg/tcpip/link/nested/nested.go @@ -19,7 +19,6 @@ package nested import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -123,11 +122,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return e.child.WritePackets(r, gso, pkts, protocol) } -// WriteRawPacket implements stack.LinkEndpoint. -func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - return e.child.WriteRawPacket(vv) -} - // Wait implements stack.LinkEndpoint. func (e *Endpoint) Wait() { e.child.Wait() diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go index 3922c2a04..9a1b0c0c2 100644 --- a/pkg/tcpip/link/packetsocket/endpoint.go +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -36,14 +36,14 @@ func New(lower stack.LinkEndpoint) stack.LinkEndpoint { // WritePacket implements stack.LinkEndpoint.WritePacket. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) + e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress(), r.LocalLinkAddress, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } // WritePackets implements stack.LinkEndpoint.WritePackets. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) + e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress(), pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) } return e.Endpoint.WritePackets(r, gso, pkts, proto) diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 523b0d24b..25c364391 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -55,7 +55,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.Network // remote address from the perspective of the other end of the pipe // (e.linked). Similarly, the remote address from the perspective of this // endpoint is the local address on the other end. - e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress() /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), })) @@ -67,11 +67,6 @@ func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, panic("not implemented") } -// WriteRawPacket implements stack.LinkEndpoint. -func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - panic("not implemented") -} - // Attach implements stack.LinkEndpoint. func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD index 1d0079bd6..5bea598eb 100644 --- a/pkg/tcpip/link/qdisc/fifo/BUILD +++ b/pkg/tcpip/link/qdisc/fifo/BUILD @@ -13,7 +13,6 @@ go_library( "//pkg/sleep", "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index fc1e34fc7..b7458b620 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -155,8 +154,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. - newRoute := r.Clone() - pkt.EgressRoute = &newRoute + pkt.EgressRoute = r pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] @@ -179,11 +177,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB for pkt := pkts.Front(); pkt != nil; { d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] nxt := pkt.Next() - // Since qdisc can hold onto a packet for long we should Clone - // the route here to ensure it doesn't get released while the - // packet is still in our queue. - newRoute := pkt.EgressRoute.Clone() - pkt.EgressRoute = &newRoute if !d.q.enqueue(pkt) { if enqueued > 0 { d.newPacketWaker.Assert() @@ -197,13 +190,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB return enqueued, nil } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // TODO(gvisor.dev/issue/3267): Queue these packets as well once - // WriteRawPacket takes PacketBuffer instead of VectorisedView. - return e.lower.WriteRawPacket(vv) -} - // Wait implements stack.LinkEndpoint.Wait. func (e *endpoint) Wait() { e.lower.Wait() diff --git a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go index eb5abb906..45adcbccb 100644 --- a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go +++ b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go @@ -61,6 +61,7 @@ func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool { q.mu.Lock() r := q.used < q.limit if r { + s.EgressRoute.Acquire() q.list.PushBack(s) q.used++ } diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 7fb8a6c49..5660418fa 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -204,7 +204,7 @@ func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net // WritePacket writes outbound packets to the file descriptor. If it is not // currently writable, the packet is dropped. func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress(), protocol, pkt) views := pkt.Views() // Transmit the packet. @@ -224,21 +224,6 @@ func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketB panic("not implemented") } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - views := vv.Views() - // Transmit the packet. - e.mu.Lock() - ok := e.tx.transmit(views...) - e.mu.Unlock() - - if !ok { - return tcpip.ErrWouldBlock - } - - return nil -} - // dispatchLoop reads packets from the rx queue in a loop and dispatches them // to the network stack. func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 22d5c97f1..dd2e1a125 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -260,9 +260,8 @@ func TestSimpleSend(t *testing.T) { defer c.cleanup() // Prepare route. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) for iters := 1000; iters > 0; iters-- { func() { @@ -341,10 +340,9 @@ func TestPreserveSrcAddressInSend(t *testing.T) { newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6)) // Set both remote and local link address in route. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - LocalLinkAddress: newLocalLinkAddress, - } + var r stack.Route + r.LocalLinkAddress = newLocalLinkAddress + r.ResolveWith(remoteLinkAddr) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ // WritePacket panics given a prependable with anything less than @@ -395,9 +393,8 @@ func TestFillTxQueue(t *testing.T) { defer c.cleanup() // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) buf := buffer.NewView(100) @@ -444,9 +441,8 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) { c.txq.rx.Flush() // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) buf := buffer.NewView(100) @@ -509,9 +505,8 @@ func TestFillTxMemory(t *testing.T) { defer c.cleanup() // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) buf := buffer.NewView(100) @@ -557,9 +552,8 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) { defer c.cleanup() // Prepare to send a packet. - r := stack.Route{ - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.ResolveWith(remoteLinkAddr) buf := buffer.NewView(100) diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index b3e8c4b92..1a2cc39eb 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -53,16 +53,35 @@ type endpoint struct { nested.Endpoint writer io.Writer maxPCAPLen uint32 + logPrefix string } var _ stack.GSOEndpoint = (*endpoint)(nil) var _ stack.LinkEndpoint = (*endpoint)(nil) var _ stack.NetworkDispatcher = (*endpoint)(nil) +type direction int + +const ( + directionSend = iota + directionRecv +) + // New creates a new sniffer link-layer endpoint. It wraps around another // endpoint and logs packets and they traverse the endpoint. func New(lower stack.LinkEndpoint) stack.LinkEndpoint { - sniffer := &endpoint{} + return NewWithPrefix(lower, "") +} + +// NewWithPrefix creates a new sniffer link-layer endpoint. It wraps around +// another endpoint and logs packets prefixed with logPrefix as they traverse +// the endpoint. +// +// logPrefix is prepended to the log line without any separators. +// E.g. logPrefix = "NIC:en0/" will produce log lines like +// "NIC:en0/send udp [...]". +func NewWithPrefix(lower stack.LinkEndpoint, logPrefix string) stack.LinkEndpoint { + sniffer := &endpoint{logPrefix: logPrefix} sniffer.Endpoint.Init(lower, sniffer) return sniffer } @@ -120,7 +139,7 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // called by the link-layer endpoint being wrapped when a packet arrives, and // logs the packet before forwarding to the actual dispatcher. func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { - e.dumpPacket("recv", nil, protocol, pkt) + e.dumpPacket(directionRecv, nil, protocol, pkt) e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } @@ -129,10 +148,10 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt) } -func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +func (e *endpoint) dumpPacket(dir direction, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { writer := e.writer if writer == nil && atomic.LoadUint32(&LogPackets) == 1 { - logPacket(prefix, protocol, pkt, gso) + logPacket(e.logPrefix, dir, protocol, pkt, gso) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { totalLength := pkt.Size() @@ -169,7 +188,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw // higher-level protocols to write packets; it just logs the packet and // forwards the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - e.dumpPacket("send", gso, protocol, pkt) + e.dumpPacket(directionSend, gso, protocol, pkt) return e.Endpoint.WritePacket(r, gso, protocol, pkt) } @@ -178,20 +197,12 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // forwards the request to the lower endpoint. func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.dumpPacket("send", gso, protocol, pkt) + e.dumpPacket(directionSend, gso, protocol, pkt) } return e.Endpoint.WritePackets(r, gso, pkts, protocol) } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - })) - return e.Endpoint.WriteRawPacket(vv) -} - -func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { +func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 src := tcpip.Address("unknown") @@ -201,6 +212,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P var fragmentOffset uint16 var moreFragments bool + var directionPrefix string + switch dir { + case directionSend: + directionPrefix = "send" + case directionRecv: + directionPrefix = "recv" + default: + panic(fmt.Sprintf("unrecognized direction: %d", dir)) + } + // Clone the packet buffer to not modify the original. // // We don't clone the original packet buffer so that the new packet buffer @@ -242,21 +263,22 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P fragmentOffset = fragOffset case header.ARPProtocolNumber: - if parse.ARP(pkt) { + if !parse.ARP(pkt) { return } arp := header.ARP(pkt.NetworkHeader().View()) log.Infof( - "%s arp %s (%s) -> %s (%s) valid:%t", + "%s%s arp %s (%s) -> %s (%s) valid:%t", prefix, + directionPrefix, tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()), tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()), arp.IsValid(), ) return default: - log.Infof("%s unknown network protocol", prefix) + log.Infof("%s%s unknown network protocol", prefix, directionPrefix) return } @@ -300,7 +322,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P icmpType = "info reply" } } - log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code()) return case header.ICMPv6ProtocolNumber: @@ -335,7 +357,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6RedirectMsg: icmpType = "redirect message" } - log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + log.Infof("%s%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, directionPrefix, transName, src, dst, icmpType, size, id, icmp.Code()) return case header.UDPProtocolNumber: @@ -391,7 +413,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P } default: - log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto) + log.Infof("%s%s %s -> %s unknown transport protocol: %d", prefix, directionPrefix, src, dst, transProto) return } @@ -399,5 +421,5 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P details += fmt.Sprintf(" gso: %+v", gso) } - log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details) + log.Infof("%s%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, directionPrefix, transName, src, srcPort, dst, dstPort, size, id, details) } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 4c14f55d3..bfac358f4 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -76,29 +76,13 @@ func (d *Device) Release(ctx context.Context) { } } -// NICID returns the NIC ID of the device. -// -// Must only be called after the device has been attached to an endpoint. -func (d *Device) NICID() tcpip.NICID { - d.mu.RLock() - defer d.mu.RUnlock() - - if d.endpoint == nil { - panic("called NICID on a device that has not been attached") - } - - return d.endpoint.nicID -} - // SetIff services TUNSETIFF ioctl(2) request. -// -// Returns true if a new NIC was created; false if an existing one was attached. -func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) { +func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { d.mu.Lock() defer d.mu.Unlock() if d.endpoint != nil { - return false, syserror.EINVAL + return syserror.EINVAL } // Input validations. @@ -106,7 +90,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) isTap := flags&linux.IFF_TAP != 0 supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI) if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 { - return false, syserror.EINVAL + return syserror.EINVAL } prefix := "tun" @@ -119,18 +103,18 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) linkCaps |= stack.CapabilityResolutionRequired } - endpoint, created, err := attachOrCreateNIC(s, name, prefix, linkCaps) + endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps) if err != nil { - return false, syserror.EINVAL + return syserror.EINVAL } d.endpoint = endpoint d.notifyHandle = d.endpoint.AddNotify(d) d.flags = flags - return created, nil + return nil } -func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, bool, error) { +func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) { for { // 1. Try to attach to an existing NIC. if name != "" { @@ -138,13 +122,13 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE endpoint, ok := linkEP.(*tunEndpoint) if !ok { // Not a NIC created by tun device. - return nil, false, syserror.EOPNOTSUPP + return nil, syserror.EOPNOTSUPP } if !endpoint.TryIncRef() { // Race detected: NIC got deleted in between. continue } - return endpoint, false, nil + return endpoint, nil } } @@ -167,12 +151,12 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE }) switch err { case nil: - return endpoint, true, nil + return endpoint, nil case tcpip.ErrDuplicateNICID: // Race detected: A NIC has been created in between. continue default: - return nil, false, syserror.EINVAL + return nil, syserror.EINVAL } } } @@ -280,7 +264,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { // If the packet does not already have link layer header, and the route // does not exist, we can't compute it. This is possibly a raw packet, tun // device doesn't support this at the moment. - if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" { + if info.Pkt.LinkHeader().View().IsEmpty() && len(info.Route.RemoteLinkAddress) == 0 { return nil, false } diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index ee84c3d96..9b4602c1b 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -11,7 +11,6 @@ go_library( deps = [ "//pkg/gate", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], @@ -25,7 +24,6 @@ go_test( library = ":waitable", deps = [ "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index b152a0f26..cf0077f43 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -24,7 +24,6 @@ package waitable import ( "gvisor.dev/gvisor/pkg/gate" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -132,17 +131,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n, err } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - if !e.writeGate.Enter() { - return nil - } - - err := e.lower.WriteRawPacket(vv) - e.writeGate.Leave() - return err -} - // WaitWrite prevents new calls to WritePacket from reaching the lower endpoint, // and waits for inflight ones to finish before returning. func (e *Endpoint) WaitWrite() { diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 94827fc56..cf7fb5126 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -18,7 +18,6 @@ import ( "testing" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -81,11 +80,6 @@ func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack. return pkts.Len(), nil } -func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { - e.writeCount++ - return nil -} - // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { panic("unimplemented") diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index c118a2929..9ebf31b78 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -7,13 +7,16 @@ go_test( size = "small", srcs = [ "ip_test.go", + "multicast_group_test.go", ], deps = [ "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 33a4a0720..3d5c0d270 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -31,17 +31,15 @@ import ( const ( // ProtocolNumber is the ARP protocol number. ProtocolNumber = header.ARPProtocolNumber - - // ProtocolAddress is the address expected by the ARP endpoint. - ProtocolAddress = tcpip.Address("arp") ) -var _ stack.AddressableEndpoint = (*endpoint)(nil) +// ARP endpoints need to implement stack.NetworkEndpoint because the stack +// considers the layer above the link-layer a network layer; the only +// facility provided by the stack to deliver packets to a layer above +// the link-layer is via stack.NetworkEndpoint.HandlePacket. var _ stack.NetworkEndpoint = (*endpoint)(nil) type endpoint struct { - stack.AddressableEndpointState - protocol *protocol // enabled is set to 1 when the NIC is enabled and 0 when it is disabled. @@ -87,7 +85,7 @@ func (e *endpoint) Disable() { } // DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. -func (e *endpoint) DefaultTTL() uint8 { +func (*endpoint) DefaultTTL() uint8 { return 0 } @@ -100,25 +98,23 @@ func (e *endpoint) MaxHeaderLength() uint16 { return e.nic.MaxHeaderLength() + header.ARPSize } -func (e *endpoint) Close() { - e.AddressableEndpointState.Cleanup() -} +func (*endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { +func (*endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. -func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { +func (*endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return ProtocolNumber } // WritePackets implements stack.NetworkEndpoint.WritePackets. -func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { +func (*endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) { return 0, tcpip.ErrNotSupported } -func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { +func (*endpoint) WriteHeaderIncludedPacket(*stack.Route, *stack.PacketBuffer) *tcpip.Error { return tcpip.ErrNotSupported } @@ -216,9 +212,8 @@ func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber func (p *protocol) MinimumPacketSize() int { return header.ARPSize } func (p *protocol) DefaultPrefixLen() int { return 0 } -func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { - h := header.ARP(v) - return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress +func (*protocol) ParseAddresses(buffer.View) (src, dst tcpip.Address) { + return "", "" } func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher) stack.NetworkEndpoint { @@ -228,7 +223,6 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L linkAddrCache: linkAddrCache, nud: nud, } - e.AddressableEndpointState.Init(e) return e } @@ -311,10 +305,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu } // NewProtocol returns an ARP network protocol. -// -// Note, to make sure that the ARP endpoint receives ARP packets, the "arp" -// address must be added to every NIC that should respond to ARP requests. See -// ProtocolAddress for more details. func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return &protocol{stack: s} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 087ee9c66..a25cba513 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -200,9 +200,6 @@ func newTestContext(t *testing.T, useNeighborCache bool) *testContext { t.Fatalf("AddAddress for ipv4 failed: %v", err) } } - if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("AddAddress for arp failed: %v", err) - } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, @@ -322,9 +319,9 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { copy(h.HardwareAddressSender(), test.senderLinkAddr) copy(h.ProtocolAddressSender(), test.senderAddr) copy(h.ProtocolAddressTarget(), test.targetAddr) - c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ + c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: v.ToVectorisedView(), - }) + })) if !test.isValid { // No packets should be sent after receiving an invalid ARP request. @@ -439,11 +436,14 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) Promiscuous() bool { + return false +} + func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - r := stack.Route{ - NetProto: protocol, - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.NetProto = protocol + r.ResolveWith(remoteLinkAddr) return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) } diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index 47fb63290..429af69ee 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -18,7 +18,6 @@ go_template_instance( go_library( name = "fragmentation", srcs = [ - "frag_heap.go", "fragmentation.go", "reassembler.go", "reassembler_list.go", @@ -38,7 +37,6 @@ go_test( name = "fragmentation_test", size = "small", srcs = [ - "frag_heap_test.go", "fragmentation_test.go", "reassembler_test.go", ], @@ -47,6 +45,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/faketime", "//pkg/tcpip/network/testutil", + "//pkg/tcpip/stack", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go deleted file mode 100644 index 0b570d25a..000000000 --- a/pkg/tcpip/network/fragmentation/frag_heap.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "container/heap" - "fmt" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -type fragment struct { - offset uint16 - vv buffer.VectorisedView -} - -type fragHeap []fragment - -func (h *fragHeap) Len() int { - return len(*h) -} - -func (h *fragHeap) Less(i, j int) bool { - return (*h)[i].offset < (*h)[j].offset -} - -func (h *fragHeap) Swap(i, j int) { - (*h)[i], (*h)[j] = (*h)[j], (*h)[i] -} - -func (h *fragHeap) Push(x interface{}) { - *h = append(*h, x.(fragment)) -} - -func (h *fragHeap) Pop() interface{} { - old := *h - n := len(old) - x := old[n-1] - *h = old[:n-1] - return x -} - -// reassamble empties the heap and returns a VectorisedView -// containing a reassambled version of the fragments inside the heap. -func (h *fragHeap) reassemble() (buffer.VectorisedView, error) { - curr := heap.Pop(h).(fragment) - views := curr.vv.Views() - size := curr.vv.Size() - - if curr.offset != 0 { - return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset) - } - - for h.Len() > 0 { - curr := heap.Pop(h).(fragment) - if int(curr.offset) < size { - curr.vv.TrimFront(size - int(curr.offset)) - } else if int(curr.offset) > size { - return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset) - } - size += curr.vv.Size() - views = append(views, curr.vv.Views()...) - } - return buffer.NewVectorisedView(size, views), nil -} diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go deleted file mode 100644 index 9ececcb9f..000000000 --- a/pkg/tcpip/network/fragmentation/frag_heap_test.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "container/heap" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -var reassambleTestCases = []struct { - comment string - in []fragment - want buffer.VectorisedView -}{ - { - comment: "Non-overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Non-overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Duplicated packets", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(1, "0"), - }, - { - comment: "Overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(2, "01")}, - {offset: 1, vv: vv(2, "12")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(2, "12")}, - {offset: 0, vv: vv(2, "01")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping subset in-order", - in: []fragment{ - {offset: 0, vv: vv(3, "012")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(3, "012"), - }, - { - comment: "Overlapping subset out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(3, "012")}, - }, - want: vv(3, "012"), - }, -} - -func TestReassamble(t *testing.T) { - for _, c := range reassambleTestCases { - t.Run(c.comment, func(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - for _, f := range c.in { - heap.Push(&h, f) - } - got, err := h.reassemble() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, c.want) { - t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want) - } - }) - } -} - -func TestReassambleFailsForNonZeroOffset(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when the first packet had offset != 0") - } -} - -func TestReassambleFailsForHoles(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")}) - heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when there was a hole in the packet") - } -} diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 936601287..1af87d713 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -46,9 +46,17 @@ const ( ) var ( - // ErrInvalidArgs indicates to the caller that that an invalid argument was + // ErrInvalidArgs indicates to the caller that an invalid argument was // provided. ErrInvalidArgs = errors.New("invalid args") + + // ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps + // with another one. + ErrFragmentOverlap = errors.New("overlapping fragments") + + // ErrFragmentConflict indicates that, during reassembly, some fragments are + // in conflict with one another. + ErrFragmentConflict = errors.New("conflicting fragments") ) // FragmentID is the identifier for a fragment. @@ -71,16 +79,25 @@ type FragmentID struct { // Fragmentation is the main structure that other modules // of the stack should use to implement IP Fragmentation. type Fragmentation struct { - mu sync.Mutex - highLimit int - lowLimit int - reassemblers map[FragmentID]*reassembler - rList reassemblerList - size int - timeout time.Duration - blockSize uint16 - clock tcpip.Clock - releaseJob *tcpip.Job + mu sync.Mutex + highLimit int + lowLimit int + reassemblers map[FragmentID]*reassembler + rList reassemblerList + size int + timeout time.Duration + blockSize uint16 + clock tcpip.Clock + releaseJob *tcpip.Job + timeoutHandler TimeoutHandler +} + +// TimeoutHandler is consulted if a packet reassembly has timed out. +type TimeoutHandler interface { + // OnReassemblyTimeout will be called with the first fragment (or nil, if the + // first fragment has not been received) of a packet whose reassembly has + // timed out. + OnReassemblyTimeout(pkt *stack.PacketBuffer) } // NewFragmentation creates a new Fragmentation. @@ -97,7 +114,7 @@ type Fragmentation struct { // reassemblingTimeout specifies the maximum time allowed to reassemble a packet. // Fragments are lazily evicted only when a new a packet with an // already existing fragmentation-id arrives after the timeout. -func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock) *Fragmentation { +func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration, clock tcpip.Clock, timeoutHandler TimeoutHandler) *Fragmentation { if lowMemoryLimit >= highMemoryLimit { lowMemoryLimit = highMemoryLimit } @@ -111,12 +128,13 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea } f := &Fragmentation{ - reassemblers: make(map[FragmentID]*reassembler), - highLimit: highMemoryLimit, - lowLimit: lowMemoryLimit, - timeout: reassemblingTimeout, - blockSize: blockSize, - clock: clock, + reassemblers: make(map[FragmentID]*reassembler), + highLimit: highMemoryLimit, + lowLimit: lowMemoryLimit, + timeout: reassemblingTimeout, + blockSize: blockSize, + clock: clock, + timeoutHandler: timeoutHandler, } f.releaseJob = tcpip.NewJob(f.clock, &f.mu, f.releaseReassemblersLocked) @@ -136,16 +154,8 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea // proto is the protocol number marked in the fragment being processed. It has // to be given here outside of the FragmentID struct because IPv6 should not use // the protocol to identify a fragment. -// -// releaseCB is a callback that will run when the fragment reassembly of a -// packet is complete or cancelled. releaseCB take a a boolean argument which is -// true iff the reassembly is cancelled due to timeout. releaseCB should be -// passed only with the first fragment of a packet. If more than one releaseCB -// are passed for the same packet, only the first releaseCB will be saved for -// the packet and the succeeding ones will be dropped by running them -// immediately with a false argument. func (f *Fragmentation) Process( - id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) ( + id FragmentID, first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) ( buffer.VectorisedView, uint8, bool, error) { if first > last { return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) @@ -160,10 +170,9 @@ func (f *Fragmentation) Process( return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) } - if l := vv.Size(); l < int(fragmentSize) { - return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) + if l := pkt.Data.Size(); l != int(fragmentSize) { + return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes not equal to the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) } - vv.CapLength(int(fragmentSize)) f.mu.Lock() r, ok := f.reassemblers[id] @@ -179,15 +188,9 @@ func (f *Fragmentation) Process( f.releaseReassemblersLocked() } } - if releaseCB != nil { - if !r.setCallback(releaseCB) { - // We got a duplicate callback. Release it immediately. - releaseCB(false /* timedOut */) - } - } f.mu.Unlock() - res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv) + res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, pkt) if err != nil { // We probably got an invalid sequence of fragments. Just // discard the reassembler and move on. @@ -231,7 +234,9 @@ func (f *Fragmentation) release(r *reassembler, timedOut bool) { f.size = 0 } - r.release(timedOut) // releaseCB may run. + if h := f.timeoutHandler; timedOut && h != nil { + h.OnReassemblyTimeout(r.pkt) + } } // releaseReassemblersLocked releases already-expired reassemblers, then diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 5dcd10730..3a79688a8 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // reassembleTimeout is dummy timeout used for testing, where the clock never @@ -40,13 +41,19 @@ func vv(size int, pieces ...string) buffer.VectorisedView { return buffer.NewVectorisedView(size, views) } +func pkt(size int, pieces ...string) *stack.PacketBuffer { + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv(size, pieces...), + }) +} + type processInput struct { id FragmentID first uint16 last uint16 more bool proto uint8 - vv buffer.VectorisedView + pkt *stack.PacketBuffer } type processOutput struct { @@ -63,8 +70,8 @@ var processTestCases = []struct { { comment: "One ID", in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -74,8 +81,8 @@ var processTestCases = []struct { { comment: "Next Header protocol mismatch", in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, vv: vv(2, "01")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, pkt: pkt(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -85,10 +92,10 @@ var processTestCases = []struct { { comment: "Two IDs", in: []processInput{ - {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, - {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, vv: vv(2, "ab")}, - {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, vv: vv(2, "cd")}, - {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")}, + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, pkt: pkt(2, "01")}, + {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, pkt: pkt(2, "ab")}, + {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, pkt: pkt(2, "cd")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, pkt: pkt(2, "23")}, }, out: []processOutput{ {vv: buffer.VectorisedView{}, done: false}, @@ -102,17 +109,17 @@ var processTestCases = []struct { func TestFragmentationProcess(t *testing.T) { for _, c := range processTestCases { t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}, nil) firstFragmentProto := c.in[0].proto for i, in := range c.in { - vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil) + vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.pkt) if err != nil { - t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s", - in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err) + t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %#v) failed: %s", + in.id, in.first, in.last, in.more, in.proto, in.pkt, err) } if !reflect.DeepEqual(vv, c.out[i].vv) { - t.Errorf("got Process(%+v, %d, %d, %t, %d, %X) = (%X, _, _, _), want = (%X, _, _, _)", - in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), vv.ToView(), c.out[i].vv.ToView()) + t.Errorf("got Process(%+v, %d, %d, %t, %d, %#v) = (%X, _, _, _), want = (%X, _, _, _)", + in.id, in.first, in.last, in.more, in.proto, in.pkt, vv.ToView(), c.out[i].vv.ToView()) } if done != c.out[i].done { t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", @@ -236,11 +243,11 @@ func TestReassemblingTimeout(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { clock := faketime.NewManualClock() - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock) + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassemblyTimeout, clock, nil) for _, event := range test.events { clock.Advance(event.clockAdvance) if frag := event.fragment; frag != nil { - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil) + _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, pkt(len(frag.data), frag.data)) if err != nil { t.Fatalf("%s: f.Process failed: %s", event.name, err) } @@ -257,17 +264,17 @@ func TestReassemblingTimeout(t *testing.T) { } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil) + f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, pkt(1, "0")) // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil) + f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, pkt(1, "1")) // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil) + f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, pkt(1, "2")) // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil) + f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, pkt(1, "3")) if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -281,11 +288,11 @@ func TestMemoryLimits(t *testing.T) { } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}, nil) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) + f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) + f.Process(FragmentID{}, 0, 0, true, 0xFF, pkt(1, "0")) got := f.size want := 1 @@ -327,6 +334,7 @@ func TestErrors(t *testing.T) { last: 3, more: true, data: "012", + err: ErrInvalidArgs, }, { name: "exact block size with more and too little data", @@ -376,8 +384,8 @@ func TestErrors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) - _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil) + f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, nil) + _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, pkt(len(test.data), test.data)) if !errors.Is(err, test.err) { t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) } @@ -498,57 +506,92 @@ func TestPacketFragmenter(t *testing.T) { } } -func TestReleaseCallback(t *testing.T) { +type testTimeoutHandler struct { + pkt *stack.PacketBuffer +} + +func (h *testTimeoutHandler) OnReassemblyTimeout(pkt *stack.PacketBuffer) { + h.pkt = pkt +} + +func TestTimeoutHandler(t *testing.T) { const ( proto = 99 ) - var result int - var callbackReasonIsTimeout bool - cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut } - cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut } + pk1 := pkt(1, "1") + pk2 := pkt(1, "2") + + type processParam struct { + first uint16 + last uint16 + more bool + pkt *stack.PacketBuffer + } tests := []struct { - name string - callbacks []func(bool) - timeout bool - wantResult int - wantCallbackReasonIsTimeout bool + name string + params []processParam + wantError bool + wantPkt *stack.PacketBuffer }{ { - name: "callback runs on release", - callbacks: []func(bool){cb1}, - timeout: false, - wantResult: 1, - wantCallbackReasonIsTimeout: false, - }, - { - name: "first callback is nil", - callbacks: []func(bool){nil, cb2}, - timeout: false, - wantResult: 2, - wantCallbackReasonIsTimeout: false, + name: "onTimeout runs", + params: []processParam{ + { + first: 0, + last: 0, + more: true, + pkt: pk1, + }, + }, + wantError: false, + wantPkt: pk1, }, { - name: "two callbacks - first one is set", - callbacks: []func(bool){cb1, cb2}, - timeout: false, - wantResult: 1, - wantCallbackReasonIsTimeout: false, + name: "no first fragment", + params: []processParam{ + { + first: 1, + last: 1, + more: true, + pkt: pk1, + }, + }, + wantError: false, + wantPkt: nil, }, { - name: "callback runs on timeout", - callbacks: []func(bool){cb1}, - timeout: true, - wantResult: 1, - wantCallbackReasonIsTimeout: true, + name: "second pkt is ignored", + params: []processParam{ + { + first: 0, + last: 0, + more: true, + pkt: pk1, + }, + { + first: 0, + last: 0, + more: true, + pkt: pk2, + }, + }, + wantError: false, + wantPkt: pk1, }, { - name: "no callbacks", - callbacks: []func(bool){nil}, - timeout: false, - wantResult: 0, - wantCallbackReasonIsTimeout: false, + name: "invalid args - first is greater than last", + params: []processParam{ + { + first: 1, + last: 0, + more: true, + pkt: pk1, + }, + }, + wantError: true, + wantPkt: nil, }, } @@ -556,29 +599,31 @@ func TestReleaseCallback(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - result = 0 - callbackReasonIsTimeout = false + handler := &testTimeoutHandler{pkt: nil} - f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}, handler) - for i, cb := range test.callbacks { - _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb) - if err != nil { + for _, p := range test.params { + if _, _, _, err := f.Process(id, p.first, p.last, p.more, proto, p.pkt); err != nil && !test.wantError { t.Errorf("f.Process error = %s", err) } } - - r, ok := f.reassemblers[id] - if !ok { - t.Fatalf("Reassemberr not found") - } - f.release(r, test.timeout) - - if result != test.wantResult { - t.Errorf("got result = %d, want = %d", result, test.wantResult) + if !test.wantError { + r, ok := f.reassemblers[id] + if !ok { + t.Fatal("Reassembler not found") + } + f.release(r, true) } - if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout { - t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout) + switch { + case handler.pkt != nil && test.wantPkt == nil: + t.Errorf("got handler.pkt = not nil (pkt.Data = %x), want = nil", handler.pkt.Data.ToView()) + case handler.pkt == nil && test.wantPkt != nil: + t.Errorf("got handler.pkt = nil, want = not nil (pkt.Data = %x)", test.wantPkt.Data.ToView()) + case handler.pkt != nil && test.wantPkt != nil: + if diff := cmp.Diff(test.wantPkt.Data.ToView(), handler.pkt.Data.ToView()); diff != "" { + t.Errorf("pkt.Data mismatch (-want, +got):\n%s", diff) + } } }) } diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index c0cc0bde0..9b20bb1d8 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -15,19 +15,21 @@ package fragmentation import ( - "container/heap" - "fmt" "math" + "sort" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) type hole struct { - first uint16 - last uint16 - deleted bool + first uint16 + last uint16 + filled bool + final bool + data buffer.View } type reassembler struct { @@ -37,84 +39,139 @@ type reassembler struct { proto uint8 mu sync.Mutex holes []hole - deleted int - heap fragHeap + filled int done bool creationTime int64 - callback func(bool) + pkt *stack.PacketBuffer } func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { r := &reassembler{ id: id, - holes: make([]hole, 0, 16), - heap: make(fragHeap, 0, 8), creationTime: clock.NowMonotonic(), } r.holes = append(r.holes, hole{ - first: 0, - last: math.MaxUint16, - deleted: false}) + first: 0, + last: math.MaxUint16, + filled: false, + final: true, + }) return r } -// updateHoles updates the list of holes for an incoming fragment and -// returns true iff the fragment filled at least part of an existing hole. -func (r *reassembler) updateHoles(first, last uint16, more bool) bool { - used := false - for i := range r.holes { - if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first { - continue - } - used = true - r.deleted++ - r.holes[i].deleted = true - if first > r.holes[i].first { - r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false}) - } - if last < r.holes[i].last && more { - r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false}) - } - } - return used -} - -func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (buffer.VectorisedView, uint8, bool, int, error) { +func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() - consumed := 0 if r.done { // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, 0, false, consumed, nil + return buffer.VectorisedView{}, 0, false, 0, nil } - // For IPv6, it is possible to have different Protocol values between - // fragments of a packet (because, unlike IPv4, the Protocol is not used to - // identify a fragment). In this case, only the Protocol of the first - // fragment must be used as per RFC 8200 Section 4.5. - // - // TODO(gvisor.dev/issue/3648): The entire first IP header should be recorded - // here (instead of just the protocol) because most IP options should be - // derived from the first fragment. - if first == 0 { - r.proto = proto - } - if r.updateHoles(first, last, more) { - // We store the incoming packet only if it filled some holes. - heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)}) - consumed = vv.Size() + + var holeFound bool + var consumed int + for i := range r.holes { + currentHole := &r.holes[i] + + if last < currentHole.first || currentHole.last < first { + continue + } + // For IPv6, overlaps with an existing fragment are explicitly forbidden by + // RFC 8200 section 4.5: + // If any of the fragments being reassembled overlap with any other + // fragments being reassembled for the same packet, reassembly of that + // packet must be abandoned and all the fragments that have been received + // for that packet must be discarded, and no ICMP error messages should be + // sent. + // + // It is not explicitly forbidden for IPv4, but to keep parity with Linux we + // disallow it as well: + // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349 + if first < currentHole.first || currentHole.last < last { + // Incoming fragment only partially fits in the free hole. + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap + } + if !more { + if !currentHole.final || currentHole.filled && currentHole.last != last { + // We have another final fragment, which does not perfectly overlap. + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict + } + } + + holeFound = true + if currentHole.filled { + // Incoming fragment is a duplicate. + continue + } + + // We are populating the current hole with the payload and creating a new + // hole for any unfilled ranges on either end. + if first > currentHole.first { + r.holes = append(r.holes, hole{ + first: currentHole.first, + last: first - 1, + filled: false, + final: false, + }) + } + if last < currentHole.last && more { + r.holes = append(r.holes, hole{ + first: last + 1, + last: currentHole.last, + filled: false, + final: currentHole.final, + }) + currentHole.final = false + } + v := pkt.Data.ToOwnedView() + consumed = v.Size() r.size += consumed + // Update the current hole to precisely match the incoming fragment. + r.holes[i] = hole{ + first: first, + last: last, + filled: true, + final: currentHole.final, + data: v, + } + r.filled++ + // For IPv6, it is possible to have different Protocol values between + // fragments of a packet (because, unlike IPv4, the Protocol is not used to + // identify a fragment). In this case, only the Protocol of the first + // fragment must be used as per RFC 8200 Section 4.5. + // + // TODO(gvisor.dev/issue/3648): During reassembly of an IPv6 packet, IP + // options received in the first fragment should be used - and they should + // override options from following fragments. + if first == 0 { + r.pkt = pkt + r.proto = proto + } + + break } - // Check if all the holes have been deleted and we are ready to reassamble. - if r.deleted < len(r.holes) { + if !holeFound { + // Incoming fragment is beyond end. + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict + } + + // Check if all the holes have been filled and we are ready to reassemble. + if r.filled < len(r.holes) { return buffer.VectorisedView{}, 0, false, consumed, nil } - res, err := r.heap.reassemble() - if err != nil { - return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err) + + sort.Slice(r.holes, func(i, j int) bool { + return r.holes[i].first < r.holes[j].first + }) + + var size int + views := make([]buffer.View, 0, len(r.holes)) + for _, hole := range r.holes { + views = append(views, hole.data) + size += hole.data.Size() } - return res, r.proto, true, consumed, nil + return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil } func (r *reassembler) checkDoneOrMark() bool { @@ -124,24 +181,3 @@ func (r *reassembler) checkDoneOrMark() bool { r.mu.Unlock() return prev } - -func (r *reassembler) setCallback(c func(bool)) bool { - r.mu.Lock() - defer r.mu.Unlock() - if r.callback != nil { - return false - } - r.callback = c - return true -} - -func (r *reassembler) release(timedOut bool) { - r.mu.Lock() - callback := r.callback - r.callback = nil - r.mu.Unlock() - - if callback != nil { - callback(timedOut) - } -} diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index fa2a70dc8..2ff03eeeb 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -16,115 +16,175 @@ package fragmentation import ( "math" - "reflect" "testing" + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) -type updateHolesInput struct { - first uint16 - last uint16 - more bool +type processParams struct { + first uint16 + last uint16 + more bool + pkt *stack.PacketBuffer + wantDone bool + wantError error } -var holesTestCases = []struct { - comment string - in []updateHolesInput - want []hole -}{ - { - comment: "No fragments. Expected holes: {[0 -> inf]}.", - in: []updateHolesInput{}, - want: []hole{{first: 0, last: math.MaxUint16, deleted: false}}, - }, - { - comment: "One fragment at beginning. Expected holes: {[2, inf]}.", - in: []updateHolesInput{{first: 0, last: 1, more: true}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 2, last: math.MaxUint16, deleted: false}, +func TestReassemblerProcess(t *testing.T) { + const proto = 99 + + v := func(size int) buffer.View { + payload := buffer.NewView(size) + for i := 1; i < size; i++ { + payload[i] = uint8(i) * 3 + } + return payload + } + + pkt := func(size int) *stack.PacketBuffer { + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v(size).ToVectorisedView(), + }) + } + + var tests = []struct { + name string + params []processParams + want []hole + }{ + { + name: "No fragments", + params: nil, + want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}}, }, - }, - { - comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.", - in: []updateHolesInput{{first: 1, last: 2, more: true}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 0, last: 0, deleted: false}, - {first: 3, last: math.MaxUint16, deleted: false}, + { + name: "One fragment at beginning", + params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, + want: []hole{ + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: math.MaxUint16, filled: false, final: true}, + }, }, - }, - { - comment: "One fragment at the end. Expected holes: {[0, 0]}.", - in: []updateHolesInput{{first: 1, last: 2, more: false}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 0, last: 0, deleted: false}, + { + name: "One fragment in the middle", + params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, + want: []hole{ + {first: 1, last: 2, filled: true, final: false, data: v(2)}, + {first: 0, last: 0, filled: false, final: false}, + {first: 3, last: math.MaxUint16, filled: false, final: true}, + }, }, - }, - { - comment: "One fragment completing a packet. Expected holes: {}.", - in: []updateHolesInput{{first: 0, last: 1, more: false}}, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, + { + name: "One fragment at the end", + params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}}, + want: []hole{ + {first: 1, last: 2, filled: true, final: true, data: v(2)}, + {first: 0, last: 0, filled: false}, + }, }, - }, - { - comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.", - in: []updateHolesInput{ - {first: 0, last: 1, more: true}, - {first: 2, last: 3, more: false}, + { + name: "One fragment completing a packet", + params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}}, + want: []hole{ + {first: 0, last: 1, filled: true, final: true, data: v(2)}, + }, }, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 2, last: math.MaxUint16, deleted: true}, + { + name: "Two fragments completing a packet", + params: []processParams{ + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: 3, filled: true, final: true, data: v(2)}, + }, }, - }, - { - comment: "Two overlapping fragments completing a packet. Expected holes: {}.", - in: []updateHolesInput{ - {first: 0, last: 2, more: true}, - {first: 2, last: 3, more: false}, + { + name: "Two fragments completing a packet with a duplicate", + params: []processParams{ + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: 3, filled: true, final: true, data: v(2)}, + }, }, - want: []hole{ - {first: 0, last: math.MaxUint16, deleted: true}, - {first: 3, last: math.MaxUint16, deleted: true}, + { + name: "Two fragments completing a packet with a partial duplicate", + params: []processParams{ + {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil}, + {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 3, filled: true, final: false, data: v(4)}, + {first: 4, last: 5, filled: true, final: true, data: v(2)}, + }, + }, + { + name: "Two overlapping fragments", + params: []processParams{ + {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil}, + {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap}, + }, + want: []hole{ + {first: 0, last: 10, filled: true, final: false, data: v(11)}, + {first: 11, last: math.MaxUint16, filled: false, final: true}, + }, + }, + { + name: "Two final fragments with different ends", + params: []processParams{ + {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, + {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict}, + }, + want: []hole{ + {first: 10, last: 14, filled: true, final: true, data: v(5)}, + {first: 0, last: 9, filled: false, final: false}, + }, + }, + { + name: "Two final fragments - duplicate", + params: []processParams{ + {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, + {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, + }, + want: []hole{ + {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 0, last: 4, filled: false, final: false}, + }, + }, + { + name: "Two final fragments - duplicate, with different ends", + params: []processParams{ + {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, + {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict}, + }, + want: []hole{ + {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 0, last: 4, filled: false, final: false}, + }, }, - }, -} - -func TestUpdateHoles(t *testing.T) { - for _, c := range holesTestCases { - r := newReassembler(FragmentID{}, &faketime.NullClock{}) - for _, i := range c.in { - r.updateHoles(i.first, i.last, i.more) - } - if !reflect.DeepEqual(r.holes, c.want) { - t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want) - } } -} -func TestSetCallback(t *testing.T) { - result := 0 - reasonTimeout := false - - cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut } - cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut } - - r := newReassembler(FragmentID{}, &faketime.NullClock{}) - if !r.setCallback(cb1) { - t.Errorf("setCallback failed") - } - if r.setCallback(cb2) { - t.Errorf("setCallback should fail if one is already set") - } - r.release(true) - if result != 1 { - t.Errorf("got result = %d, want = 1", result) - } - if !reasonTimeout { - t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + r := newReassembler(FragmentID{}, &faketime.NullClock{}) + for _, param := range test.params { + _, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) + if done != param.wantDone || err != param.wantError { + t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) + } + } + if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" { + t.Errorf("r.holes mismatch (-want +got):\n%s", diff) + } + }) } } diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD new file mode 100644 index 000000000..ca1247c1e --- /dev/null +++ b/pkg/tcpip/network/ip/BUILD @@ -0,0 +1,26 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "ip", + srcs = ["generic_multicast_protocol.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/sync", + "//pkg/tcpip", + ], +) + +go_test( + name = "ip_test", + size = "small", + srcs = ["generic_multicast_protocol_test.go"], + deps = [ + ":ip", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/faketime", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go new file mode 100644 index 000000000..f2f0e069c --- /dev/null +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -0,0 +1,676 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ip holds IPv4/IPv6 common utilities. +package ip + +import ( + "fmt" + "math/rand" + "time" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +// hostState is the state a host may be in for a multicast group. +type hostState int + +// The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1 +// (RFC 2710 section 5). Even though the states are generic across both IGMPv2 +// and MLDv1, IGMPv2 terminology will be used. +// +// ______________receive query______________ +// | | +// | _____send or receive report_____ | +// | | | | +// V | V | +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ | +// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | - +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ +// | ^ | ^ | ^ | ^ +// | | | | | | | | +// ---------- ------- ---------- ------------- +// initialize new send inital fail to send send or receive +// group membership report delayed report report +// +// Not shown in the diagram above, but any state may transition into the non +// member state when a group is left. +const ( + // nonMember is the "'Non-Member' state, when the host does not belong to the + // group on the interface. This is the initial state for all memberships on + // all network interfaces; it requires no storage in the host." + // + // 'Non-Listener' is the MLDv1 term used to describe this state. + // + // This state is used to keep track of groups that have been joined locally, + // but without advertising the membership to the network. + nonMember hostState = iota + + // pendingMember is a newly joined member that is waiting to successfully send + // the initial set of reports. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the initial report needs to be sent. + // + // MAY NOT transition to the idle member state from this state. + pendingMember + + // delayingMember is the "'Delaying Member' state, when the host belongs to + // the group on the interface and has a report delay timer running for that + // membership." + // + // 'Delaying Listener' is the MLDv1 term used to describe this state. + delayingMember + + // queuedDelayingMember is a delayingMember that failed to send a report after + // its delayed report timer fired. Hosts in this state are waiting to attempt + // retransmission of the delayed report. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the delayed report needs to be sent. + // + // May transition to idle member if a report is received for a group. + queuedDelayingMember + + // idleMember is the "Idle Member" state, when the host belongs to the group + // on the interface and does not have a report delay timer running for that + // membership. + // + // 'Idle Listener' is the MLDv1 term used to describe this state. + idleMember +) + +func (s hostState) isDelayingMember() bool { + switch s { + case nonMember, pendingMember, idleMember: + return false + case delayingMember, queuedDelayingMember: + return true + default: + panic(fmt.Sprintf("unrecognized host state = %d", s)) + } +} + +// multicastGroupState holds the Generic Multicast Protocol state for a +// multicast group. +type multicastGroupState struct { + // joins is the number of times the group has been joined. + joins uint64 + + // state holds the host's state for the group. + state hostState + + // lastToSendReport is true if we sent the last report for the group. It is + // used to track whether there are other hosts on the subnet that are also + // members of the group. + // + // Defined in RFC 2236 section 6 page 9 for IGMPv2 and RFC 2710 section 5 page + // 8 for MLDv1. + lastToSendReport bool + + // delayedReportJob is used to delay sending responses to membership report + // messages in order to reduce duplicate reports from multiple hosts on the + // interface. + // + // Must not be nil. + delayedReportJob *tcpip.Job +} + +// GenericMulticastProtocolOptions holds options for the generic multicast +// protocol. +type GenericMulticastProtocolOptions struct { + // Rand is the source of random numbers. + Rand *rand.Rand + + // Clock is the clock used to create timers. + Clock tcpip.Clock + + // Protocol is the implementation of the variant of multicast group protocol + // in use. + Protocol MulticastGroupProtocol + + // MaxUnsolicitedReportDelay is the maximum amount of time to wait between + // transmitting unsolicited reports. + // + // Unsolicited reports are transmitted when a group is newly joined. + MaxUnsolicitedReportDelay time.Duration + + // AllNodesAddress is a multicast address that all nodes on a network should + // be a member of. + // + // This address will not have the generic multicast protocol performed on it; + // it will be left in the non member/listener state, and packets will never + // be sent for it. + AllNodesAddress tcpip.Address +} + +// MulticastGroupProtocol is a multicast group protocol whose core state machine +// can be represented by GenericMulticastProtocolState. +type MulticastGroupProtocol interface { + // Enabled indicates whether the generic multicast protocol will be + // performed. + // + // When enabled, the protocol may transmit report and leave messages when + // joining and leaving multicast groups respectively, and handle incoming + // packets. + // + // When disabled, the protocol will still keep track of locally joined groups, + // it just won't transmit and handle packets, or update groups' state. + Enabled() bool + + // SendReport sends a multicast report for the specified group address. + // + // Returns false if the caller should queue the report to be sent later. Note, + // returning false does not mean that the receiver hit an error. + SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error) + + // SendLeave sends a multicast leave for the specified group address. + SendLeave(groupAddress tcpip.Address) *tcpip.Error +} + +// GenericMulticastProtocolState is the per interface generic multicast protocol +// state. +// +// There is actually no protocol named "Generic Multicast Protocol". Instead, +// the term used to refer to a generic multicast protocol that applies to both +// IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state +// machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710. +// +// Callers must synchronize accesses to the generic multicast protocol state; +// GenericMulticastProtocolState obtains no locks in any of its methods. The +// only exception to this is GenericMulticastProtocolState's timer/job callbacks +// which will obtain the lock provided to the GenericMulticastProtocolState when +// it is initialized. +// +// GenericMulticastProtocolState.Init MUST be called before calling any of +// the methods on GenericMulticastProtocolState. +// +// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the +// multicast group protocol is disabled so that leave messages may be sent. +type GenericMulticastProtocolState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + + opts GenericMulticastProtocolOptions + + // memberships holds group addresses and their associated state. + memberships map[tcpip.Address]multicastGroupState + + // protocolMU is the mutex used to protect the protocol. + protocolMU *sync.RWMutex +} + +// Init initializes the Generic Multicast Protocol state. +// +// Must only be called once for the lifetime of g; Init will panic if it is +// called twice. +// +// The GenericMulticastProtocolState will only grab the lock when timers/jobs +// fire. +// +// Note: the methods on opts.Protocol will always be called while protocolMU is +// held. +func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) { + if g.memberships != nil { + panic("attempted to initialize generic membership protocol state twice") + } + + *g = GenericMulticastProtocolState{ + opts: opts, + memberships: make(map[tcpip.Address]multicastGroupState), + protocolMU: protocolMU, + } +} + +// MakeAllNonMemberLocked transitions all groups to the non-member state. +// +// The groups will still be considered joined locally. +// +// MUST be called when the multicast group protocol is disabled. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { + if !g.opts.Protocol.Enabled() { + return + } + + for groupAddress, info := range g.memberships { + g.transitionToNonMemberLocked(groupAddress, &info) + g.memberships[groupAddress] = info + } +} + +// InitializeGroupsLocked initializes each group, as if they were newly joined +// but without affecting the groups' join count. +// +// Must only be called after calling MakeAllNonMember as a group should not be +// initialized while it is not in the non-member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { + if !g.opts.Protocol.Enabled() { + return + } + + for groupAddress, info := range g.memberships { + g.initializeNewMemberLocked(groupAddress, &info) + g.memberships[groupAddress] = info + } +} + +// SendQueuedReportsLocked attempts to send reports for groups that failed to +// send reports during their last attempt. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() { + for groupAddress, info := range g.memberships { + switch info.state { + case nonMember, delayingMember, idleMember: + case pendingMember: + // pendingMembers failed to send their initial unsolicited report so try + // to send the report and queue the extra unsolicited reports. + g.maybeSendInitialReportLocked(groupAddress, &info) + case queuedDelayingMember: + // queuedDelayingMembers failed to send their delayed reports so try to + // send the report and transition them to the idle state. + g.maybeSendDelayedReportLocked(groupAddress, &info) + default: + panic(fmt.Sprintf("unrecognized host state = %d", info.state)) + } + g.memberships[groupAddress] = info + } +} + +// JoinGroupLocked handles joining a new group. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) { + if info, ok := g.memberships[groupAddress]; ok { + // The group has already been joined. + info.joins++ + g.memberships[groupAddress] = info + return + } + + info := multicastGroupState{ + // Since we just joined the group, its count is 1. + joins: 1, + // The state will be updated below, if required. + state: nonMember, + lastToSendReport: false, + delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() { + if !g.opts.Protocol.Enabled() { + panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress)) + } + + info, ok := g.memberships[groupAddress] + if !ok { + panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) + } + + g.maybeSendDelayedReportLocked(groupAddress, &info) + g.memberships[groupAddress] = info + }), + } + + if g.opts.Protocol.Enabled() { + g.initializeNewMemberLocked(groupAddress, &info) + } + + g.memberships[groupAddress] = info +} + +// IsLocallyJoinedRLocked returns true if the group is locally joined. +// +// Precondition: g.protocolMU must be read locked. +func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool { + _, ok := g.memberships[groupAddress] + return ok +} + +// LeaveGroupLocked handles leaving the group. +// +// Returns false if the group is not currently joined. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool { + info, ok := g.memberships[groupAddress] + if !ok { + return false + } + + if info.joins == 0 { + panic(fmt.Sprintf("tried to leave group %s with a join count of 0", groupAddress)) + } + info.joins-- + if info.joins != 0 { + // If we still have outstanding joins, then do nothing further. + g.memberships[groupAddress] = info + return true + } + + g.transitionToNonMemberLocked(groupAddress, &info) + delete(g.memberships, groupAddress) + return true +} + +// HandleQueryLocked handles a query message with the specified maximum response +// time. +// +// If the group address is unspecified, then reports will be scheduled for all +// joined groups. +// +// Report(s) will be scheduled to be sent after a random duration between 0 and +// the maximum response time. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) { + if !g.opts.Protocol.Enabled() { + return + } + + // As per RFC 2236 section 2.4 (for IGMPv2), + // + // In a Membership Query message, the group address field is set to zero + // when sending a General Query, and set to the group address being + // queried when sending a Group-Specific Query. + // + // As per RFC 2710 section 3.6 (for MLDv1), + // + // In a Query message, the Multicast Address field is set to zero when + // sending a General Query, and set to a specific IPv6 multicast address + // when sending a Multicast-Address-Specific Query. + if groupAddress.Unspecified() { + // This is a general query as the group address is unspecified. + for groupAddress, info := range g.memberships { + g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) + g.memberships[groupAddress] = info + } + } else if info, ok := g.memberships[groupAddress]; ok { + g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) + g.memberships[groupAddress] = info + } +} + +// HandleReportLocked handles a report message. +// +// If the report is for a joined group, any active delayed report will be +// cancelled and the host state for the group transitions to idle. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) { + if !g.opts.Protocol.Enabled() { + return + } + + // As per RFC 2236 section 3 pages 3-4 (for IGMPv2), + // + // If the host receives another host's Report (version 1 or 2) while it has + // a timer running, it stops its timer for the specified group and does not + // send a Report + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // If a node receives another node's Report from an interface for a + // multicast address while it has a timer running for that same address + // on that interface, it stops its timer and does not send a Report for + // that address, thus suppressing duplicate reports on the link. + if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() { + info.delayedReportJob.Cancel() + info.lastToSendReport = false + info.state = idleMember + g.memberships[groupAddress] = info + } +} + +// initializeNewMemberLocked initializes a new group membership. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state != nonMember { + panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state)) + } + + info.lastToSendReport = false + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + info.state = idleMember + return + } + + info.state = pendingMember + g.maybeSendInitialReportLocked(groupAddress, info) +} + +// maybeSendInitialReportLocked attempts to start transmission of the initial +// set of reports after newly joining a group. +// +// Host must be in pending member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state != pendingMember { + panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state)) + } + + // As per RFC 2236 section 3 page 5 (for IGMPv2), + // + // When a host joins a multicast group, it should immediately transmit an + // unsolicited Version 2 Membership Report for that group" ... "it is + // recommended that it be repeated". + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // When a node starts listening to a multicast address on an interface, + // it should immediately transmit an unsolicited Report for that address + // on that interface, in case it is the first listener on the link. To + // cover the possibility of the initial Report being lost or damaged, it + // is recommended that it be repeated once or twice after short delays + // [Unsolicited Report Interval]. + // + // TODO(gvisor.dev/issue/4901): Support a configurable number of initial + // unsolicited reports. + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) + } +} + +// maybeSendDelayedReportLocked attempts to send the delayed report. +// +// Host must be in pending, delaying or queued delaying member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if !info.state.isDelayingMember() { + panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state)) + } + + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + info.state = idleMember + } else { + info.state = queuedDelayingMember + } +} + +// maybeSendLeave attempts to send a leave message. +func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) { + if !g.opts.Protocol.Enabled() || !lastToSendReport { + return + } + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + return + } + + // Okay to ignore the error here as if packet write failed, the multicast + // routers will eventually drop our membership anyways. If the interface is + // being disabled or removed, the generic multicast protocol's should be + // cleared eventually. + // + // As per RFC 2236 section 3 page 5 (for IGMPv2), + // + // When a router receives a Report, it adds the group being reported to + // the list of multicast group memberships on the network on which it + // received the Report and sets the timer for the membership to the + // [Group Membership Interval]. Repeated Reports refresh the timer. If + // no Reports are received for a particular group before this timer has + // expired, the router assumes that the group has no local members and + // that it need not forward remotely-originated multicasts for that + // group onto the attached network. + // + // As per RFC 2710 section 4 page 5 (for MLDv1), + // + // When a router receives a Report from a link, if the reported address + // is not already present in the router's list of multicast address + // having listeners on that link, the reported address is added to the + // list, its timer is set to [Multicast Listener Interval], and its + // appearance is made known to the router's multicast routing component. + // If a Report is received for a multicast address that is already + // present in the router's list, the timer for that address is reset to + // [Multicast Listener Interval]. If an address's timer expires, it is + // assumed that there are no longer any listeners for that address + // present on the link, so it is deleted from the list and its + // disappearance is made known to the multicast routing component. + // + // The requirement to send a leave message is also optional (it MAY be + // skipped): + // + // As per RFC 2236 section 6 page 8 (for IGMPv2), + // + // "send leave" for the group on the interface. If the interface + // state says the Querier is running IGMPv1, this action SHOULD be + // skipped. If the flag saying we were the last host to report is + // cleared, this action MAY be skipped. The Leave Message is sent to + // the ALL-ROUTERS group (224.0.0.2). + // + // As per RFC 2710 section 5 page 8 (for MLDv1), + // + // "send done" for the address on the interface. If the flag saying + // we were the last node to report is cleared, this action MAY be + // skipped. The Done message is sent to the link-scope all-routers + // address (FF02::2). + _ = g.opts.Protocol.SendLeave(groupAddress) +} + +// transitionToNonMemberLocked transitions the given multicast group the the +// non-member/listener state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state == nonMember { + return + } + + info.delayedReportJob.Cancel() + g.maybeSendLeave(groupAddress, info.lastToSendReport) + info.lastToSendReport = false + info.state = nonMember +} + +// setDelayTimerForAddressRLocked sets timer to send a delay report. +// +// Precondition: g.protocolMU MUST be read locked. +func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) { + if info.state == nonMember { + return + } + + if groupAddress == g.opts.AllNodesAddress { + // As per RFC 2236 section 6 page 10 (for IGMPv2), + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + // + // As per RFC 2710 section 5 page 10 (for MLDv1), + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + return + } + + // As per RFC 2236 section 3 page 3 (for IGMPv2), + // + // If a timer for the group is already unning, it is reset to the random + // value only if the requested Max Response Time is less than the remaining + // value of the running timer. + // + // As per RFC 2710 section 4 page 5 (for MLDv1), + // + // If a timer for any address is already running, it is reset to the new + // random value only if the requested Maximum Response Delay is less than + // the remaining value of the running timer. + if info.state == delayingMember { + // TODO: Reset the timer if time remaining is greater than maxResponseTime. + return + } + + info.state = delayingMember + info.delayedReportJob.Cancel() + info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime)) +} + +// calculateDelayTimerDuration returns a random time between (0, maxRespTime]. +func (g *GenericMulticastProtocolState) calculateDelayTimerDuration(maxRespTime time.Duration) time.Duration { + // As per RFC 2236 section 3 page 3 (for IGMPv2), + // + // When a host receives a Group-Specific Query, it sets a delay timer to a + // random value selected from the range (0, Max Response Time]... + // + // As per RFC 2710 section 4 page 6 (for MLDv1), + // + // When a node receives a Multicast-Address-Specific Query, if it is + // listening to the queried Multicast Address on the interface from + // which the Query was received, it sets a delay timer for that address + // to a random value selected from the range [0, Maximum Response Delay], + // as above. + if maxRespTime == 0 { + return 0 + } + return time.Duration(g.opts.Rand.Int63n(int64(maxRespTime))) +} diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go new file mode 100644 index 000000000..85593f211 --- /dev/null +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -0,0 +1,806 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip_test + +import ( + "math/rand" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" +) + +const ( + addr1 = tcpip.Address("\x01") + addr2 = tcpip.Address("\x02") + addr3 = tcpip.Address("\x03") + addr4 = tcpip.Address("\x04") + + maxUnsolicitedReportDelay = time.Second +) + +var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) + +type mockMulticastGroupProtocolProtectedFields struct { + sync.RWMutex + + genericMulticastGroup ip.GenericMulticastProtocolState + sendReportGroupAddrCount map[tcpip.Address]int + sendLeaveGroupAddrCount map[tcpip.Address]int + makeQueuePackets bool + disabled bool +} + +type mockMulticastGroupProtocol struct { + t *testing.T + + mu mockMulticastGroupProtocolProtectedFields +} + +func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) { + m.mu.Lock() + defer m.mu.Unlock() + m.initLocked() + opts.Protocol = m + m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts) +} + +func (m *mockMulticastGroupProtocol) initLocked() { + m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int) + m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) +} + +func (m *mockMulticastGroupProtocol) setEnabled(v bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.disabled = !v +} + +func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.makeQueuePackets = v +} + +func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.JoinGroupLocked(addr) +} + +func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.mu.genericMulticastGroup.LeaveGroupLocked(addr) +} + +func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.HandleReportLocked(addr) +} + +func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime) +} + +func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr) +} + +func (m *mockMulticastGroupProtocol) makeAllNonMember() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.MakeAllNonMemberLocked() +} + +func (m *mockMulticastGroupProtocol) initializeGroups() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.InitializeGroupsLocked() +} + +func (m *mockMulticastGroupProtocol) sendQueuedReports() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.SendQueuedReportsLocked() +} + +// Enabled implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be read locked. +func (m *mockMulticastGroupProtocol) Enabled() bool { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled") + } + + return !m.mu.disabled +} + +// SendReport implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be locked. +func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + + m.mu.sendReportGroupAddrCount[groupAddress]++ + return !m.mu.makeQueuePackets, nil +} + +// SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be locked. +func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + + m.mu.sendLeaveGroupAddrCount[groupAddress]++ + return nil +} + +func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { + m.mu.Lock() + defer m.mu.Unlock() + + sendReportGroupAddrCount := make(map[tcpip.Address]int) + for _, a := range sendReportGroupAddresses { + sendReportGroupAddrCount[a] = 1 + } + + sendLeaveGroupAddrCount := make(map[tcpip.Address]int) + for _, a := range sendLeaveGroupAddresses { + sendLeaveGroupAddrCount[a] = 1 + } + + diff := cmp.Diff( + &mockMulticastGroupProtocol{ + mu: mockMulticastGroupProtocolProtectedFields{ + sendReportGroupAddrCount: sendReportGroupAddrCount, + sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, + }, + }, + m, + cmp.AllowUnexported(mockMulticastGroupProtocol{}), + cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}), + // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t + cmp.FilterPath( + func(p cmp.Path) bool { + switch p.Last().String() { + case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup": + return true + } + return false + }, + cmp.Ignore(), + ), + ) + m.initLocked() + return diff +} + +func TestJoinGroup(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + shouldSendReports bool + }{ + { + name: "Normal group", + addr: addr1, + shouldSendReports: true, + }, + { + name: "All-nodes group", + addr: addr2, + shouldSendReports: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(0)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr2, + }) + + // Joining a group should send a report immediately and another after + // a random interval between 0 and the maximum unsolicited report delay. + mgp.joinGroup(test.addr) + if test.shouldSendReports { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestLeaveGroup(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + shouldSendMessages bool + }{ + { + name: "Normal group", + addr: addr1, + shouldSendMessages: true, + }, + { + name: "All-nodes group", + addr: addr2, + shouldSendMessages: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(1)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr2, + }) + + mgp.joinGroup(test.addr) + if test.shouldSendMessages { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Leaving a group should send a leave report immediately and cancel any + // delayed reports. + { + + if !mgp.leaveGroup(test.addr) { + t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr) + } + } + if test.shouldSendMessages { + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestHandleReport(t *testing.T) { + tests := []struct { + name string + reportAddr tcpip.Address + expectReportsFor []tcpip.Address + }{ + { + name: "Unpecified empty", + reportAddr: "", + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Unpecified any", + reportAddr: "\x00", + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified", + reportAddr: addr1, + expectReportsFor: []tcpip.Address{addr2}, + }, + { + name: "Specified all-nodes", + reportAddr: addr3, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified other", + reportAddr: addr4, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(2)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr3) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a report for a group we have a timer scheduled for should + // cancel our delayed report timer for the group. + mgp.handleReport(test.reportAddr) + if len(test.expectReportsFor) != 0 { + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestHandleQuery(t *testing.T) { + tests := []struct { + name string + queryAddr tcpip.Address + maxDelay time.Duration + expectReportsFor []tcpip.Address + }{ + { + name: "Unpecified empty", + queryAddr: "", + maxDelay: 0, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Unpecified any", + queryAddr: "\x00", + maxDelay: 1, + expectReportsFor: []tcpip.Address{addr1, addr2}, + }, + { + name: "Specified", + queryAddr: addr1, + maxDelay: 2, + expectReportsFor: []tcpip.Address{addr1}, + }, + { + name: "Specified all-nodes", + queryAddr: addr3, + maxDelay: 3, + expectReportsFor: nil, + }, + { + name: "Specified other", + queryAddr: addr4, + maxDelay: 4, + expectReportsFor: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr3) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a query should make us schedule a new delayed report if it + // is a query directed at us or a general query. + mgp.handleQuery(test.queryAddr, test.maxDelay) + if len(test.expectReportsFor) != 0 { + clock.Advance(test.maxDelay) + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestJoinCount(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(4)), + Clock: clock, + MaxUnsolicitedReportDelay: time.Second, + }) + + // Set the join count to 2 for a group. + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) + } + // Only the first join should trigger a report to be sent. + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() + } + + // Group should still be considered joined after leaving once. + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) + } + if !mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) + } + // A leave report should only be sent once the join count reaches 0. + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() + } + + // Leaving once more should actually remove us from the group. + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) + } + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() + } + + // Group should no longer be joined so we should not have anything to + // leave. + if mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1) + } + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +func TestMakeAllNonMemberAndInitialize(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + AllNodesAddress: addr3, + }) + + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr3) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should send the leave reports for each but still consider them locally + // joined. + mgp.makeAllNonMember() + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + for _, group := range []tcpip.Address{addr1, addr2, addr3} { + if !mgp.isLocallyJoined(group) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group) + } + } + + // Should send the initial set of unsolcited reports. + mgp.initializeGroups() + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should have no more messages to send. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +// TestGroupStateNonMember tests that groups do not send packets when in the +// non-member state, but are still considered locally joined. +func TestGroupStateNonMember(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) + mgp.setEnabled(false) + + // Joining groups should not send any reports. + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr2) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a query should not send any reports. + mgp.handleQuery(addr1, time.Nanosecond) + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Nanosecond) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Leaving groups should not send any leave messages. + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2) + } + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} + +func TestQueuedPackets(t *testing.T) { + clock := faketime.NewManualClock() + mgp := mockMulticastGroupProtocol{t: t} + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(4)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) + + // Joining should trigger a SendReport, but mgp should report that we did not + // send the packet. + mgp.setQueuePackets(true) + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The delayed report timer should have been cancelled since we did not send + // the initial report earlier. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to successfully send the report. + mgp.setQueuePackets(false) + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send (we should be idle). + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query but mock being unable to send reports again. + mgp.setQueuePackets(true) + mgp.handleQuery(addr1, time.Nanosecond) + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to send reports again - we should have a packet queued to + // send. + mgp.setQueuePackets(false) + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query again, but mock being unable to send reports. + mgp.setQueuePackets(true) + mgp.handleQuery(addr1, time.Nanosecond) + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a report should should transition us into the idle member state, + // even if we had a packet queued. We should no longer have any packets to + // send. + mgp.handleReport(addr1) + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // When we fail to send the initial set of reports, incoming reports should + // not affect a newly joined group's reports from being sent. + mgp.setQueuePackets(true) + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.handleReport(addr2) + // Attempting to send queued reports while still unable to send reports should + // not change the host state. + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Mock being able to successfully send the report. + mgp.setQueuePackets(false) + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index c7d26e14f..3005973d7 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -34,16 +35,16 @@ import ( ) const ( - localIPv4Addr = "\x0a\x00\x00\x01" - remoteIPv4Addr = "\x0a\x00\x00\x02" - ipv4SubnetAddr = "\x0a\x00\x00\x00" - ipv4SubnetMask = "\xff\xff\xff\x00" - ipv4Gateway = "\x0a\x00\x00\x03" - localIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - remoteIPv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" - ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" + localIPv4Addr = tcpip.Address("\x0a\x00\x00\x01") + remoteIPv4Addr = tcpip.Address("\x0a\x00\x00\x02") + ipv4SubnetAddr = tcpip.Address("\x0a\x00\x00\x00") + ipv4SubnetMask = tcpip.Address("\xff\xff\xff\x00") + ipv4Gateway = tcpip.Address("\x0a\x00\x00\x03") + localIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + remoteIPv6Addr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + ipv6SubnetAddr = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") + ipv6SubnetMask = tcpip.Address("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00") + ipv6Gateway = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") nicID = 1 ) @@ -192,10 +193,6 @@ func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBu panic("not implemented") } -func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { - return tcpip.ErrNotSupported -} - // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. func (*testObject) ARPHardwareType() header.ARPHardwareType { panic("not implemented") @@ -206,7 +203,7 @@ func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.Net panic("not implemented") } -func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { +func buildIPv4Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, @@ -222,7 +219,7 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */) } -func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { +func buildIPv6Route(local, remote tcpip.Address) (*stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, @@ -299,6 +296,10 @@ func (t *testInterface) Enabled() bool { return !t.mu.disabled } +func (*testInterface) Promiscuous() bool { + return false +} + func (t *testInterface) setEnabled(v bool) { t.mu.Lock() defer t.mu.Unlock() @@ -343,11 +344,11 @@ func TestSourceAddressValidation(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv6Addr, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv6Addr, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -549,7 +550,7 @@ func TestIPv4Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, @@ -558,59 +559,135 @@ func TestIPv4Send(t *testing.T) { } } -func TestIPv4Receive(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - v4: true, +func TestReceive(t *testing.T) { + tests := []struct { + name string + protoFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + v4 bool + epAddr tcpip.AddressWithPrefix + handlePacket func(*testing.T, stack.NetworkEndpoint, *testInterface) + }{ + { + name: "IPv4", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + v4: true, + epAddr: localIPv4Addr.WithPrefix(), + handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { + const totalLen = header.IPv4MinimumSize + 30 /* payload length */ + + view := buffer.NewView(totalLen) + ip := header.IPv4(view) + ip.Encode(&header.IPv4Fields{ + TotalLength: totalLen, + TTL: ipv4.DefaultTTL, + Protocol: 10, + SrcAddr: remoteIPv4Addr, + DstAddr: localIPv4Addr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Make payload be non-zero. + for i := header.IPv4MinimumSize; i < len(view); i++ { + view[i] = uint8(i) + } + + // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv4Addr + nic.testObject.dstAddr = localIPv4Addr + nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if ok := parse.IPv4(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(pkt) + }, }, - } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) - defer ep.Close() + { + name: "IPv6", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + v4: false, + epAddr: localIPv6Addr.WithPrefix(), + handlePacket: func(t *testing.T, ep stack.NetworkEndpoint, nic *testInterface) { + const payloadLen = 30 + view := buffer.NewView(header.IPv6MinimumSize + payloadLen) + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: payloadLen, + TransportProtocol: 10, + HopLimit: ipv6.DefaultTTL, + SrcAddr: remoteIPv6Addr, + DstAddr: localIPv6Addr, + }) - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } + // Make payload be non-zero. + for i := header.IPv6MinimumSize; i < len(view); i++ { + view[i] = uint8(i) + } - totalLen := header.IPv4MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv4(view) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - TTL: 20, - Protocol: 10, - SrcAddr: remoteIPv4Addr, - DstAddr: localIPv4Addr, - }) - ip.SetChecksum(^ip.CalculateChecksum()) + // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. + nic.testObject.protocol = 10 + nic.testObject.srcAddr = remoteIPv6Addr + nic.testObject.dstAddr = localIPv6Addr + nic.testObject.contents = view[header.IPv6MinimumSize:][:payloadLen] - // Make payload be non-zero. - for i := header.IPv4MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: view.ToVectorisedView(), + }) + if _, _, _, _, ok := parse.IPv6(pkt); !ok { + t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) + } + ep.HandlePacket(pkt) + }, + }, } - // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv4Addr - nic.testObject.dstAddr = localIPv4Addr - nic.testObject.contents = view[header.IPv4MinimumSize:totalLen] + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, + }) + nic := testInterface{ + testObject: testObject{ + t: t, + v4: test.v4, + }, + } + ep := s.NetworkProtocolInstance(test.protoNum).NewEndpoint(&nic, nil, nil, &nic.testObject) + defer ep.Close() - r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - if _, _, ok := proto.Parse(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } - r.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) - if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint with number = %d to implement stack.AddressableEndpoint", test.protoNum) + } + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(test.epAddr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", test.epAddr, err) + } else { + ep.DecRef() + } + + stat := s.Stats().IP.PacketsReceived + if got := stat.Value(); got != 0 { + t.Fatalf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 0", got) + } + test.handlePacket(t, ep, &nic) + if nic.testObject.dataCalls != 1 { + t.Errorf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) + } + if got := stat.Value(); got != 1 { + t.Errorf("got s.Stats().IP.PacketsReceived.Value() = %d, want = 1", got) + } + }) } } @@ -634,10 +711,6 @@ func TestIPv4ReceiveControl(t *testing.T) { {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, } - r, err := buildIPv4Route(localIPv4Addr, "\x0a\x00\x00\xbb") - if err != nil { - t.Fatal(err) - } for _, c := range cases { t.Run(c.name, func(t *testing.T) { s := buildDummyStack(t) @@ -705,8 +778,18 @@ func TestIPv4ReceiveControl(t *testing.T) { nic.testObject.typ = c.expectedTyp nic.testObject.extra = c.expectedExtra + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") + } + addr := localIPv4Addr.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } + pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) @@ -716,7 +799,9 @@ func TestIPv4ReceiveControl(t *testing.T) { } func TestIPv4FragmentationReceive(t *testing.T) { - s := buildDummyStack(t) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + }) proto := s.NetworkProtocolInstance(ipv4.ProtocolNumber) nic := testInterface{ testObject: testObject{ @@ -774,11 +859,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { nic.testObject.dstAddr = localIPv4Addr nic.testObject.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) - r, err := buildIPv4Route(localIPv4Addr, remoteIPv4Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - // Send first segment. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: frag1.ToVectorisedView(), @@ -786,7 +866,18 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - r.PopulatePacketInfo(pkt) + + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatal("expected IPv4 network endpoint to implement stack.AddressableEndpoint") + } + addr := localIPv4Addr.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) @@ -799,7 +890,6 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) @@ -843,7 +933,7 @@ func TestIPv6Send(t *testing.T) { if err != nil { t.Fatalf("could not find route: %v", err) } - if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{ + if err := ep.WritePacket(r, nil /* gso */, stack.NetworkHeaderParams{ Protocol: 123, TTL: 123, TOS: stack.DefaultTOS, @@ -852,61 +942,6 @@ func TestIPv6Send(t *testing.T) { } } -func TestIPv6Receive(t *testing.T) { - s := buildDummyStack(t) - proto := s.NetworkProtocolInstance(ipv6.ProtocolNumber) - nic := testInterface{ - testObject: testObject{ - t: t, - }, - } - ep := proto.NewEndpoint(&nic, nil, nil, &nic.testObject) - defer ep.Close() - - if err := ep.Enable(); err != nil { - t.Fatalf("ep.Enable(): %s", err) - } - - totalLen := header.IPv6MinimumSize + 30 - view := buffer.NewView(totalLen) - ip := header.IPv6(view) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(totalLen - header.IPv6MinimumSize), - NextHeader: 10, - HopLimit: 20, - SrcAddr: remoteIPv6Addr, - DstAddr: localIPv6Addr, - }) - - // Make payload be non-zero. - for i := header.IPv6MinimumSize; i < totalLen; i++ { - view[i] = uint8(i) - } - - // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. - nic.testObject.protocol = 10 - nic.testObject.srcAddr = remoteIPv6Addr - nic.testObject.dstAddr = localIPv6Addr - nic.testObject.contents = view[header.IPv6MinimumSize:totalLen] - - r, err := buildIPv6Route(localIPv6Addr, remoteIPv6Addr) - if err != nil { - t.Fatalf("could not find route: %v", err) - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: view.ToVectorisedView(), - }) - if _, _, ok := proto.Parse(pkt); !ok { - t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) - } - r.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) - if nic.testObject.dataCalls != 1 { - t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) - } -} - func TestIPv6ReceiveControl(t *testing.T) { newUint16 := func(v uint16) *uint16 { return &v } @@ -933,13 +968,6 @@ func TestIPv6ReceiveControl(t *testing.T) { {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, } - r, err := buildIPv6Route( - localIPv6Addr, - "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", - ) - if err != nil { - t.Fatal(err) - } for _, c := range cases { t.Run(c.name, func(t *testing.T) { s := buildDummyStack(t) @@ -965,11 +993,11 @@ func TestIPv6ReceiveControl(t *testing.T) { // Create the outer IPv6 header. ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 20, - SrcAddr: outerSrcAddr, - DstAddr: localIPv6Addr, + PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 20, + SrcAddr: outerSrcAddr, + DstAddr: localIPv6Addr, }) // Create the ICMP header. @@ -979,28 +1007,27 @@ func TestIPv6ReceiveControl(t *testing.T) { icmp.SetIdent(0xdead) icmp.SetSequence(0xbeef) - // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) - ip.Encode(&header.IPv6Fields{ - PayloadLength: 100, - NextHeader: 10, - HopLimit: 20, - SrcAddr: localIPv6Addr, - DstAddr: remoteIPv6Addr, - }) - + var extHdrs header.IPv6ExtHdrSerializer // Build the fragmentation header if needed. if c.fragmentOffset != nil { - ip.SetNextHeader(header.IPv6FragmentHeader) - frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:]) - frag.Encode(&header.IPv6FragmentFields{ - NextHeader: 10, + extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{ FragmentOffset: *c.fragmentOffset, M: true, Identification: 0x12345678, }) } + // Create the inner IPv6 header. + ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) + ip.Encode(&header.IPv6Fields{ + PayloadLength: 100, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: localIPv6Addr, + DstAddr: remoteIPv6Addr, + ExtensionHeaders: extHdrs, + }) + // Make payload be non-zero. for i := dataOffset; i < len(view); i++ { view[i] = uint8(i) @@ -1018,8 +1045,17 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatal("expected IPv6 network endpoint to implement stack.AddressableEndpoint") + } + addr := localIPv6Addr.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() + } pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) @@ -1052,7 +1088,19 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { dataBuf := [dataLen]byte{1, 2, 3, 4} data := dataBuf[:] - ipv4Options := header.IPv4Options{0, 1, 0, 1} + ipv4Options := header.IPv4OptionsSerializer{ + &header.IPv4SerializableListEndOption{}, + &header.IPv4SerializableNOPOption{}, + &header.IPv4SerializableListEndOption{}, + &header.IPv4SerializableNOPOption{}, + } + + expectOptions := header.IPv4Options{ + byte(header.IPv4OptionListEndType), + byte(header.IPv4OptionNOPType), + byte(header.IPv4OptionListEndType), + byte(header.IPv4OptionNOPType), + } ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4} ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:] @@ -1202,7 +1250,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ipHdrLen := header.IPv4MinimumSize + ipv4Options.AllocationSize() + ipHdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) totalLen := ipHdrLen + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1225,7 +1273,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { netHdr := pkt.NetworkHeader() - hdrLen := header.IPv4MinimumSize + len(ipv4Options) + hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) if len(netHdr.View()) != hdrLen { t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) } @@ -1235,7 +1283,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { checker.DstAddr(remoteIPv4Addr), checker.IPv4HeaderLength(hdrLen), checker.IPFullLength(uint16(hdrLen+len(data))), - checker.IPv4Options(ipv4Options), + checker.IPv4Options(expectOptions), checker.IPPayload(data), ) }, @@ -1247,7 +1295,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { - ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.AllocationSize())) + ip := header.IPv4(make([]byte, header.IPv4MinimumSize+ipv4Options.Length())) ip.Encode(&header.IPv4Fields{ Protocol: transportProto, TTL: ipv4.DefaultTTL, @@ -1266,7 +1314,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { netHdr := pkt.NetworkHeader() - hdrLen := header.IPv4MinimumSize + len(ipv4Options) + hdrLen := int(header.IPv4MinimumSize + ipv4Options.Length()) if len(netHdr.View()) != hdrLen { t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) } @@ -1276,7 +1324,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { checker.DstAddr(remoteIPv4Addr), checker.IPv4HeaderLength(hdrLen), checker.IPFullLength(uint16(hdrLen+len(data))), - checker.IPv4Options(ipv4Options), + checker.IPv4Options(expectOptions), checker.IPPayload(data), ) }, @@ -1295,10 +1343,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return hdr.View().ToVectorisedView() }, @@ -1338,10 +1386,12 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier), - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + // NB: we're lying about transport protocol here to verify the raw + // fragment header bytes. + TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier), + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return hdr.View().ToVectorisedView() }, @@ -1373,10 +1423,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return buffer.View(ip).ToVectorisedView() }, @@ -1408,10 +1458,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 6252614ec..32f53f217 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -6,6 +6,7 @@ go_library( name = "ipv4", srcs = [ "icmp.go", + "igmp.go", "ipv4.go", ], visibility = ["//visibility:public"], @@ -17,6 +18,7 @@ go_library( "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", + "//pkg/tcpip/network/ip", "//pkg/tcpip/stack", ], ) @@ -24,7 +26,10 @@ go_library( go_test( name = "ipv4_test", size = "small", - srcs = ["ipv4_test.go"], + srcs = [ + "igmp_test.go", + "ipv4_test.go", + ], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 9b5e37fee..8e392f86c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -63,7 +63,7 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { stats := e.protocol.stack.Stats() - received := stats.ICMP.V4PacketsReceived + received := stats.ICMP.V4.PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. @@ -90,7 +90,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { iph := header.IPv4(pkt.NetworkHeader().View()) var newOptions header.IPv4Options - if len(iph) > header.IPv4MinimumSize { + if opts := iph.Options(); len(opts) != 0 { // RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip // type ICMP packets): // If a Record Route and/or Time Stamp option is received in an @@ -106,7 +106,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } else { op = &optionUsageReceive{} } - aux, tmp, err := e.processIPOptions(pkt, iph.Options(), op) + aux, tmp, err := e.processIPOptions(pkt, opts, op) if err != nil { switch { case @@ -130,7 +130,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4Echo: received.Echo.Increment() - sent := stats.ICMP.V4PacketsSent + sent := stats.ICMP.V4.PacketsSent if !e.protocol.stack.AllowICMPMessage() { sent.RateLimited.Increment() return @@ -290,6 +290,13 @@ type icmpReasonProtoUnreachable struct{} func (*icmpReasonProtoUnreachable) isICMPReason() {} +// icmpReasonTTLExceeded is an error where a packet's time to live exceeded in +// transit to its final destination, as per RFC 792 page 6, Time Exceeded +// Message. +type icmpReasonTTLExceeded struct{} + +func (*icmpReasonTTLExceeded) isICMPReason() {} + // icmpReasonReassemblyTimeout is an error where insufficient fragments are // received to complete reassembly of a packet within a configured time after // the reception of the first-arriving fragment of that packet. @@ -342,17 +349,37 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi return nil } + // If we hit a TTL Exceeded error, then we know we are operating as a router. + // As per RFC 792 page 6, Time Exceeded Message, + // + // If the gateway processing a datagram finds the time to live field + // is zero it must discard the datagram. The gateway may also notify + // the source host via the time exceeded message. + // + // ... + // + // Code 0 may be received from a gateway. ... + // + // Note, Code 0 is the TTL exceeded error. + // + // If we are operating as a router/gateway, don't use the packet's destination + // address as the response's source address as we should not not own the + // destination address of a packet we are forwarding. + localAddr := origIPHdrDst + if _, ok := reason.(*icmpReasonTTLExceeded); ok { + localAddr = "" + } // Even if we were able to receive a packet from some remote, we may not have // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } defer route.Release() - sent := p.stack.Stats().ICMP.V4PacketsSent + sent := p.stack.Stats().ICMP.V4.PacketsSent if !p.stack.AllowICMPMessage() { sent.RateLimited.Increment() return nil @@ -454,6 +481,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) counter = sent.DstUnreachable + case *icmpReasonTTLExceeded: + icmpHdr.SetType(header.ICMPv4TimeExceeded) + icmpHdr.SetCode(header.ICMPv4TTLExceeded) + counter = sent.TimeExceeded case *icmpReasonReassemblyTimeout: icmpHdr.SetType(header.ICMPv4TimeExceeded) icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) @@ -483,3 +514,18 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi counter.Increment() return nil } + +// OnReassemblyTimeout implements fragmentation.TimeoutHandler. +func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) { + // OnReassemblyTimeout sends a Time Exceeded Message, as per RFC 792: + // + // If a host reassembling a fragmented datagram cannot complete the + // reassembly due to missing fragments within its time limit it discards the + // datagram, and it may send a time exceeded message. + // + // If fragment zero is not available then no time exceeded need be sent at + // all. + if pkt != nil { + p.returnError(&icmpReasonReassemblyTimeout{}, pkt) + } +} diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go new file mode 100644 index 000000000..da88d65d1 --- /dev/null +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -0,0 +1,345 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "fmt" + "sync/atomic" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // igmpV1PresentDefault is the initial state for igmpV1Present in the + // igmpState. As per RFC 2236 Page 9 says "No IGMPv1 Router Present ... is + // the initial state." + igmpV1PresentDefault = 0 + + // v1RouterPresentTimeout from RFC 2236 Section 8.11, Page 18 + // See note on igmpState.igmpV1Present for more detail. + v1RouterPresentTimeout = 400 * time.Second + + // v1MaxRespTime from RFC 2236 Section 4, Page 5. "The IGMPv1 router + // will send General Queries with the Max Response Time set to 0. This MUST + // be interpreted as a value of 100 (10 seconds)." + // + // Note that the Max Response Time field is a value in units of deciseconds. + v1MaxRespTime = 10 * time.Second + + // UnsolicitedReportIntervalMax is the maximum delay between sending + // unsolicited IGMP reports. + // + // Obtained from RFC 2236 Section 8.10, Page 19. + UnsolicitedReportIntervalMax = 10 * time.Second +) + +// IGMPOptions holds options for IGMP. +type IGMPOptions struct { + // Enabled indicates whether IGMP will be performed. + // + // When enabled, IGMP may transmit IGMP report and leave messages when + // joining and leaving multicast groups respectively, and handle incoming + // IGMP packets. + // + // This field is ignored and is always assumed to be false for interfaces + // without neighbouring nodes (e.g. loopback). + Enabled bool +} + +var _ ip.MulticastGroupProtocol = (*igmpState)(nil) + +// igmpState is the per-interface IGMP state. +// +// igmpState.init() MUST be called after creating an IGMP state. +type igmpState struct { + // The IPv4 endpoint this igmpState is for. + ep *endpoint + + genericMulticastProtocol ip.GenericMulticastProtocolState + + // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from + // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1 + // Membership Reports in response to its Queries, and will not pay + // attention to Version 2 Membership Reports. Therefore, a state variable + // MUST be kept for each interface, describing whether the multicast + // Querier on that interface is running IGMPv1 or IGMPv2. This variable + // MUST be based upon whether or not an IGMPv1 query was heard in the last + // [Version 1 Router Present Timeout] seconds". + // + // Must be accessed with atomic operations. Holds a value of 1 when true, 0 + // when false. + igmpV1Present uint32 + + // igmpV1Job is scheduled when this interface receives an IGMPv1 style + // message, upon expiration the igmpV1Present flag is cleared. + // igmpV1Job may not be nil once igmpState is initialized. + igmpV1Job *tcpip.Job +} + +// Enabled implements ip.MulticastGroupProtocol. +func (igmp *igmpState) Enabled() bool { + // No need to perform IGMP on loopback interfaces since they don't have + // neighbouring nodes. + return igmp.ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() && igmp.ep.Enabled() +} + +// SendReport implements ip.MulticastGroupProtocol. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { + igmpType := header.IGMPv2MembershipReport + if igmp.v1Present() { + igmpType = header.IGMPv1MembershipReport + } + return igmp.writePacket(groupAddress, groupAddress, igmpType) +} + +// SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + // As per RFC 2236 Section 6, Page 8: "If the interface state says the + // Querier is running IGMPv1, this action SHOULD be skipped. If the flag + // saying we were the last host to report is cleared, this action MAY be + // skipped." + if igmp.v1Present() { + return nil + } + _, err := igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup) + return err +} + +// init sets up an igmpState struct, and is required to be called before using +// a new igmpState. +// +// Must only be called once for the lifetime of igmp. +func (igmp *igmpState) init(ep *endpoint) { + igmp.ep = ep + igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ + Rand: ep.protocol.stack.Rand(), + Clock: ep.protocol.stack.Clock(), + Protocol: igmp, + MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, + AllNodesAddress: header.IPv4AllSystems, + }) + igmp.igmpV1Present = igmpV1PresentDefault + igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() { + igmp.setV1Present(false) + }) +} + +// handleIGMP handles an IGMP packet. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { + stats := igmp.ep.protocol.stack.Stats() + received := stats.IGMP.PacketsReceived + headerView, ok := pkt.Data.PullUp(header.IGMPMinimumSize) + if !ok { + received.Invalid.Increment() + return + } + h := header.IGMP(headerView) + + // Temporarily reset the checksum field to 0 in order to calculate the proper + // checksum. + wantChecksum := h.Checksum() + h.SetChecksum(0) + gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) + h.SetChecksum(wantChecksum) + + if gotChecksum != wantChecksum { + received.ChecksumErrors.Increment() + return + } + + switch h.Type() { + case header.IGMPMembershipQuery: + received.MembershipQuery.Increment() + if len(headerView) < header.IGMPQueryMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipQuery(h.GroupAddress(), h.MaxRespTime()) + case header.IGMPv1MembershipReport: + received.V1MembershipReport.Increment() + if len(headerView) < header.IGMPReportMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipReport(h.GroupAddress()) + case header.IGMPv2MembershipReport: + received.V2MembershipReport.Increment() + if len(headerView) < header.IGMPReportMinimumSize { + received.Invalid.Increment() + return + } + igmp.handleMembershipReport(h.GroupAddress()) + case header.IGMPLeaveGroup: + received.LeaveGroup.Increment() + // As per RFC 2236 Section 6, Page 7: "IGMP messages other than Query or + // Report, are ignored in all states" + + default: + // As per RFC 2236 Section 2.1 Page 3: "Unrecognized message types should + // be silently ignored. New message types may be used by newer versions of + // IGMP, by multicast routing protocols, or other uses." + received.Unrecognized.Increment() + } +} + +func (igmp *igmpState) v1Present() bool { + return atomic.LoadUint32(&igmp.igmpV1Present) == 1 +} + +func (igmp *igmpState) setV1Present(v bool) { + if v { + atomic.StoreUint32(&igmp.igmpV1Present, 1) + } else { + atomic.StoreUint32(&igmp.igmpV1Present, 0) + } +} + +// handleMembershipQuery handles a membership query. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) { + // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero + // then change the state to note that an IGMPv1 router is present and + // schedule the query received Job. + if maxRespTime == 0 && igmp.Enabled() { + igmp.igmpV1Job.Cancel() + igmp.igmpV1Job.Schedule(v1RouterPresentTimeout) + igmp.setV1Present(true) + maxRespTime = v1MaxRespTime + } + + igmp.genericMulticastProtocol.HandleQueryLocked(groupAddress, maxRespTime) +} + +// handleMembershipReport handles a membership report. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { + igmp.genericMulticastProtocol.HandleReportLocked(groupAddress) +} + +// writePacket assembles and sends an IGMP packet. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) { + igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize)) + igmpData.SetType(igmpType) + igmpData.SetGroupAddress(groupAddress) + igmpData.SetChecksum(header.IGMPCalculateChecksum(igmpData)) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(igmp.ep.MaxHeaderLength()), + Data: buffer.View(igmpData).ToVectorisedView(), + }) + + addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */) + if addressEndpoint == nil { + return false, nil + } + localAddr := addressEndpoint.AddressWithPrefix().Address + addressEndpoint.DecRef() + addressEndpoint = nil + igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{ + Protocol: header.IGMPProtocolNumber, + TTL: header.IGMPTTL, + TOS: stack.DefaultTOS, + }, header.IPv4OptionsSerializer{ + &header.IPv4SerializableRouterAlertOption{}, + }) + + sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent + if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + sentStats.Dropped.Increment() + return false, err + } + switch igmpType { + case header.IGMPv1MembershipReport: + sentStats.V1MembershipReport.Increment() + case header.IGMPv2MembershipReport: + sentStats.V2MembershipReport.Increment() + case header.IGMPLeaveGroup: + sentStats.LeaveGroup.Increment() + default: + panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType)) + } + return true, nil +} + +// joinGroup handles adding a new group to the membership map, setting up the +// IGMP state for the group, and sending and scheduling the required +// messages. +// +// If the group already exists in the membership map, returns +// tcpip.ErrDuplicateAddress. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) { + igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress) +} + +// isInGroup returns true if the specified group has been joined locally. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool { + return igmp.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress) +} + +// leaveGroup handles removing the group from the membership map, cancels any +// delay timers associated with that group, and sends the Leave Group message +// if required. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { + // LeaveGroup returns false only if the group was not joined. + if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { + return nil + } + + return tcpip.ErrBadLocalAddress +} + +// softLeaveAll leaves all groups from the perspective of IGMP, but remains +// joined locally. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) softLeaveAll() { + igmp.genericMulticastProtocol.MakeAllNonMemberLocked() +} + +// initializeAll attemps to initialize the IGMP state for each group that has +// been joined locally. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) initializeAll() { + igmp.genericMulticastProtocol.InitializeGroupsLocked() +} + +// sendQueuedReports attempts to send any reports that are queued for sending. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) sendQueuedReports() { + igmp.genericMulticastProtocol.SendQueuedReportsLocked() +} diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go new file mode 100644 index 000000000..1ee573ac8 --- /dev/null +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -0,0 +1,215 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + addr = tcpip.Address("\x0a\x00\x00\x01") + multicastAddr = tcpip.Address("\xe0\x00\x00\x03") + nicID = 1 +) + +// validateIgmpPacket checks that a passed PacketInfo is an IPv4 IGMP packet +// sent to the provided address with the passed fields set. Raises a t.Error if +// any field does not match. +func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) { + t.Helper() + + payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + checker.IPv4(t, payload, + checker.SrcAddr(addr), + checker.DstAddr(remoteAddress), + // TTL for an IGMP message must be 1 as per RFC 2236 section 2. + checker.TTL(1), + checker.IPv4RouterAlert(), + checker.IGMP( + checker.IGMPType(igmpType), + checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), + checker.IGMPGroupAddress(groupAddress), + ), + ) +} + +func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { + t.Helper() + + // Create an endpoint of queue size 1, since no more than 1 packets are ever + // queued in the tests in this file. + e := channel.New(1, 1280, linkAddr) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocolWithOptions(ipv4.Options{ + IGMP: ipv4.IGMPOptions{ + Enabled: igmpEnabled, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + return e, s, clock +} + +func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, maxRespTime byte, groupAddress tcpip.Address) { + buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize) + + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(buf)), + TTL: 1, + Protocol: uint8(header.IGMPProtocolNumber), + SrcAddr: header.IPv4Any, + DstAddr: header.IPv4AllSystems, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + igmp := header.IGMP(buf[header.IPv4MinimumSize:]) + igmp.SetType(igmpType) + igmp.SetMaxRespTime(maxRespTime) + igmp.SetGroupAddress(groupAddress) + igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) + + e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// TestIgmpV1Present tests the handling of the case where an IGMPv1 router is +// present on the network. The IGMP stack will then send IGMPv1 Membership +// reports for backwards compatibility. +func TestIgmpV1Present(t *testing.T) { + e, s, clock := createStack(t, true) + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + } + + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) + } + + // This NIC will send an IGMPv2 report immediately, before this test can get + // the IGMPv1 General Membership Query in. + p, ok := e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V2MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V2MembershipReport.Value(); got != 1 { + t.Fatalf("got V2MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + if t.Failed() { + t.FailNow() + } + + // Inject an IGMPv1 General Membership Query which is identical to a standard + // membership query except the Max Response Time is set to 0, which will tell + // the stack that this is a router using IGMPv1. Send it to the all systems + // group which is the only group this host belongs to. + createAndInjectIGMPPacket(e, header.IGMPMembershipQuery, 0, header.IPv4AllSystems) + if got := s.Stats().IGMP.PacketsReceived.MembershipQuery.Value(); got != 1 { + t.Fatalf("got Membership Queries received = %d, want = 1", got) + } + + // Before advancing the clock, verify that this host has not sent a + // V1MembershipReport yet. + if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 0 { + t.Fatalf("got V1MembershipReport messages sent = %d, want = 0", got) + } + + // Verify the solicited Membership Report is sent. Now that this NIC has seen + // an IGMPv1 query, it should send an IGMPv1 Membership Report. + p, ok = e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected V1MembershipReport only after advancing the clock = %+v", p.Pkt) + } + clock.Advance(ipv4.UnsolicitedReportIntervalMax) + p, ok = e.Read() + if !ok { + t.Fatal("unable to Read IGMP packet, expected V1MembershipReport") + } + if got := s.Stats().IGMP.PacketsSent.V1MembershipReport.Value(); got != 1 { + t.Fatalf("got V1MembershipReport messages sent = %d, want = 1", got) + } + validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr) +} + +func TestSendQueuedIGMPReports(t *testing.T) { + e, s, clock := createStack(t, true) + + // Joining a group without an assigned address should queue IGMP packets; none + // should be sent without an assigned address. + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err) + } + reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport + if got := reportStat.Value(); got != 0 { + t.Errorf("got reportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("got unexpected packet = %#v", p) + } + + // The initial set of IGMP reports that were queued should be sent once an + // address is assigned. + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + } + if got := reportStat.Value(); got != 1 { + t.Errorf("got reportStat.Value() = %d, want = 1", got) + } + if p, ok := e.Read(); !ok { + t.Error("expected to send an IGMP membership report") + } else { + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + } + if t.Failed() { + t.FailNow() + } + clock.Advance(ipv4.UnsolicitedReportIntervalMax) + if got := reportStat.Value(); got != 2 { + t.Errorf("got reportStat.Value() = %d, want = 2", got) + } + if p, ok := e.Read(); !ok { + t.Error("expected to send an IGMP membership report") + } else { + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + } + if t.Failed() { + t.FailNow() + } + + // Should have no more packets to send after the initial set of unsolicited + // reports. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("got unexpected packet = %#v", p) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index a376cb8ec..e9ff70d04 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -83,6 +83,7 @@ type endpoint struct { sync.RWMutex addressableEndpointState stack.AddressableEndpointState + igmp igmpState } } @@ -93,7 +94,10 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa dispatcher: dispatcher, protocol: p, } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) + e.mu.igmp.init(e) + e.mu.Unlock() return e } @@ -121,11 +125,22 @@ func (e *endpoint) Enable() *tcpip.Error { // We have no need for the address endpoint. ep.DecRef() + // Groups may have been joined while the endpoint was disabled, or the + // endpoint may have left groups from the perspective of IGMP when the + // endpoint was disabled. Either way, we need to let routers know to + // send us multicast traffic. + e.mu.igmp.initializeAll() + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts // multicast group. Note, the IANA calls the all-hosts multicast group the // all-systems multicast group. - _, err = e.mu.addressableEndpointState.JoinGroup(header.IPv4AllSystems) - return err + if err := e.joinGroupLocked(header.IPv4AllSystems); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv4 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllSystems, err)) + } + + return nil } // Enabled implements stack.NetworkEndpoint. @@ -157,19 +172,27 @@ func (e *endpoint) Disable() { } func (e *endpoint) disableLocked() { - if !e.setEnabled(false) { + if !e.isEnabled() { return } // The endpoint may have already left the multicast group. - if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { + if err := e.leaveGroupLocked(header.IPv4AllSystems); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) } + // Leave groups from the perspective of IGMP so that routers know that + // we are no longer interested in the group. + e.mu.igmp.softLeaveAll() + // The address may have already been removed. if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) } + + if !e.setEnabled(false) { + panic("should have only done work to disable the endpoint if it was enabled") + } } // DefaultTTL is the default time-to-live value for this endpoint. @@ -198,37 +221,34 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, options header.IPv4OptionsSerializer) { hdrLen := header.IPv4MinimumSize - var opts header.IPv4Options - if params.Options != nil { - var ok bool - if opts, ok = params.Options.(header.IPv4Options); !ok { - panic(fmt.Sprintf("want IPv4Options, got %T", params.Options)) - } - hdrLen += opts.AllocationSize() - if hdrLen > header.IPv4MaximumHeaderSize { - // Since we have no way to report an error we must either panic or create - // a packet which is different to what was requested. Choose panic as this - // would be a programming error that should be caught in testing. - panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", params.Options.AllocationSize(), header.IPv4MaximumOptionsSize)) - } + var optLen int + if options != nil { + optLen = int(options.Length()) + } + hdrLen += optLen + if hdrLen > header.IPv4MaximumHeaderSize { + // Since we have no way to report an error we must either panic or create + // a packet which is different to what was requested. Choose panic as this + // would be a programming error that should be caught in testing. + panic(fmt.Sprintf("IPv4 Options %d bytes, Max %d", optLen, header.IPv4MaximumOptionsSize)) } ip := header.IPv4(pkt.NetworkHeader().Push(hdrLen)) length := uint16(pkt.Size()) // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic // datagrams. Since the DF bit is never being set here, all datagrams // are non-atomic and need an ID. - id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) + id := atomic.AddUint32(&e.protocol.ids[hashRoute(srcAddr, dstAddr, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ TotalLength: length, ID: uint16(id), TTL: params.TTL, TOS: params.TOS, Protocol: uint8(params.Protocol), - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - Options: opts, + SrcAddr: srcAddr, + DstAddr: dstAddr, + Options: options, }) ip.SetChecksum(^ip.CalculateChecksum()) pkt.NetworkProtocolNumber = ProtocolNumber @@ -259,17 +279,14 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r, pkt, params) - return e.writePacket(r, gso, pkt) -} + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */) -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesOutputDropped.Increment() + e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() return nil } @@ -286,24 +303,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet if err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) // Since we rewrote the packet but it is being routed back to us, we can // safely assume the checksum is valid. pkt.RXTransportChecksumValidated = true - ep.HandlePacket(pkt) + ep.(*endpoint).handlePacket(pkt) } return nil } } + return e.writePacket(r, gso, pkt, false /* headerIncluded */) +} + +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, headerIncluded bool) *tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - loopedR := r.MakeLoopedRoute() - loopedR.PopulatePacketInfo(pkt) - loopedR.Release() - e.HandlePacket(pkt) + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = !headerIncluded + e.handlePacket(pkt) } } if r.Loop&stack.PacketOut == 0 { @@ -347,7 +367,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { - e.addIPHeader(r, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* options */) networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) if err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) @@ -374,8 +394,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - ipt := e.protocol.stack.IPTables() - dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -400,9 +419,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.(*endpoint).handlePacket(pkt) } n++ continue @@ -461,7 +481,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // non-atomic datagrams, so assign an ID to all such datagrams // according to the definition given in RFC 6864 section 4. if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { - ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) + ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r.LocalAddress, r.RemoteAddress, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } } @@ -479,16 +499,85 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrMalformedHeader } - return e.writePacket(r, nil /* gso */, pkt) + return e.writePacket(r, nil /* gso */, pkt, true /* headerIncluded */) +} + +// forwardPacket attempts to forward a packet to its final destination. +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { + h := header.IPv4(pkt.NetworkHeader().View()) + ttl := h.TTL() + if ttl == 0 { + // As per RFC 792 page 6, Time Exceeded Message, + // + // If the gateway processing a datagram finds the time to live field + // is zero it must discard the datagram. The gateway may also notify + // the source host via the time exceeded message. + return e.protocol.returnError(&icmpReasonTTLExceeded{}, pkt) + } + + dstAddr := h.DestinationAddress() + + // Check if the destination is owned by the stack. + networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) + if err == nil { + networkEndpoint.(*endpoint).handlePacket(pkt) + return nil + } + if err != tcpip.ErrBadAddress { + return err + } + + r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + // We need to do a deep copy of the IP packet because + // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do + // not own it. + newHdr := header.IPv4(stack.PayloadSince(pkt.NetworkHeader())) + + // As per RFC 791 page 30, Time to Live, + // + // This field must be decreased at each point that the internet header + // is processed to reflect the time spent processing the datagram. + // Even if no local information is available on the time actually + // spent, the field must be decremented by 1. + newHdr.SetTTL(ttl - 1) + + return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(newHdr).ToVectorisedView(), + })) } // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() + stats.IP.PacketsReceived.Increment() + if !e.isEnabled() { + stats.IP.DisabledPacketsReceived.Increment() return } + // Loopback traffic skips the prerouting chain. + if !e.nic.IsLoopback() { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + // iptables is telling us to drop the packet. + stats.IP.IPTablesPreroutingDropped.Increment() + return + } + } + + e.handlePacket(pkt) +} + +// handlePacket is like HandlePacket except it does not perform the prerouting +// iptables hook. +func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.protocol.stack.Stats() @@ -524,19 +613,46 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() + // As per RFC 1122 section 3.2.1.3: // When a host sends any datagram, the IP source address MUST // be one of its own IP addresses (but not a broadcast or // multicast address). - if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) { + if srcAddr == header.IPv4Broadcast || header.IsV4MulticastAddress(srcAddr) { stats.IP.InvalidSourceAddressesReceived.Increment() return } + // Make sure the source address is not a subnet-local broadcast address. + if addressEndpoint := e.AcquireAssignedAddress(srcAddr, false /* createTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + subnet := addressEndpoint.Subnet() + addressEndpoint.DecRef() + if subnet.IsBroadcast(srcAddr) { + stats.IP.InvalidSourceAddressesReceived.Increment() + return + } + } + + // The destination address should be an address we own or a group we joined + // for us to receive the packet. Otherwise, attempt to forward the packet. + if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil { + subnet := addressEndpoint.AddressWithPrefix().Subnet() + addressEndpoint.DecRef() + pkt.NetworkPacketInfo.LocalAddressBroadcast = subnet.IsBroadcast(dstAddr) || dstAddr == header.IPv4Broadcast + } else if !e.IsInGroup(dstAddr) { + if !e.protocol.Forwarding() { + stats.IP.InvalidDestinationAddressesReceived.Increment() + return + } + + _ = e.forwardPacket(pkt) + return + } // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. stats.IP.IPTablesInputDropped.Increment() return @@ -565,29 +681,8 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - // Set up a callback in case we need to send a Time Exceeded Message, as per - // RFC 792: - // - // If a host reassembling a fragmented datagram cannot complete the - // reassembly due to missing fragments within its time limit it discards - // the datagram, and it may send a time exceeded message. - // - // If fragment zero is not available then no time exceeded need be sent at - // all. - var releaseCB func(bool) - if start == 0 { - pkt := pkt.Clone() - releaseCB = func(timedOut bool) { - if timedOut { - _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt) - } - } - } - - var ready bool - var err error proto := h.Protocol() - pkt.Data, _, ready, err = e.protocol.fragmentation.Process( + data, _, ready, err := e.protocol.fragmentation.Process( // As per RFC 791 section 2.3, the identification value is unique // for a source-destination pair and protocol. fragmentation.FragmentID{ @@ -600,8 +695,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { start+uint16(pkt.Data.Size())-1, h.More(), proto, - pkt.Data, - releaseCB, + pkt, ) if err != nil { stats.IP.MalformedPacketsReceived.Increment() @@ -611,6 +705,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !ready { return } + pkt.Data = data // The reassembler doesn't take care of fixing up the header, so we need // to do it here. @@ -628,11 +723,17 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.handleICMP(pkt) return } - if len(h.Options()) != 0 { + if p == header.IGMPProtocolNumber { + e.mu.Lock() + e.mu.igmp.handleIGMP(pkt) + e.mu.Unlock() + return + } + if opts := h.Options(); len(opts) != 0 { // TODO(gvisor.dev/issue/4586): // When we add forwarding support we should use the verified options // rather than just throwing them away. - aux, _, err := e.processIPOptions(pkt, h.Options(), &optionUsageReceive{}) + aux, _, err := e.processIPOptions(pkt, opts, &optionUsageReceive{}) if err != nil { switch { case @@ -683,7 +784,12 @@ func (e *endpoint) Close() { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + if err == nil { + e.mu.igmp.sendQueuedReports() + } + return ep, err } // RemovePermanentAddress implements stack.AddressableEndpoint. @@ -706,34 +812,26 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo defer e.mu.Unlock() loopback := e.nic.IsLoopback() - addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool { + return e.mu.addressableEndpointState.AcquireAssignedAddressOrMatching(localAddr, func(addressEndpoint stack.AddressEndpoint) bool { subnet := addressEndpoint.Subnet() // IPv4 has a notion of a subnet broadcast address and considers the // loopback interface bound to an address's whole subnet (on linux). return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr)) - }) - if addressEndpoint != nil { - return addressEndpoint - } - - if !allowTemp { - return nil - } - - addr := localAddr.WithPrefix() - addressEndpoint, err := e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(addr, tempPEB) - if err != nil { - // AddAddress only returns an error if the address is already assigned, - // but we just checked above if the address exists so we expect no error. - panic(fmt.Sprintf("e.mu.addressableEndpointState.AddAndAcquireTemporaryAddress(%s, %d): %s", addr, tempPEB, err)) - } - return addressEndpoint + }, allowTemp, tempPEB) } // AcquireOutgoingPrimaryAddress implements stack.AddressableEndpoint. func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { e.mu.RLock() defer e.mu.RUnlock() + return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) +} + +// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress +// but with locking requirements +// +// Precondition: igmp.ep.mu must be read locked. +func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) } @@ -752,32 +850,48 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.joinGroupLocked(addr) +} + +// joinGroupLocked is like JoinGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { if !header.IsV4MulticastAddress(addr) { - return false, tcpip.ErrBadAddress + return tcpip.ErrBadAddress } - e.mu.Lock() - defer e.mu.Unlock() - return e.mu.addressableEndpointState.JoinGroup(addr) + e.mu.igmp.joinGroup(addr) + return nil } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.LeaveGroup(addr) + return e.leaveGroupLocked(addr) +} + +// leaveGroupLocked is like LeaveGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { + return e.mu.igmp.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mu.addressableEndpointState.IsInGroup(addr) + return e.mu.igmp.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) +var _ fragmentation.TimeoutHandler = (*protocol)(nil) type protocol struct { stack *stack.Stack @@ -798,6 +912,8 @@ type protocol struct { hashIV uint32 fragmentation *fragmentation.Fragmentation + + options Options } // Number returns the ipv4 protocol number. @@ -922,17 +1038,23 @@ func addressToUint32(addr tcpip.Address) uint32 { return uint32(addr[0]) | uint32(addr[1])<<8 | uint32(addr[2])<<16 | uint32(addr[3])<<24 } -// hashRoute calculates a hash value for the given route. It uses the source & -// destination address, the transport protocol number and a 32-bit number to -// generate the hash. -func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { - a := addressToUint32(r.LocalAddress) - b := addressToUint32(r.RemoteAddress) +// hashRoute calculates a hash value for the given source/destination pair using +// the addresses, transport protocol number and a 32-bit number to generate the +// hash. +func hashRoute(srcAddr, dstAddr tcpip.Address, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 { + a := addressToUint32(srcAddr) + b := addressToUint32(dstAddr) return hash.Hash3Words(a, b, uint32(protocol), hashIV) } -// NewProtocol returns an IPv4 network protocol. -func NewProtocol(s *stack.Stack) stack.NetworkProtocol { +// Options holds options to configure a new protocol. +type Options struct { + // IGMP holds options for IGMP. + IGMP IGMPOptions +} + +// NewProtocolWithOptions returns an IPv4 network protocol. +func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { ids := make([]uint32, buckets) // Randomly initialize hashIV and the ids. @@ -942,15 +1064,24 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { } hashIV := r[buckets] - return &protocol{ - stack: s, - ids: ids, - hashIV: hashIV, - defaultTTL: DefaultTTL, - fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), + return func(s *stack.Stack) stack.NetworkProtocol { + p := &protocol{ + stack: s, + ids: ids, + hashIV: hashIV, + defaultTTL: DefaultTTL, + options: opts, + } + p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) + return p } } +// NewProtocol is equivalent to NewProtocolWithOptions with an empty Options. +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return NewProtocolWithOptions(Options{})(s) +} + func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader header.IPv4) (*stack.PacketBuffer, bool) { fragPkt, offset, copied, more := pf.BuildNextFragment() fragPkt.NetworkProtocolNumber = ProtocolNumber @@ -1063,6 +1194,12 @@ func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Addres } pointer := tsOpt.Pointer() + // RFC 791 page 22 states: "The smallest legal value is 5." + // Since the pointer is 1 based, and the header is 4 bytes long the + // pointer must point beyond the header therefore 4 or less is bad. + if pointer <= header.IPv4OptionTimestampHdrLength { + return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + } // To simplify processing below, base further work on the array of timestamps // beyond the header, rather than on the whole option. Also to aid // calculations set 'nextSlot' to be 0 based as in the packet it is 1 based. @@ -1149,7 +1286,15 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength } - nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based. + pointer := rrOpt.Pointer() + // RFC 791 page 20 states: + // The pointer is relative to this option, and the + // smallest legal value for the pointer is 4. + // Since the pointer is 1 based, and the header is 3 bytes long the + // pointer must point beyond the header therefore 3 or less is bad. + if pointer <= header.IPv4OptionRecordRouteHdrLength { + return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + } // RFC 791 page 21 says // If the route data area is already full (the pointer exceeds the @@ -1164,14 +1309,14 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // do this (as do most implementations). It is probable that the inclusion // of these words is a copy/paste error from the timestamp option where // there are two failure reasons given. - if nextSlot >= optlen { + if pointer > optlen { return 0, nil } // The data area isn't full but there isn't room for a new entry. // Either Length or Pointer could be bad. We must select Pointer for Linux - // compatibility, even if only the length is bad. - if nextSlot+header.IPv4AddressSize > optlen { + // compatibility, even if only the length is bad. NB. pointer is 1 based. + if pointer+header.IPv4AddressSize > optlen+1 { if false { // This is what we would do if we were not being Linux compatible. // Check for bad pointer or length value. Must be a multiple of 4 after diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index c6e565455..1c4919b1e 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -15,9 +15,11 @@ package ipv4_test import ( + "bytes" "context" "encoding/hex" "fmt" + "io/ioutil" "math" "net" "testing" @@ -103,6 +105,163 @@ func TestExcludeBroadcast(t *testing.T) { }) } +func TestForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + randomSequence = 123 + randomIdent = 42 + ) + + ipv4Addr1 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10.0.0.1").To4()), + PrefixLen: 8, + } + ipv4Addr2 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11.0.0.1").To4()), + PrefixLen: 8, + } + remoteIPv4Addr1 := tcpip.Address(net.ParseIP("10.0.0.2").To4()) + remoteIPv4Addr2 := tcpip.Address(net.ParseIP("11.0.0.2").To4()) + + tests := []struct { + name string + TTL uint8 + expectErrorICMP bool + }{ + { + name: "TTL of zero", + TTL: 0, + expectErrorICMP: true, + }, + { + name: "TTL of one", + TTL: 1, + expectErrorICMP: false, + }, + { + name: "TTL of two", + TTL: 2, + expectErrorICMP: false, + }, + { + name: "Max TTL", + TTL: math.MaxUint8, + expectErrorICMP: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + }) + // We expect at most a single packet in response to our ICMP Echo Request. + e1 := channel.New(1, ipv4.MaxTotalSize, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + ipv4ProtoAddr1 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr1} + if err := s.AddProtocolAddress(nicID1, ipv4ProtoAddr1); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv4ProtoAddr1, err) + } + + e2 := channel.New(1, ipv4.MaxTotalSize, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + ipv4ProtoAddr2 := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr2} + if err := s.AddProtocolAddress(nicID2, ipv4ProtoAddr2); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv4ProtoAddr2, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Addr1.Subnet(), + NIC: nicID1, + }, + { + Destination: ipv4Addr2.Subnet(), + NIC: nicID2, + }, + }) + + if err := s.SetForwarding(header.IPv4ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", header.IPv4ProtocolNumber, err) + } + + totalLen := uint16(header.IPv4MinimumSize + header.ICMPv4MinimumSize) + hdr := buffer.NewPrependable(int(totalLen)) + icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmp.SetIdent(randomIdent) + icmp.SetSequence(randomSequence) + icmp.SetType(header.ICMPv4Echo) + icmp.SetCode(header.ICMPv4UnusedCode) + icmp.SetChecksum(0) + icmp.SetChecksum(^header.Checksum(icmp, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: totalLen, + Protocol: uint8(header.ICMPv4ProtocolNumber), + TTL: test.TTL, + SrcAddr: remoteIPv4Addr1, + DstAddr: remoteIPv4Addr2, + }) + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e1.InjectInbound(header.IPv4ProtocolNumber, requestPkt) + + if test.expectErrorICMP { + reply, ok := e1.Read() + if !ok { + t.Fatal("expected ICMP TTL Exceeded packet through incoming NIC") + } + + checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(ipv4Addr1.Address), + checker.DstAddr(remoteIPv4Addr1), + checker.TTL(ipv4.DefaultTTL), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4TimeExceeded), + checker.ICMPv4Code(header.ICMPv4TTLExceeded), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) + + if n := e2.Drain(); n != 0 { + t.Fatalf("got e2.Drain() = %d, want = 0", n) + } + } else { + reply, ok := e2.Read() + if !ok { + t.Fatal("expected ICMP Echo packet through outgoing NIC") + } + + checker.IPv4(t, header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(remoteIPv4Addr1), + checker.DstAddr(remoteIPv4Addr2), + checker.TTL(test.TTL-1), + checker.ICMPv4( + checker.ICMPv4Checksum(), + checker.ICMPv4Type(header.ICMPv4Echo), + checker.ICMPv4Code(header.ICMPv4UnusedCode), + checker.ICMPv4Payload(nil), + ), + ) + + if n := e1.Drain(); n != 0 { + t.Fatalf("got e1.Drain() = %d, want = 0", n) + } + } + }) + } +} + // TestIPv4Sanity sends IP/ICMP packets with various problems to the stack and // checks the response. func TestIPv4Sanity(t *testing.T) { @@ -319,7 +478,7 @@ func TestIPv4Sanity(t *testing.T) { 68, 7, 5, 0, // ^ ^ Linux points here which is wrong. // | Not a multiple of 4 - 1, 2, 3, + 1, 2, 3, 0, }, shouldFail: true, expectErrorICMP: true, @@ -398,6 +557,56 @@ func TestIPv4Sanity(t *testing.T) { }, }, { + // Timestamp pointer uses one based counting so 0 is invalid. + name: "timestamp pointer invalid", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 8, 0, 0x00, + // ^ 0 instead of 5 or more. + 0, 0, 0, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + // Timestamp pointer cannot be less than 5. It must point past the header + // which is 4 bytes. (1 based counting) + name: "timestamp pointer too small by 1", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 8, header.IPv4OptionTimestampHdrLength, 0x00, + // ^ header is 4 bytes, so 4 should fail. + 0, 0, 0, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + name: "valid timestamp pointer", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 68, 8, header.IPv4OptionTimestampHdrLength + 1, 0x00, + // ^ header is 4 bytes, so 5 should succeed. + 0, 0, 0, 0, + }, + replyOptions: header.IPv4Options{ + 68, 8, 9, 0x00, + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { // Needs 8 bytes for a type 1 timestamp but there are only 4 free. name: "bad timer element alignment", maxTotalLength: ipv4.MaxTotalSize, @@ -528,7 +737,61 @@ func TestIPv4Sanity(t *testing.T) { }, }, { - // Confirm linux bug for bug compatibility. + // Pointer uses one based counting so 0 is invalid. + name: "record route pointer zero", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 8, 0, // 3 byte header + 0, 0, 0, 0, + 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + // Pointer must be 4 or more as it must point past the 3 byte header + // using 1 based counting. 3 should fail. + name: "record route pointer too small by 1", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 8, header.IPv4OptionRecordRouteHdrLength, // 3 byte header + 0, 0, 0, 0, + 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + // Pointer must be 4 or more as it must point past the 3 byte header + // using 1 based counting. Check 4 passes. (Duplicates "single + // record route with room") + name: "valid record route pointer", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: header.IPv4Options{ + 7, 7, header.IPv4OptionRecordRouteHdrLength + 1, // 3 byte header + 0, 0, 0, 0, + 0, + }, + replyOptions: header.IPv4Options{ + 7, 7, 8, // 3 byte header + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Confirm Linux bug for bug compatibility. // Linux returns slot 22 but the error is in slot 21. name: "multiple record route with not enough room", maxTotalLength: ipv4.MaxTotalSize, @@ -599,9 +862,12 @@ func TestIPv4Sanity(t *testing.T) { }, }) - ipHeaderLength := header.IPv4MinimumSize + test.options.AllocationSize() + if len(test.options)%4 != 0 { + t.Fatalf("options must be aligned to 32 bits, invalid test options: %x (len=%d)", test.options, len(test.options)) + } + ipHeaderLength := header.IPv4MinimumSize + len(test.options) if ipHeaderLength > header.IPv4MaximumHeaderSize { - t.Fatalf("too many bytes in options: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) + t.Fatalf("IP header length too large: got = %d, want <= %d ", ipHeaderLength, header.IPv4MaximumHeaderSize) } totalLen := uint16(ipHeaderLength + header.ICMPv4MinimumSize) hdr := buffer.NewPrependable(int(totalLen)) @@ -618,16 +884,26 @@ func TestIPv4Sanity(t *testing.T) { if test.maxTotalLength < totalLen { totalLen = test.maxTotalLength } + ip.Encode(&header.IPv4Fields{ TotalLength: totalLen, Protocol: test.transportProtocol, TTL: test.TTL, SrcAddr: remoteIPv4Addr, DstAddr: ipv4Addr.Address, - Options: test.options, }) if test.headerLength != 0 { ip.SetHeaderLength(test.headerLength) + } else { + // Set the calculated header length, since we may manually add options. + ip.SetHeaderLength(uint8(ipHeaderLength)) + } + if len(test.options) != 0 { + // Copy options manually. We do not use Encode for options so we can + // verify malformed options with handcrafted payloads. + if want, got := copy(ip.Options(), test.options), len(test.options); want != got { + t.Fatalf("got copy(ip.Options(), test.options) = %d, want = %d", got, want) + } } ip.SetChecksum(0) ipHeaderChecksum := ip.CalculateChecksum() @@ -2049,6 +2325,28 @@ func TestReceiveFragments(t *testing.T) { }, expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, }, + { + name: "Two fragments with MF flag reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload4Addr1ToAddr2[:65512], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 65512, + payload: ipv4Payload4Addr1ToAddr2[65512:], + }, + }, + expectedPayloads: nil, + }, } for _, test := range tests { @@ -2112,18 +2410,26 @@ func TestReceiveFragments(t *testing.T) { t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) } + const rcvSize = 65536 // Account for reassembled packets. for i, expectedPayload := range test.expectedPayloads { - gotPayload, _, err := ep.Read(nil) + var buf bytes.Buffer + result, err := ep.Read(&buf, rcvSize, tcpip.ReadOptions{}) if err != nil { - t.Fatalf("(i=%d) Read(nil): %s", i, err) + t.Fatalf("(i=%d) Read: %s", i, err) } - if diff := cmp.Diff(buffer.View(expectedPayload), gotPayload); diff != "" { + if diff := cmp.Diff(tcpip.ReadResult{ + Count: len(expectedPayload), + Total: len(expectedPayload), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff) + } + if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" { t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) } } - if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + if res, err := ep.Read(ioutil.Discard, rcvSize, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { + t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) } }) } @@ -2242,7 +2548,7 @@ func TestWriteStats(t *testing.T) { test.setup(t, rt.Stack()) - nWritten, _ := writer.writePackets(&rt, pkts) + nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) @@ -2259,7 +2565,7 @@ func TestWriteStats(t *testing.T) { } } -func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { +func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, }) @@ -2441,9 +2747,6 @@ func TestPacketQueing(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) } - if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } if err := s.AddProtocolAddress(nicID, host1IPv4Addr); err != nil { t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, host1IPv4Addr, err) } diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 0ac24a6fb..afa45aefe 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -8,6 +8,7 @@ go_library( "dhcpv6configurationfromndpra_string.go", "icmp.go", "ipv6.go", + "mld.go", "ndp.go", ], visibility = ["//visibility:public"], @@ -19,6 +20,7 @@ go_library( "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", + "//pkg/tcpip/network/ip", "//pkg/tcpip/stack", ], ) @@ -49,3 +51,19 @@ go_test( "@com_github_google_go_cmp//cmp:go_default_library", ], ) + +go_test( + name = "ipv6_x_test", + size = "small", + srcs = ["mld_test.go"], + deps = [ + ":ipv6", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 8502b848c..6ee162713 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -126,8 +126,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { stats := e.protocol.stack.Stats().ICMP - sent := stats.V6PacketsSent - received := stats.V6PacketsReceived + sent := stats.V6.PacketsSent + received := stats.V6.PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a // full explanation. @@ -163,7 +163,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } // TODO(b/112892170): Meaningfully handle all ICMP types. - switch h.Type() { + switch icmpType := h.Type(); icmpType { case header.ICMPv6PacketTooBig: received.PacketTooBig.Increment() hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize) @@ -358,7 +358,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize)) packet.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(packet.NDPPayload()) + na := header.NDPNeighborAdvert(packet.MessageBody()) // As per RFC 4861 section 7.2.4: // @@ -644,8 +644,39 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { return } + case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone: + switch icmpType { + case header.ICMPv6MulticastListenerQuery: + received.MulticastListenerQuery.Increment() + case header.ICMPv6MulticastListenerReport: + received.MulticastListenerReport.Increment() + case header.ICMPv6MulticastListenerDone: + received.MulticastListenerDone.Increment() + default: + panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) + } + + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize { + received.Invalid.Increment() + return + } + + switch icmpType { + case header.ICMPv6MulticastListenerQuery: + e.mu.Lock() + e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView())) + e.mu.Unlock() + case header.ICMPv6MulticastListenerReport: + e.mu.Lock() + e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView())) + e.mu.Unlock() + case header.ICMPv6MulticastListenerDone: + default: + panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) + } + default: - received.Invalid.Increment() + received.Unrecognized.Increment() } } @@ -681,12 +712,12 @@ func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remot pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) packet.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(packet.NDPPayload()) + ns := header.NDPNeighborSolicit(packet.MessageBody()) ns.SetTargetAddress(targetAddr) ns.Options().Serialize(optsSerializer) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - stat := p.stack.Stats().ICMP.V6PacketsSent + stat := p.stack.Stats().ICMP.V6.PacketsSent if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, @@ -750,6 +781,12 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +// icmpReasonHopLimitExceeded is an error where a packet's hop limit exceeded in +// transit to its final destination, as per RFC 4443 section 3.3. +type icmpReasonHopLimitExceeded struct{} + +func (*icmpReasonHopLimitExceeded) isICMPReason() {} + // icmpReasonReassemblyTimeout is an error where insufficient fragments are // received to complete reassembly of a packet within a configured time after // the reception of the first-arriving fragment of that packet. @@ -790,29 +827,49 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi allowResponseToMulticast = reason.respondToMulticast } - if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any { + isOrigDstMulticast := header.IsV6MulticastAddress(origIPHdrDst) + if (!allowResponseToMulticast && isOrigDstMulticast) || origIPHdrSrc == header.IPv6Any { return nil } + // If we hit a Hop Limit Exceeded error, then we know we are operating as a + // router. As per RFC 4443 section 3.3: + // + // If a router receives a packet with a Hop Limit of zero, or if a + // router decrements a packet's Hop Limit to zero, it MUST discard the + // packet and originate an ICMPv6 Time Exceeded message with Code 0 to + // the source of the packet. This indicates either a routing loop or + // too small an initial Hop Limit value. + // + // If we are operating as a router, do not use the packet's destination + // address as the response's source address as we should not own the + // destination address of a packet we are forwarding. + // + // If the packet was originally destined to a multicast address, then do not + // use the packet's destination address as the source for the response ICMP + // packet as "multicast addresses must not be used as source addresses in IPv6 + // packets", as per RFC 4291 section 2.7. + localAddr := origIPHdrDst + if _, ok := reason.(*icmpReasonHopLimitExceeded); ok || isOrigDstMulticast { + localAddr = "" + } // Even if we were able to receive a packet from some remote, we may not have // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, localAddr, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } defer route.Release() stats := p.stack.Stats().ICMP - sent := stats.V6PacketsSent + sent := stats.V6.PacketsSent if !p.stack.AllowICMPMessage() { sent.RateLimited.Increment() return nil } - network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() - if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored. // Unfortunately at this time ICMP Packets do not have a transport @@ -830,6 +887,8 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi } } + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + // As per RFC 4443 section 2.4 // // (c) Every ICMPv6 error message (type < 128) MUST include @@ -873,6 +932,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) counter = sent.DstUnreachable + case *icmpReasonHopLimitExceeded: + icmpHdr.SetType(header.ICMPv6TimeExceeded) + icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) + counter = sent.TimeExceeded case *icmpReasonReassemblyTimeout: icmpHdr.SetType(header.ICMPv6TimeExceeded) icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) @@ -896,3 +959,16 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpi counter.Increment() return nil } + +// OnReassemblyTimeout implements fragmentation.TimeoutHandler. +func (p *protocol) OnReassemblyTimeout(pkt *stack.PacketBuffer) { + // OnReassemblyTimeout sends a Time Exceeded Message as per RFC 2460 Section + // 4.5: + // + // If the first fragment (i.e., the one with a Fragment Offset of zero) has + // been received, an ICMP Time Exceeded -- Fragment Reassembly Time Exceeded + // message should be sent to the source of that fragment. + if pkt != nil { + p.returnError(&icmpReasonReassemblyTimeout{}, pkt) + } +} diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 76013daa1..bbce1ef78 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -144,11 +144,14 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) Promiscuous() bool { + return false +} + func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - r := stack.Route{ - NetProto: protocol, - RemoteLinkAddress: remoteLinkAddr, - } + var r stack.Route + r.NetProto = protocol + r.ResolveWith(remoteLinkAddr) return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) } @@ -174,13 +177,8 @@ func TestICMPCounts(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: test.useNeighborCache, }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -206,11 +204,16 @@ func TestICMPCounts(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := lladdr0.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() var tllData [header.NDPLinkLayerAddressSize]byte header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ @@ -267,6 +270,22 @@ func TestICMPCounts(t *testing.T) { typ: header.ICMPv6RedirectMsg, size: header.ICMPv6MinimumSize, }, + { + typ: header.ICMPv6MulticastListenerQuery, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerReport, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerDone, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: 255, /* Unrecognized */ + size: 50, + }, } handleIPv6Payload := func(icmp header.ICMPv6) { @@ -276,13 +295,12 @@ func TestICMPCounts(t *testing.T) { }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -290,7 +308,7 @@ func TestICMPCounts(t *testing.T) { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) handleIPv6Payload(icmp) } @@ -298,7 +316,7 @@ func TestICMPCounts(t *testing.T) { // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) - icmpv6Stats := s.Stats().ICMP.V6PacketsReceived + icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { if got, want := s.Value(), uint64(1); got != want { t.Errorf("got %s = %d, want = %d", name, got, want) @@ -317,13 +335,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, UseNeighborCache: true, }) - { - if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_, _) = %s", err) - } - if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -349,11 +362,16 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := lladdr0.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() var tllData [header.NDPLinkLayerAddressSize]byte header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ @@ -410,6 +428,22 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { typ: header.ICMPv6RedirectMsg, size: header.ICMPv6MinimumSize, }, + { + typ: header.ICMPv6MulticastListenerQuery, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerReport, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: header.ICMPv6MulticastListenerDone, + size: header.MLDMinimumSize + header.ICMPv6HeaderSize, + }, + { + typ: 255, /* Unrecognized */ + size: 50, + }, } handleIPv6Payload := func(icmp header.ICMPv6) { @@ -419,13 +453,12 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -433,7 +466,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) handleIPv6Payload(icmp) } @@ -441,7 +474,7 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) - icmpv6Stats := s.Stats().ICMP.V6PacketsReceived + icmpv6Stats := s.Stats().ICMP.V6.PacketsReceived visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { if got, want := s.Value(), uint64(1); got != want { t.Errorf("got %s = %d, want = %d", name, got, want) @@ -566,7 +599,7 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. return } - if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress { + if len(args.remoteLinkAddr) != 0 && pi.Route.RemoteLinkAddress != args.remoteLinkAddr { t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) } @@ -819,11 +852,11 @@ func TestICMPChecksumValidationSimple(t *testing.T) { } ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), @@ -831,7 +864,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { e.InjectInbound(ProtocolNumber, pkt) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid routerOnly := stats.RouterOnlyPacketsDroppedByHost typStat := typ.statCounter(stats) @@ -896,11 +929,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { errorICMPBody := func(view buffer.View) { ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - NextHeader: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, + PayloadLength: simpleBodySize, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, }) simpleBody(view[header.IPv6MinimumSize:]) } @@ -1014,11 +1047,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(icmpSize), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(icmpSize), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1026,7 +1059,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { e.InjectInbound(ProtocolNumber, pkt) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid typStat := typ.statCounter(stats) @@ -1074,11 +1107,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { errorICMPBody := func(view buffer.View) { ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - NextHeader: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, + PayloadLength: simpleBodySize, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, }) simpleBody(view[header.IPv6MinimumSize:]) } @@ -1193,11 +1226,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(size + payloadSize), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(size + payloadSize), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), @@ -1205,7 +1238,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { e.InjectInbound(ProtocolNumber, pkt) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid typStat := typ.statCounter(stats) @@ -1411,11 +1444,11 @@ func TestPacketQueing(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: uint16(payloadLength), + TransportProtocol: udp.ProtocolNumber, + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1453,11 +1486,11 @@ func TestPacketQueing(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1502,7 +1535,7 @@ func TestPacketQueing(t *testing.T) { } s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), NIC: nicID, }, @@ -1541,7 +1574,7 @@ func TestPacketQueing(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) pkt := header.ICMPv6(hdr.Prepend(naSize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(host2IPv6Addr.AddressWithPrefix.Address) @@ -1552,11 +1585,11 @@ func TestPacketQueing(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: header.NDPHopLimit, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: header.NDPHopLimit, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1590,7 +1623,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) return icmp }, @@ -1610,7 +1643,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) ns.Options().Serialize(header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(linkAddr1), @@ -1627,7 +1660,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) return icmp }, @@ -1643,7 +1676,7 @@ func TestCallsToNeighborCache(t *testing.T) { nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(nsSize)) icmp.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(lladdr0) ns.Options().Serialize(header.NDPOptionsSerializer{ header.NDPSourceLinkLayerAddressOption(linkAddr1), @@ -1660,7 +1693,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1681,7 +1714,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1700,7 +1733,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(false) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1720,7 +1753,7 @@ func TestCallsToNeighborCache(t *testing.T) { naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize icmp := header.ICMPv6(buffer.NewView(naSize)) icmp.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na := header.NDPNeighborAdvert(icmp.MessageBody()) na.SetSolicitedFlag(false) na.SetOverrideFlag(false) na.SetTargetAddress(lladdr1) @@ -1775,30 +1808,31 @@ func TestCallsToNeighborCache(t *testing.T) { t.Fatalf("ep.Enable(): %s", err) } - r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := lladdr0.WithPrefix() + if ep, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + ep.DecRef() } - defer r.Release() - - // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch. - r.LocalAddress = test.destination icmp := test.createPacket() - icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, test.source, test.destination, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: header.IPv6MinimumSize, Data: buffer.View(icmp).ToVectorisedView(), }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: r.RemoteAddress, - DstAddr: r.LocalAddress, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: test.source, + DstAddr: test.destination, }) - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) // Confirm the endpoint calls the correct NUDHandler method. diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 0526190cc..f2018d073 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -19,6 +19,7 @@ import ( "encoding/binary" "fmt" "hash/fnv" + "math" "sort" "sync/atomic" "time" @@ -34,7 +35,9 @@ import ( ) const ( + // ReassembleTimeout controls how long a fragment will be held. // As per RFC 8200 section 4.5: + // // If insufficient fragments are received to complete reassembly of a packet // within 60 seconds of the reception of the first-arriving fragment of that // packet, reassembly of that packet must be abandoned. @@ -58,6 +61,108 @@ const ( buckets = 2048 ) +// policyTable is the default policy table defined in RFC 6724 section 2.1. +// +// A more human-readable version: +// +// Prefix Precedence Label +// ::1/128 50 0 +// ::/0 40 1 +// ::ffff:0:0/96 35 4 +// 2002::/16 30 2 +// 2001::/32 5 5 +// fc00::/7 3 13 +// ::/96 1 3 +// fec0::/10 1 11 +// 3ffe::/16 1 12 +// +// The table is sorted by prefix length so longest-prefix match can be easily +// achieved. +// +// We willingly left out ::/96, fec0::/10 and 3ffe::/16 since those prefix +// assignments are deprecated. +// +// As per RFC 4291 section 2.5.5.1 (for ::/96), +// +// The "IPv4-Compatible IPv6 address" is now deprecated because the +// current IPv6 transition mechanisms no longer use these addresses. +// New or updated implementations are not required to support this +// address type. +// +// As per RFC 3879 section 4 (for fec0::/10), +// +// This document formally deprecates the IPv6 site-local unicast prefix +// defined in [RFC3513], i.e., 1111111011 binary or FEC0::/10. +// +// As per RFC 3701 section 1 (for 3ffe::/16), +// +// As clearly stated in [TEST-NEW], the addresses for the 6bone are +// temporary and will be reclaimed in the future. It further states +// that all users of these addresses (within the 3FFE::/16 prefix) will +// be required to renumber at some time in the future. +// +// and section 2, +// +// Thus after the pTLA allocation cutoff date January 1, 2004, it is +// REQUIRED that no new 6bone 3FFE pTLAs be allocated. +// +// MUST NOT BE MODIFIED. +var policyTable = [...]struct { + subnet tcpip.Subnet + + label uint8 +}{ + // ::1/128 + { + subnet: header.IPv6Loopback.WithPrefix().Subnet(), + label: 0, + }, + // ::ffff:0:0/96 + { + subnet: header.IPv4MappedIPv6Subnet, + label: 4, + }, + // 2001::/32 (Teredo prefix as per RFC 4380 section 2.6). + { + subnet: tcpip.AddressWithPrefix{ + Address: "\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: 32, + }.Subnet(), + label: 5, + }, + // 2002::/16 (6to4 prefix as per RFC 3056 section 2). + { + subnet: tcpip.AddressWithPrefix{ + Address: "\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: 16, + }.Subnet(), + label: 2, + }, + // fc00::/7 (Unique local addresses as per RFC 4193 section 3.1). + { + subnet: tcpip.AddressWithPrefix{ + Address: "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: 7, + }.Subnet(), + label: 13, + }, + // ::/0 + { + subnet: header.IPv6EmptySubnet, + label: 1, + }, +} + +func getLabel(addr tcpip.Address) uint8 { + for _, p := range policyTable { + if p.subnet.Contains(addr) { + return p.label + } + } + + panic(fmt.Sprintf("should have a label for address = %s", addr)) +} + var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -83,6 +188,7 @@ type endpoint struct { addressableEndpointState stack.AddressableEndpointState ndp ndpState + mld mldState } } @@ -118,6 +224,45 @@ type OpaqueInterfaceIdentifierOptions struct { SecretKey []byte } +// onAddressAssignedLocked handles an address being assigned. +// +// Precondition: e.mu must be exclusively locked. +func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) { + // As per RFC 2710 section 3, + // + // All MLD messages described in this document are sent with a link-local + // IPv6 Source Address, ... + // + // If we just completed DAD for a link-local address, then attempt to send any + // queued MLD reports. Note, we may have sent reports already for some of the + // groups before we had a valid link-local address to use as the source for + // the MLD messages, but that was only so that MLD snooping switches are aware + // of our membership to groups - routers would not have handled those reports. + // + // As per RFC 3590 section 4, + // + // MLD Report and Done messages are sent with a link-local address as + // the IPv6 source address, if a valid address is available on the + // interface. If a valid link-local address is not available (e.g., one + // has not been configured), the message is sent with the unspecified + // address (::) as the IPv6 source address. + // + // Once a valid link-local address is available, a node SHOULD generate + // new MLD Report messages for all multicast addresses joined on the + // interface. + // + // Routers receiving an MLD Report or Done message with the unspecified + // address as the IPv6 source address MUST silently discard the packet + // without taking any action on the packets contents. + // + // Snooping switches MUST manage multicast forwarding state based on MLD + // Report and Done messages sent with the unspecified address as the + // IPv6 source address. + if header.IsV6LinkLocalAddress(addr) { + e.mu.mld.sendQueuedReports() + } +} + // InvalidateDefaultRouter implements stack.NDPEndpoint. func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { e.mu.Lock() @@ -224,6 +369,12 @@ func (e *endpoint) Enable() *tcpip.Error { return nil } + // Groups may have been joined when the endpoint was disabled, or the + // endpoint may have left groups from the perspective of MLD when the + // endpoint was disabled. Either way, we need to let routers know to + // send us multicast traffic. + e.mu.mld.initializeAll() + // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives // and responds to the various NDP messages that are destined to the @@ -241,8 +392,10 @@ func (e *endpoint) Enable() *tcpip.Error { // (NDP NS) messages may be sent to the All-Nodes multicast group if the // source address of the NDP NS is the unspecified address, as per RFC 4861 // section 7.2.4. - if _, err := e.mu.addressableEndpointState.JoinGroup(header.IPv6AllNodesMulticastAddress); err != nil { - return err + if err := e.joinGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv6AllNodesMulticastAddress, err)) } // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent @@ -251,7 +404,7 @@ func (e *endpoint) Enable() *tcpip.Error { // Addresses may have aleady completed DAD but in the time since the endpoint // was last enabled, other devices may have acquired the same addresses. var err *tcpip.Error - e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { addr := addressEndpoint.AddressWithPrefix().Address if !header.IsV6UnicastAddress(addr) { return true @@ -273,7 +426,7 @@ func (e *endpoint) Enable() *tcpip.Error { } // Do not auto-generate an IPv6 link-local address for loopback devices. - if e.protocol.autoGenIPv6LinkLocal && !e.nic.IsLoopback() { + if e.protocol.options.AutoGenLinkLocal && !e.nic.IsLoopback() { // The valid and preferred lifetime is infinite for the auto-generated // link-local address. e.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime) @@ -322,7 +475,7 @@ func (e *endpoint) Disable() { } func (e *endpoint) disableLocked() { - if !e.setEnabled(false) { + if !e.Enabled() { return } @@ -331,9 +484,17 @@ func (e *endpoint) disableLocked() { e.stopDADForPermanentAddressesLocked() // The endpoint may have already left the multicast group. - if _, err := e.mu.addressableEndpointState.LeaveGroup(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { + if err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) } + + // Leave groups from the perspective of MLD so that routers know that + // we are no longer interested in the group. + e.mu.mld.softLeaveAll() + + if !e.setEnabled(false) { + panic("should have only done work to disable the endpoint if it was enabled") + } } // stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. @@ -341,7 +502,7 @@ func (e *endpoint) disableLocked() { // Precondition: e.mu must be write locked. func (e *endpoint) stopDADForPermanentAddressesLocked() { // Stop DAD for all the tentative unicast addresses. - e.mu.addressableEndpointState.ReadOnly().ForEach(func(addressEndpoint stack.AddressEndpoint) bool { + e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { if addressEndpoint.GetKind() != stack.PermanentTentative { return true } @@ -373,19 +534,27 @@ func (e *endpoint) MTU() uint32 { // MaxHeaderLength returns the maximum length needed by ipv6 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { + // TODO(gvisor.dev/issues/5035): The maximum header length returned here does + // not open the possibility for the caller to know about size required for + // extension headers. return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { - length := uint16(pkt.Size()) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) { + extHdrsLen := extensionHeaders.Length() + length := pkt.Size() + extensionHeaders.Length() + if length > math.MaxUint16 { + panic(fmt.Sprintf("IPv6 payload too large: %d, must be <= %d", length, math.MaxUint16)) + } + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ - PayloadLength: length, - NextHeader: uint8(params.Protocol), - HopLimit: params.TTL, - TrafficClass: params.TOS, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + PayloadLength: uint16(length), + TransportProtocol: params.Protocol, + HopLimit: params.TTL, + TrafficClass: params.TOS, + SrcAddr: srcAddr, + DstAddr: dstAddr, + ExtensionHeaders: extensionHeaders, }) pkt.NetworkProtocolNumber = ProtocolNumber } @@ -440,18 +609,14 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r, pkt, params) - return e.writePacket(r, gso, pkt, params.Protocol) -} + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */) -func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesOutputDropped.Increment() + e.protocol.stack.Stats().IP.IPTablesOutputDropped.Increment() return nil } @@ -467,24 +632,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) // Since we rewrote the packet but it is being routed back to us, we can // safely assume the checksum is valid. pkt.RXTransportChecksumValidated = true - ep.HandlePacket(pkt) + ep.(*endpoint).handlePacket(pkt) } return nil } } + return e.writePacket(r, gso, pkt, params.Protocol, false /* headerIncluded */) +} + +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber, headerIncluded bool) *tcpip.Error { if r.Loop&stack.PacketLoop != 0 { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - loopedR := r.MakeLoopedRoute() - loopedR.PopulatePacketInfo(pkt) - loopedR.Release() - e.HandlePacket(pkt) + // If the packet was generated by the stack (not a raw/packet endpoint + // where a packet may be written with the header included), then we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = !headerIncluded + e.handlePacket(pkt) } } if r.Loop&stack.PacketOut == 0 { @@ -530,7 +698,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - e.addIPHeader(r, pb, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */) networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) if err != nil { @@ -558,8 +726,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, nicName) if len(dropped) == 0 && len(natPkts) == 0 { // Fast path: If no packets are to be dropped then we can just invoke the // faster WritePackets API directly. @@ -584,9 +751,10 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { pkt := pkt.CloneToInbound() if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - route.PopulatePacketInfo(pkt) - ep.HandlePacket(pkt) + // Since we rewrote the packet but it is being routed back to us, we + // can safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.(*endpoint).handlePacket(pkt) } n++ continue @@ -640,16 +808,85 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrMalformedHeader } - return e.writePacket(r, nil /* gso */, pkt, proto) + return e.writePacket(r, nil /* gso */, pkt, proto, true /* headerIncluded */) +} + +// forwardPacket attempts to forward a packet to its final destination. +func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) *tcpip.Error { + h := header.IPv6(pkt.NetworkHeader().View()) + hopLimit := h.HopLimit() + if hopLimit <= 1 { + // As per RFC 4443 section 3.3, + // + // If a router receives a packet with a Hop Limit of zero, or if a + // router decrements a packet's Hop Limit to zero, it MUST discard the + // packet and originate an ICMPv6 Time Exceeded message with Code 0 to + // the source of the packet. This indicates either a routing loop or + // too small an initial Hop Limit value. + return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) + } + + dstAddr := h.DestinationAddress() + + // Check if the destination is owned by the stack. + networkEndpoint, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr) + if err == nil { + networkEndpoint.(*endpoint).handlePacket(pkt) + return nil + } + if err != tcpip.ErrBadAddress { + return err + } + + r, err := e.protocol.stack.FindRoute(0, "", dstAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + // We need to do a deep copy of the IP packet because + // WriteHeaderIncludedPacket takes ownership of the packet buffer, but we do + // not own it. + newHdr := header.IPv6(stack.PayloadSince(pkt.NetworkHeader())) + + // As per RFC 8200 section 3, + // + // Hop Limit 8-bit unsigned integer. Decremented by 1 by + // each node that forwards the packet. + newHdr.SetHopLimit(hopLimit - 1) + + return r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.View(newHdr).ToVectorisedView(), + })) } // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() + stats.IP.PacketsReceived.Increment() + if !e.isEnabled() { + stats.IP.DisabledPacketsReceived.Increment() return } + // Loopback traffic skips the prerouting chain. + if !e.nic.IsLoopback() { + if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, nil, e.MainAddress().Address, ""); !ok { + // iptables is telling us to drop the packet. + stats.IP.IPTablesPreroutingDropped.Increment() + return + } + } + + e.handlePacket(pkt) +} + +// handlePacket is like HandlePacket except it does not perform the prerouting +// iptables hook. +func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { pkt.NICID = e.nic.ID() stats := e.protocol.stack.Stats() @@ -669,6 +906,20 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } + // The destination address should be an address we own or a group we joined + // for us to receive the packet. Otherwise, attempt to forward the packet. + if addressEndpoint := e.AcquireAssignedAddress(dstAddr, e.nic.Promiscuous(), stack.CanBePrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + } else if !e.IsInGroup(dstAddr) { + if !e.protocol.Forwarding() { + stats.IP.InvalidDestinationAddressesReceived.Increment() + return + } + + _ = e.forwardPacket(pkt) + return + } + // vv consists of: // - Any IPv6 header bytes after the first 40 (i.e. extensions). // - The transport header, if present. @@ -681,8 +932,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. stats.IP.IPTablesInputDropped.Increment() return @@ -888,18 +1138,6 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { return } - // Set up a callback in case we need to send a Time Exceeded Message as - // per RFC 2460 Section 4.5. - var releaseCB func(bool) - if start == 0 { - pkt := pkt.Clone() - releaseCB = func(timedOut bool) { - if timedOut { - _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt) - } - } - } - // Note that pkt doesn't have its transport header set after reassembly, // and won't until DeliverNetworkPacket sets it. data, proto, ready, err := e.protocol.fragmentation.Process( @@ -914,17 +1152,17 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { start+uint16(fragmentPayloadLen)-1, extHdr.More(), uint8(rawPayload.Identifier), - rawPayload.Buf, - releaseCB, + pkt, ) if err != nil { stats.IP.MalformedPacketsReceived.Increment() stats.IP.MalformedFragmentsReceived.Increment() return } - pkt.Data = data if ready { + pkt.Data = data + // We create a new iterator with the reassembled packet because we could // have more extension headers in the reassembled payload, as per RFC // 8200 section 4.5. We also use the NextHeader value from the first @@ -1023,9 +1261,16 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // // Which when taken together indicate that an unknown protocol should // be treated as an unrecognized next header value. + // The location of the Next Header field is in a different place in + // the initial IPv6 header than it is in the extension headers so + // treat it specially. + prevHdrIDOffset := uint32(header.IPv6NextHeaderOffset) + if previousHeaderStart != 0 { + prevHdrIDOffset = previousHeaderStart + } _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, - pointer: it.ParseOffset(), + pointer: prevHdrIDOffset, }, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) @@ -1033,12 +1278,11 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } default: - _ = e.protocol.returnError(&icmpReasonParameterProblem{ - code: header.ICMPv6UnknownHeader, - pointer: it.ParseOffset(), - }, pkt) - stats.UnknownProtocolRcvdPackets.Increment() - return + // Since the iterator returns IPv6RawPayloadHeader for unknown Extension + // Header IDs this should never happen unless we missed a supported type + // here. + panic(fmt.Sprintf("unrecognized type from it.Next() = %T", extHdr)) + } } } @@ -1086,11 +1330,6 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre return addressEndpoint, nil } - snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.mu.addressableEndpointState.JoinGroup(snmc); err != nil { - return nil, err - } - addressEndpoint.SetKind(stack.PermanentTentative) if e.Enabled() { @@ -1099,6 +1338,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } } + snmc := header.SolicitedNodeAddr(addr.Address) + if err := e.joinGroupLocked(snmc); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err)) + } + return addressEndpoint, nil } @@ -1144,7 +1390,8 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn } snmc := header.SolicitedNodeAddr(addr.Address) - if _, err := e.mu.addressableEndpointState.LeaveGroup(snmc); err != nil && err != tcpip.ErrBadLocalAddress { + // The endpoint may have already left the multicast group. + if err := e.leaveGroupLocked(snmc); err != nil && err != tcpip.ErrBadLocalAddress { return err } @@ -1167,7 +1414,7 @@ func (e *endpoint) hasPermanentAddressRLocked(addr tcpip.Address) bool { // // Precondition: e.mu must be read or write locked. func (e *endpoint) getAddressRLocked(localAddr tcpip.Address) stack.AddressEndpoint { - return e.mu.addressableEndpointState.ReadOnly().Lookup(localAddr) + return e.mu.addressableEndpointState.GetAddress(localAddr) } // MainAddress implements stack.AddressableEndpoint. @@ -1199,6 +1446,26 @@ func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allow return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) } +// getLinkLocalAddressRLocked returns a link-local address from the primary list +// of addresses, if one is available. +// +// See stack.PrimaryEndpointBehavior for more details about the primary list. +// +// Precondition: e.mu must be read locked. +func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address { + var linkLocalAddr tcpip.Address + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { + if addressEndpoint.IsAssigned(false /* allowExpired */) { + if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) { + linkLocalAddr = addr + return false + } + } + return true + }) + return linkLocalAddr +} + // acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress // but with locking requirements. // @@ -1208,7 +1475,11 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address // RFC 6724 section 5. type addrCandidate struct { addressEndpoint stack.AddressEndpoint + addr tcpip.Address scope header.IPv6AddressScope + + label uint8 + matchingPrefix uint8 } if len(remoteAddr) == 0 { @@ -1218,10 +1489,10 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address // Create a candidate set of available addresses we can potentially use as a // source address. var cs []addrCandidate - e.mu.addressableEndpointState.ReadOnly().ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) { + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { // If r is not valid for outgoing connections, it is not a valid endpoint. if !addressEndpoint.IsAssigned(allowExpired) { - return + return true } addr := addressEndpoint.AddressWithPrefix().Address @@ -1235,8 +1506,13 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address cs = append(cs, addrCandidate{ addressEndpoint: addressEndpoint, + addr: addr, scope: scope, + label: getLabel(addr), + matchingPrefix: remoteAddr.MatchingPrefix(addr), }) + + return true }) remoteScope, err := header.ScopeForIPv6Address(remoteAddr) @@ -1245,18 +1521,20 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) } + remoteLabel := getLabel(remoteAddr) + // Sort the addresses as per RFC 6724 section 5 rules 1-3. // - // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. + // TODO(b/146021396): Implement rules 4, 5 of RFC 6724 section 5. sort.Slice(cs, func(i, j int) bool { sa := cs[i] sb := cs[j] // Prefer same address as per RFC 6724 section 5 rule 1. - if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr { + if sa.addr == remoteAddr { return true } - if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr { + if sb.addr == remoteAddr { return false } @@ -1273,11 +1551,29 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address return sbDep } + // Prefer matching label as per RFC 6724 section 5 rule 6. + if sa, sb := sa.label == remoteLabel, sb.label == remoteLabel; sa != sb { + if sa { + return true + } + if sb { + return false + } + } + // Prefer temporary addresses as per RFC 6724 section 5 rule 7. if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp { return saTemp } + // Use longest matching prefix as per RFC 6724 section 5 rule 8. + if sa.matchingPrefix > sb.matchingPrefix { + return true + } + if sb.matchingPrefix > sa.matchingPrefix { + return false + } + // sa and sb are equal, return the endpoint that is closest to the front of // the primary endpoint list. return i < j @@ -1309,35 +1605,52 @@ func (e *endpoint) PermanentAddresses() []tcpip.AddressWithPrefix { } // JoinGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) JoinGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) JoinGroup(addr tcpip.Address) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + return e.joinGroupLocked(addr) +} + +// joinGroupLocked is like JoinGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { if !header.IsV6MulticastAddress(addr) { - return false, tcpip.ErrBadAddress + return tcpip.ErrBadAddress } - e.mu.Lock() - defer e.mu.Unlock() - return e.mu.addressableEndpointState.JoinGroup(addr) + e.mu.mld.joinGroup(addr) + return nil } // LeaveGroup implements stack.GroupAddressableEndpoint. -func (e *endpoint) LeaveGroup(addr tcpip.Address) (bool, *tcpip.Error) { +func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.LeaveGroup(addr) + return e.leaveGroupLocked(addr) +} + +// leaveGroupLocked is like LeaveGroup but with locking requirements. +// +// Precondition: e.mu must be locked. +func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { + return e.mu.mld.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mu.addressableEndpointState.IsInGroup(addr) + return e.mu.mld.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) var _ stack.NetworkProtocol = (*protocol)(nil) +var _ fragmentation.TimeoutHandler = (*protocol)(nil) type protocol struct { - stack *stack.Stack + stack *stack.Stack + options Options mu struct { sync.RWMutex @@ -1361,26 +1674,6 @@ type protocol struct { forwarding uint32 fragmentation *fragmentation.Fragmentation - - // ndpDisp is the NDP event dispatcher that is used to send the netstack - // integrator NDP related events. - ndpDisp NDPDispatcher - - // ndpConfigs is the default NDP configurations used by an IPv6 endpoint. - ndpConfigs NDPConfigurations - - // opaqueIIDOpts hold the options for generating opaque interface identifiers - // (IIDs) as outlined by RFC 7217. - opaqueIIDOpts OpaqueInterfaceIdentifierOptions - - // tempIIDSeed is used to seed the initial temporary interface identifier - // history value used to generate IIDs for temporary SLAAC addresses. - tempIIDSeed []byte - - // autoGenIPv6LinkLocal determines whether or not the stack attempts to - // auto-generate an IPv6 link-local address for newly enabled non-loopback - // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. - autoGenIPv6LinkLocal bool } // Number returns the ipv6 protocol number. @@ -1413,16 +1706,11 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L dispatcher: dispatcher, protocol: p, } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) - e.mu.ndp = ndpState{ - ep: e, - configs: p.ndpConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), - } - e.mu.ndp.initializeTempAddrState() + e.mu.ndp.init(e) + e.mu.mld.init(e) + e.mu.Unlock() p.mu.Lock() defer p.mu.Unlock() @@ -1545,17 +1833,17 @@ type Options struct { // NDPConfigs is the default NDP configurations used by interfaces. NDPConfigs NDPConfigurations - // AutoGenIPv6LinkLocal determines whether or not the stack attempts to - // auto-generate an IPv6 link-local address for newly enabled non-loopback + // AutoGenLinkLocal determines whether or not the stack attempts to + // auto-generate a link-local address for newly enabled non-loopback // NICs. // // Note, setting this to true does not mean that a link-local address is // assigned right away, or at all. If Duplicate Address Detection is enabled, // an address is only assigned if it successfully resolves. If it fails, no - // further attempts are made to auto-generate an IPv6 link-local adddress. + // further attempts are made to auto-generate a link-local adddress. // // The generated link-local address follows RFC 4291 Appendix A guidelines. - AutoGenIPv6LinkLocal bool + AutoGenLinkLocal bool // NDPDisp is the NDP event dispatcher that an integrator can provide to // receive NDP related events. @@ -1579,6 +1867,9 @@ type Options struct { // seed that is too small would reduce randomness and increase predictability, // defeating the purpose of temporary SLAAC addresses. TempIIDSeed []byte + + // MLD holds options for MLD. + MLD MLDOptions } // NewProtocolWithOptions returns an IPv6 network protocol. @@ -1590,17 +1881,13 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { return func(s *stack.Stack) stack.NetworkProtocol { p := &protocol{ - stack: s, - fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), - ids: ids, - hashIV: hashIV, - - ndpDisp: opts.NDPDisp, - ndpConfigs: opts.NDPConfigs, - opaqueIIDOpts: opts.OpaqueIIDOpts, - tempIIDSeed: opts.TempIIDSeed, - autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, + stack: s, + options: opts, + + ids: ids, + hashIV: hashIV, } + p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[*endpoint]struct{}) p.SetDefaultTTL(DefaultTTL) return p @@ -1644,24 +1931,25 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea fragPkt.NetworkProtocolNumber = ProtocolNumber originalIPHeadersLength := len(originalIPHeaders) - fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize + + s := header.IPv6ExtHdrSerializer{&header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), + M: more, + Identification: id, + }} + + fragmentIPHeadersLength := originalIPHeadersLength + s.Length() fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength)) - fragPkt.NetworkProtocolNumber = ProtocolNumber // Copy the IPv6 header and any extension headers already populated. if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength { panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength)) } - fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader) - fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) - fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:]) - fragmentHeader.Encode(&header.IPv6FragmentFields{ - M: more, - FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), - Identification: id, - NextHeader: uint8(transportProto), - }) + nextHeader, _ := s.Serialize(transportProto, fragmentIPHeaders[originalIPHeadersLength:]) + + fragmentIPHeaders.SetNextHeader(nextHeader) + fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) return fragPkt, more } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 1bfcdde25..360025b20 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -15,9 +15,12 @@ package ipv6 import ( + "bytes" "encoding/hex" "fmt" + "io/ioutil" "math" + "net" "testing" "github.com/google/go-cmp/cmp" @@ -50,6 +53,7 @@ const ( fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier) destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier) + unknownHdrID = uint8(header.IPv6UnknownExtHdrIdentifier) extraHeaderReserve = 50 ) @@ -67,18 +71,18 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived if got := stats.NeighborAdvert.Value(); got != want { t.Fatalf("got NeighborAdvert = %d, want = %d", got, want) @@ -125,11 +129,11 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(payloadLength), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -572,6 +576,33 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { expectICMP: false, }, { + name: "unknown next header (first)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + nextHdr, 0, 63, 4, 1, 2, 3, 4, + }, unknownHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6NextHeaderOffset, + }, + { + name: "unknown next header (not first)", + extHdr: func(nextHdr uint8) ([]byte, uint8) { + return []byte{ + unknownHdrID, 0, + 63, 4, 1, 2, 3, 4, + }, hopByHopExtHdrID + }, + shouldAccept: false, + expectICMP: true, + ICMPType: header.ICMPv6ParamProblem, + ICMPCode: header.ICMPv6UnknownHeader, + pointer: header.IPv6FixedHeaderSize, + }, + { name: "destination with unknown option skippable action", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ @@ -754,11 +785,6 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { pointer: header.IPv6FixedHeaderSize, }, { - name: "No next header", - extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID }, - shouldAccept: false, - }, - { name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)", extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{ @@ -820,13 +846,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, } + const mtu = header.IPv6MinimumMTU for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) + e := channel.New(1, mtu, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -872,7 +899,13 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { Length: uint16(udpLength), }) copy(u.Payload(), udpPayload) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength)) + + dstAddr := tcpip.Address(addr2) + if test.multicast { + dstAddr = header.IPv6AllNodesMulticastAddress + } + + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, dstAddr, uint16(udpLength)) sum = header.Checksum(udpPayload, sum) u.SetChecksum(^u.CalculateChecksum(sum)) @@ -883,16 +916,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { // Serialize IPv6 fixed header. payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - dstAddr := tcpip.Address(addr2) - if test.multicast { - dstAddr = header.IPv6AllNodesMulticastAddress - } ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), - NextHeader: ipv6NextHdr, - HopLimit: 255, - SrcAddr: addr1, - DstAddr: dstAddr, + // We're lying about transport protocol here to be able to generate + // raw extension headers from the test definitions. + TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr), + HopLimit: 255, + SrcAddr: addr1, + DstAddr: dstAddr, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -951,17 +982,24 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { if got := stats.Value(); got != 1 { t.Errorf("got UDP Rx Packets = %d, want = 1", got) } - gotPayload, _, err := ep.Read(nil) + var buf bytes.Buffer + result, err := ep.Read(&buf, mtu, tcpip.ReadOptions{}) if err != nil { - t.Fatalf("Read(nil): %s", err) + t.Fatalf("Read: %s", err) + } + if diff := cmp.Diff(tcpip.ReadResult{ + Count: len(udpPayload), + Total: len(udpPayload), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("Read: unexpected result (-want +got):\n%s", diff) } - if diff := cmp.Diff(buffer.View(udpPayload), gotPayload); diff != "" { + if diff := cmp.Diff(udpPayload, buf.Bytes()); diff != "" { t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } // Should not have any more UDP packets. - if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + if res, err := ep.Read(ioutil.Discard, mtu, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { + t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) } }) } @@ -981,9 +1019,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { udpPayload2Length = 128 // Used to test cases where the fragment blocks are not a multiple of // the fragment block size of 8 (RFC 8200 section 4.5). - udpPayload3Length = 127 - udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize - fragmentExtHdrLen = 8 + udpPayload3Length = 127 + udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize + udpMaximumSizeMinus15 = header.UDPMaximumSize - 15 + fragmentExtHdrLen = 8 // Note, not all routing extension headers will be 8 bytes but this test // uses 8 byte routing extension headers for most sub tests. routingExtHdrLen = 8 @@ -1327,14 +1366,14 @@ func TestReceiveIPv6Fragments(t *testing.T) { dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+65520, + fragmentExtHdrLen+udpMaximumSizeMinus15, []buffer.View{ // Fragment extension header. // // Fragment offset = 0, More = true, ID = 1 buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), - ipv6Payload4Addr1ToAddr2[:65520], + ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], }, ), }, @@ -1343,14 +1382,17 @@ func TestReceiveIPv6Fragments(t *testing.T) { dstAddr: addr2, nextHdr: fragmentExtHdrID, data: buffer.NewVectorisedView( - fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520, + fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15, []buffer.View{ // Fragment extension header. // - // Fragment offset = 8190, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}), + // Fragment offset = udpMaximumSizeMinus15/8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + udpMaximumSizeMinus15 >> 8, + udpMaximumSizeMinus15 & 0xff, + 0, 0, 0, 1}), - ipv6Payload4Addr1ToAddr2[65520:], + ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], }, ), }, @@ -1358,6 +1400,47 @@ func TestReceiveIPv6Fragments(t *testing.T) { expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, }, { + name: "Two fragments with MF flag reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+udpMaximumSizeMinus15, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[:udpMaximumSizeMinus15], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-udpMaximumSizeMinus15, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = udpMaximumSizeMinus15/8, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + udpMaximumSizeMinus15 >> 8, + (udpMaximumSizeMinus15 & 0xff) + 1, + 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[udpMaximumSizeMinus15:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { name: "Two fragments with per-fragment routing header with zero segments left", fragments: []fragmentData{ { @@ -1876,10 +1959,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(f.data.Size()), - NextHeader: f.nextHdr, - HopLimit: 255, - SrcAddr: f.srcAddr, - DstAddr: f.dstAddr, + // We're lying about transport protocol here so that we can generate + // raw extension headers for the tests. + TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr), + HopLimit: 255, + SrcAddr: f.srcAddr, + DstAddr: f.dstAddr, }) vv := hdr.View().ToVectorisedView() @@ -1894,18 +1979,20 @@ func TestReceiveIPv6Fragments(t *testing.T) { t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) } + const rcvSize = 65536 // Account for reassembled packets. for i, p := range test.expectedPayloads { - gotPayload, _, err := ep.Read(nil) + var buf bytes.Buffer + _, err := ep.Read(&buf, rcvSize, tcpip.ReadOptions{}) if err != nil { - t.Fatalf("(i=%d) Read(nil): %s", i, err) + t.Fatalf("(i=%d) Read: %s", i, err) } - if diff := cmp.Diff(buffer.View(p), gotPayload); diff != "" { + if diff := cmp.Diff(p, buf.Bytes()); diff != "" { t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) } } - if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + if res, err := ep.Read(ioutil.Discard, rcvSize, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { + t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) } }) } @@ -1924,7 +2011,7 @@ func TestInvalidIPv6Fragments(t *testing.T) { type fragmentData struct { ipv6Fields header.IPv6Fields - ipv6FragmentFields header.IPv6FragmentFields + ipv6FragmentFields header.IPv6SerializableFragmentExtHdr payload []byte } @@ -1943,14 +2030,13 @@ func TestInvalidIPv6Fragments(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 9, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 9, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0 >> 3, M: true, Identification: ident, @@ -1970,14 +2056,13 @@ func TestInvalidIPv6Fragments(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3, M: false, Identification: ident, @@ -2018,10 +2103,9 @@ func TestInvalidIPv6Fragments(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) - ip.Encode(&f.ipv6Fields) - - fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) - fragHDR.Encode(&f.ipv6FragmentFields) + encodeArgs := f.ipv6Fields + encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) + ip.Encode(&encodeArgs) vv := hdr.View().ToVectorisedView() vv.AppendView(f.payload) @@ -2083,7 +2167,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) { type fragmentData struct { ipv6Fields header.IPv6Fields - ipv6FragmentFields header.IPv6FragmentFields + ipv6FragmentFields header.IPv6SerializableFragmentExtHdr payload []byte } @@ -2097,14 +2181,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2119,14 +2202,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2135,14 +2217,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { }, { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2157,14 +2238,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 8, M: false, Identification: ident, @@ -2179,14 +2259,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2195,14 +2274,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { }, { ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 8, M: false, Identification: ident, @@ -2217,14 +2295,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 8, M: false, Identification: ident, @@ -2233,14 +2310,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { }, { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2279,10 +2355,11 @@ func TestFragmentReassemblyTimeout(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) - ip.Encode(&f.ipv6Fields) + encodeArgs := f.ipv6Fields + encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) + ip.Encode(&encodeArgs) fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) - fragHDR.Encode(&f.ipv6FragmentFields) vv := hdr.View().ToVectorisedView() vv.AppendView(f.payload) @@ -2438,7 +2515,7 @@ func TestWriteStats(t *testing.T) { test.setup(t, rt.Stack()) - nWritten, _ := writer.writePackets(&rt, pkts) + nWritten, _ := writer.writePackets(rt, pkts) if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) @@ -2455,7 +2532,7 @@ func TestWriteStats(t *testing.T) { } } -func buildRoute(t *testing.T, ep stack.LinkEndpoint) stack.Route { +func buildRoute(t *testing.T, ep stack.LinkEndpoint) *stack.Route { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, }) @@ -2821,3 +2898,160 @@ func TestFragmentationErrors(t *testing.T) { }) } } + +func TestForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + randomSequence = 123 + randomIdent = 42 + ) + + ipv6Addr1 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + } + ipv6Addr2 := tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("11::1").To16()), + PrefixLen: 64, + } + remoteIPv6Addr1 := tcpip.Address(net.ParseIP("10::2").To16()) + remoteIPv6Addr2 := tcpip.Address(net.ParseIP("11::2").To16()) + + tests := []struct { + name string + TTL uint8 + expectErrorICMP bool + }{ + { + name: "TTL of zero", + TTL: 0, + expectErrorICMP: true, + }, + { + name: "TTL of one", + TTL: 1, + expectErrorICMP: true, + }, + { + name: "TTL of two", + TTL: 2, + expectErrorICMP: false, + }, + { + name: "TTL of three", + TTL: 3, + expectErrorICMP: false, + }, + { + name: "Max TTL", + TTL: math.MaxUint8, + expectErrorICMP: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + }) + // We expect at most a single packet in response to our ICMP Echo Request. + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + ipv6ProtoAddr1 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr1} + if err := s.AddProtocolAddress(nicID1, ipv6ProtoAddr1); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID1, ipv6ProtoAddr1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) + } + ipv6ProtoAddr2 := tcpip.ProtocolAddress{Protocol: ProtocolNumber, AddressWithPrefix: ipv6Addr2} + if err := s.AddProtocolAddress(nicID2, ipv6ProtoAddr2); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID2, ipv6ProtoAddr2, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: ipv6Addr1.Subnet(), + NIC: nicID1, + }, + { + Destination: ipv6Addr2.Subnet(), + NIC: nicID2, + }, + }) + + if err := s.SetForwarding(ProtocolNumber, true); err != nil { + t.Fatalf("SetForwarding(%d, true): %s", ProtocolNumber, err) + } + + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6MinimumSize) + icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmp.SetIdent(randomIdent) + icmp.SetSequence(randomSequence) + icmp.SetType(header.ICMPv6EchoRequest) + icmp.SetCode(header.ICMPv6UnusedCode) + icmp.SetChecksum(0) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: test.TTL, + SrcAddr: remoteIPv6Addr1, + DstAddr: remoteIPv6Addr2, + }) + requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + }) + e1.InjectInbound(ProtocolNumber, requestPkt) + + if test.expectErrorICMP { + reply, ok := e1.Read() + if !ok { + t.Fatal("expected ICMP Hop Limit Exceeded packet through incoming NIC") + } + + checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(ipv6Addr1.Address), + checker.DstAddr(remoteIPv6Addr1), + checker.TTL(DefaultTTL), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6TimeExceeded), + checker.ICMPv6Code(header.ICMPv6HopLimitExceeded), + checker.ICMPv6Payload([]byte(hdr.View())), + ), + ) + + if n := e2.Drain(); n != 0 { + t.Fatalf("got e2.Drain() = %d, want = 0", n) + } + } else { + reply, ok := e2.Read() + if !ok { + t.Fatal("expected ICMP Echo Request packet through outgoing NIC") + } + + checker.IPv6(t, header.IPv6(stack.PayloadSince(reply.Pkt.NetworkHeader())), + checker.SrcAddr(remoteIPv6Addr1), + checker.DstAddr(remoteIPv6Addr2), + checker.TTL(test.TTL-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoRequest), + checker.ICMPv6Code(header.ICMPv6UnusedCode), + checker.ICMPv6Payload(nil), + ), + ) + + if n := e1.Drain(); n != 0 { + t.Fatalf("got e1.Drain() = %d, want = 0", n) + } + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go new file mode 100644 index 000000000..e8d1e7a79 --- /dev/null +++ b/pkg/tcpip/network/ipv6/mld.go @@ -0,0 +1,262 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + // UnsolicitedReportIntervalMax is the maximum delay between sending + // unsolicited MLD reports. + // + // Obtained from RFC 2710 Section 7.10. + UnsolicitedReportIntervalMax = 10 * time.Second +) + +// MLDOptions holds options for MLD. +type MLDOptions struct { + // Enabled indicates whether MLD will be performed. + // + // When enabled, MLD may transmit MLD report and done messages when + // joining and leaving multicast groups respectively, and handle incoming + // MLD packets. + // + // This field is ignored and is always assumed to be false for interfaces + // without neighbouring nodes (e.g. loopback). + Enabled bool +} + +var _ ip.MulticastGroupProtocol = (*mldState)(nil) + +// mldState is the per-interface MLD state. +// +// mldState.init MUST be called to initialize the MLD state. +type mldState struct { + // The IPv6 endpoint this mldState is for. + ep *endpoint + + genericMulticastProtocol ip.GenericMulticastProtocolState +} + +// Enabled implements ip.MulticastGroupProtocol. +func (mld *mldState) Enabled() bool { + // No need to perform MLD on loopback interfaces since they don't have + // neighbouring nodes. + return mld.ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback() && mld.ep.Enabled() +} + +// SendReport implements ip.MulticastGroupProtocol. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { + return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport) +} + +// SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { + _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + return err +} + +// init sets up an mldState struct, and is required to be called before using +// a new mldState. +// +// Must only be called once for the lifetime of mld. +func (mld *mldState) init(ep *endpoint) { + mld.ep = ep + mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ + Rand: ep.protocol.stack.Rand(), + Clock: ep.protocol.stack.Clock(), + Protocol: mld, + MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, + AllNodesAddress: header.IPv6AllNodesMulticastAddress, + }) +} + +// handleMulticastListenerQuery handles a query message. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) { + mld.genericMulticastProtocol.HandleQueryLocked(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) +} + +// handleMulticastListenerReport handles a report message. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { + mld.genericMulticastProtocol.HandleReportLocked(mldHdr.MulticastAddress()) +} + +// joinGroup handles joining a new group and sending and scheduling the required +// messages. +// +// If the group is already joined, returns tcpip.ErrDuplicateAddress. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) joinGroup(groupAddress tcpip.Address) { + mld.genericMulticastProtocol.JoinGroupLocked(groupAddress) +} + +// isInGroup returns true if the specified group has been joined locally. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool { + return mld.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress) +} + +// leaveGroup handles removing the group from the membership map, cancels any +// delay timers associated with that group, and sends the Done message, if +// required. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { + // LeaveGroup returns false only if the group was not joined. + if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { + return nil + } + + return tcpip.ErrBadLocalAddress +} + +// softLeaveAll leaves all groups from the perspective of MLD, but remains +// joined locally. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) softLeaveAll() { + mld.genericMulticastProtocol.MakeAllNonMemberLocked() +} + +// initializeAll attemps to initialize the MLD state for each group that has +// been joined locally. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) initializeAll() { + mld.genericMulticastProtocol.InitializeGroupsLocked() +} + +// sendQueuedReports attempts to send any reports that are queued for sending. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) sendQueuedReports() { + mld.genericMulticastProtocol.SendQueuedReportsLocked() +} + +// writePacket assembles and sends an MLD packet. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) { + sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent + var mldStat *tcpip.StatCounter + switch mldType { + case header.ICMPv6MulticastListenerReport: + mldStat = sentStats.MulticastListenerReport + case header.ICMPv6MulticastListenerDone: + mldStat = sentStats.MulticastListenerDone + default: + panic(fmt.Sprintf("unrecognized mld type = %d", mldType)) + } + + icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize)) + icmp.SetType(mldType) + header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress) + // As per RFC 2710 section 3, + // + // All MLD messages described in this document are sent with a link-local + // IPv6 Source Address, an IPv6 Hop Limit of 1, and an IPv6 Router Alert + // option in a Hop-by-Hop Options header. + // + // However, this would cause problems with Duplicate Address Detection with + // the first address as MLD snooping switches may not send multicast traffic + // that DAD depends on to the node performing DAD without the MLD report, as + // documented in RFC 4816: + // + // Note that when a node joins a multicast address, it typically sends a + // Multicast Listener Discovery (MLD) report message [RFC2710] [RFC3810] + // for the multicast address. In the case of Duplicate Address + // Detection, the MLD report message is required in order to inform MLD- + // snooping switches, rather than routers, to forward multicast packets. + // In the above description, the delay for joining the multicast address + // thus means delaying transmission of the corresponding MLD report + // message. Since the MLD specifications do not request a random delay + // to avoid race conditions, just delaying Neighbor Solicitation would + // cause congestion by the MLD report messages. The congestion would + // then prevent the MLD-snooping switches from working correctly and, as + // a result, prevent Duplicate Address Detection from working. The + // requirement to include the delay for the MLD report in this case + // avoids this scenario. [RFC3590] also talks about some interaction + // issues between Duplicate Address Detection and MLD, and specifies + // which source address should be used for the MLD report in this case. + // + // As per RFC 3590 section 4, we should still send out MLD reports with an + // unspecified source address if we do not have an assigned link-local + // address to use as the source address to ensure DAD works as expected on + // networks with MLD snooping switches: + // + // MLD Report and Done messages are sent with a link-local address as + // the IPv6 source address, if a valid address is available on the + // interface. If a valid link-local address is not available (e.g., one + // has not been configured), the message is sent with the unspecified + // address (::) as the IPv6 source address. + // + // Once a valid link-local address is available, a node SHOULD generate + // new MLD Report messages for all multicast addresses joined on the + // interface. + // + // Routers receiving an MLD Report or Done message with the unspecified + // address as the IPv6 source address MUST silently discard the packet + // without taking any action on the packets contents. + // + // Snooping switches MUST manage multicast forwarding state based on MLD + // Report and Done messages sent with the unspecified address as the + // IPv6 source address. + localAddress := mld.ep.getLinkLocalAddressRLocked() + if len(localAddress) == 0 { + localAddress = header.IPv6Any + } + + icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{})) + + extensionHeaders := header.IPv6ExtHdrSerializer{ + header.IPv6SerializableHopByHopExtHdr{ + &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD}, + }, + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()) + extensionHeaders.Length(), + Data: buffer.View(icmp).ToVectorisedView(), + }) + + mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.MLDHopLimit, + }, extensionHeaders) + if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + sentStats.Dropped.Increment() + return false, err + } + mldStat.Increment() + return localAddress != header.IPv6Any, nil +} diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go new file mode 100644 index 000000000..f6ffa7133 --- /dev/null +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -0,0 +1,297 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" +) + +var ( + linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr) + globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) +) + +func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) { + t.Helper() + + checker.IPv6WithExtHdr(t, p, + checker.IPv6ExtHdr( + checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), + ), + checker.SrcAddr(localAddress), + checker.DstAddr(remoteAddress), + // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. + checker.TTL(1), + checker.MLD(mldType, header.MLDMinimumSize, + checker.MLDMaxRespDelay(0), + checker.MLDMulticastAddress(groupAddress), + ), + ) +} + +func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + // The stack will join an address's solicited node multicast address when + // an address is added. An MLD report message should be sent for the + // solicited-node group. + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) + } + + // The stack will leave an address's solicited node multicast address when + // an address is removed. An MLD done message should be sent for the + // solicited-node group. + if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a done message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) + } +} + +func TestSendQueuedMLDReports(t *testing.T) { + const ( + nicID = 1 + maxReports = 2 + ) + + tests := []struct { + name string + dadTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD Disabled", + dadTransmits: 0, + retransmitTimer: 0, + }, + { + name: "DAD Enabled", + dadTransmits: 1, + retransmitTimer: time.Second, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: test.dadTransmits, + RetransmitTimer: test.retransmitTimer, + }, + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + Clock: clock, + }) + + // Allow space for an extra packet so we can observe packets that were + // unexpectedly sent. + e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + resolveDAD := func(addr, snmc tcpip.Address) { + clock.Advance(dadResolutionTime) + if p, ok := e.Read(); !ok { + t.Fatal("expected DAD packet") + } else { + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(addr), + checker.NDPNSOptions(nil), + )) + } + } + + var reportCounter uint64 + reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + var doneCounter uint64 + doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + if got := doneStat.Value(); got != doneCounter { + t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) + } + + // Joining a group without an assigned address should send an MLD report + // with the unspecified address. + if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err) + } + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", globalMulticastAddr) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Adding a global address should not send reports for the already joined + // group since we should only send queued reports when a link-local + // addres sis assigned. + // + // Note, we will still expect to send a report for the global address's + // solicited node address from the unspecified address as per RFC 3590 + // section 4. + if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) + } + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", globalAddrSNMC) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC) + } + if dadResolutionTime != 0 { + // Reports should not be sent when the address resolves. + resolveDAD(globalAddr, globalAddrSNMC) + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + } + // Leave the group since we don't care about the global address's + // solicited node multicast group membership. + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err) + } + if got := doneStat.Value(); got != doneCounter { + t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) + } + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Adding a link-local address should send a report for its solicited node + // address and globalMulticastAddr. + if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) + } + if dadResolutionTime != 0 { + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", linkLocalAddrSNMC) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) + } + resolveDAD(linkLocalAddr, linkLocalAddrSNMC) + } + + // We expect two batches of reports to be sent (1 batch when the + // link-local address is assigned, and another after the maximum + // unsolicited report interval. + for i := 0; i < 2; i++ { + // We expect reports to be sent (one for globalMulticastAddr and another + // for linkLocalAddrSNMC). + reportCounter += maxReports + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + + addrs := map[tcpip.Address]bool{ + globalMulticastAddr: false, + linkLocalAddrSNMC: false, + } + for range addrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs) + } + + addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress() + if seen, ok := addrs[addr]; !ok { + t.Fatalf("got unexpected packet destined to %s", addr) + } else if seen { + t.Fatalf("got another packet destined to %s", addr) + } + + addrs[addr] = true + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr) + + clock.Advance(ipv6.UnsolicitedReportIntervalMax) + } + } + + // Should not send any more reports. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 40da011f8..d515eb622 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -20,6 +20,7 @@ import ( "math/rand" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -459,6 +460,9 @@ func (c *NDPConfigurations) validate() { // ndpState is the per-interface NDP state. type ndpState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + // The IPv6 endpoint this ndpState is for. ep *endpoint @@ -471,17 +475,8 @@ type ndpState struct { // The default routers discovered through Router Advertisements. defaultRouters map[tcpip.Address]defaultRouterState - rtrSolicit struct { - // The timer used to send the next router solicitation message. - timer tcpip.Timer - - // Used to let the Router Solicitation timer know that it has been stopped. - // - // Must only be read from or written to while protected by the lock of - // the IPv6 endpoint this ndpState is associated with. MUST be set when the - // timer is set. - done *bool - } + // The job used to send the next router solicitation message. + rtrSolicitJob *tcpip.Job // The on-link prefixes discovered through Router Advertisements' Prefix // Information option. @@ -507,7 +502,7 @@ type ndpState struct { // to the DAD goroutine that DAD should stop. type dadState struct { // The DAD timer to send the next NS message, or resolve the address. - timer tcpip.Timer + job *tcpip.Job // Used to let the DAD timer know that it has been stopped. // @@ -648,96 +643,73 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // Consider DAD to have resolved even if no DAD messages were actually // transmitted. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) } + ndp.ep.onAddressAssignedLocked(addr) return nil } - var done bool - var timer tcpip.Timer - // We initially start a timer to fire immediately because some of the DAD work - // cannot be done while holding the IPv6 endpoint's lock. This is effectively - // the same as starting a goroutine but we use a timer that fires immediately - // so we can reset it for the next DAD iteration. - timer = ndp.ep.protocol.stack.Clock().AfterFunc(0, func() { - ndp.ep.mu.Lock() - defer ndp.ep.mu.Unlock() - - if done { - // If we reach this point, it means that the DAD timer fired after - // another goroutine already obtained the IPv6 endpoint lock and stopped - // DAD before this function obtained the NIC lock. Simply return here and - // do nothing further. - return - } + state := dadState{ + job: ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { + state, ok := ndp.dad[addr] + if !ok { + panic(fmt.Sprintf("ndpdad: DAD timer fired but missing state for %s on NIC(%d)", addr, ndp.ep.nic.ID())) + } - if addressEndpoint.GetKind() != stack.PermanentTentative { - // The endpoint should still be marked as tentative since we are still - // performing DAD on it. - panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) - } + if addressEndpoint.GetKind() != stack.PermanentTentative { + // The endpoint should still be marked as tentative since we are still + // performing DAD on it. + panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) + } - dadDone := remaining == 0 - - var err *tcpip.Error - if !dadDone { - // Use the unspecified address as the source address when performing DAD. - addressEndpoint := ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) - - // Do not hold the lock when sending packets which may be a long running - // task or may block link address resolution. We know this is safe - // because immediately after obtaining the lock again, we check if DAD - // has been stopped before doing any work with the IPv6 endpoint. Note, - // DAD would be stopped if the IPv6 endpoint was disabled or closed, or if - // the address was removed. - ndp.ep.mu.Unlock() - err = ndp.sendDADPacket(addr, addressEndpoint) - ndp.ep.mu.Lock() - addressEndpoint.DecRef() - } + dadDone := remaining == 0 - if done { - // If we reach this point, it means that DAD was stopped after we released - // the IPv6 endpoint's read lock and before we obtained the write lock. - return - } + var err *tcpip.Error + if !dadDone { + err = ndp.sendDADPacket(addr, addressEndpoint) + } - if dadDone { - // DAD has resolved. - addressEndpoint.SetKind(stack.Permanent) - } else if err == nil { - // DAD is not done and we had no errors when sending the last NDP NS, - // schedule the next DAD timer. - remaining-- - timer.Reset(ndp.configs.RetransmitTimer) - return - } + if dadDone { + // DAD has resolved. + addressEndpoint.SetKind(stack.Permanent) + } else if err == nil { + // DAD is not done and we had no errors when sending the last NDP NS, + // schedule the next DAD timer. + remaining-- + state.job.Schedule(ndp.configs.RetransmitTimer) + return + } - // At this point we know that either DAD is done or we hit an error sending - // the last NDP NS. Either way, clean up addr's DAD state and let the - // integrator know DAD has completed. - delete(ndp.dad, addr) + // At this point we know that either DAD is done or we hit an error + // sending the last NDP NS. Either way, clean up addr's DAD state and let + // the integrator know DAD has completed. + delete(ndp.dad, addr) - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) - } + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { + ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) + } - // If DAD resolved for a stable SLAAC address, attempt generation of a - // temporary SLAAC address. - if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac { - // Reset the generation attempts counter as we are starting the generation - // of a new address for the SLAAC prefix. - ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) - } - }) + if dadDone { + if addressEndpoint.ConfigType() == stack.AddressConfigSlaac { + // Reset the generation attempts counter as we are starting the + // generation of a new address for the SLAAC prefix. + ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) + } - ndp.dad[addr] = dadState{ - timer: timer, - done: &done, + ndp.ep.onAddressAssignedLocked(addr) + } + }), } + // We initially start a timer to fire immediately because some of the DAD work + // cannot be done while holding the IPv6 endpoint's lock. This is effectively + // the same as starting a goroutine but we use a timer that fires immediately + // so we can reset it for the next DAD iteration. + state.job.Schedule(0) + ndp.dad[addr] = state + return nil } @@ -745,55 +717,31 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // addr. // // addr must be a tentative IPv6 address on ndp's IPv6 endpoint. -// -// The IPv6 endpoint that ndp belongs to MUST NOT be locked. func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.AddressEndpoint) *tcpip.Error { snmc := header.SolicitedNodeAddr(addr) - r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), header.IPv6Any, snmc, ProtocolNumber, false /* multicastLoop */) - if err != nil { - return err - } - defer r.Release() - - // Route should resolve immediately since snmc is a multicast address so a - // remote link address can be calculated without a resolution process. - if c, err := r.Resolve(nil); err != nil { - // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the IPv6 endpoint is not - // locked, the NIC could have been disabled or removed by another goroutine. - if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState { - return err - } - - panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.ep.nic.ID(), err)) - } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.ep.nic.ID())) - } - - icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) - icmpData.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(icmpData.NDPPayload()) + icmp := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.MessageBody()) ns.SetTargetAddress(addr) - icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, snmc, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), - Data: buffer.View(icmpData).ToVectorisedView(), + ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()), + Data: buffer.View(icmp).ToVectorisedView(), }) - sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, - stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - }, pkt, - ); err != nil { + sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent + ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, nil /* extensionHeaders */) + + if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() return err } sent.NeighborSolicit.Increment() - return nil } @@ -812,18 +760,11 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) { return } - if dad.timer != nil { - dad.timer.Stop() - dad.timer = nil - - *dad.done = true - dad.done = nil - } - + dad.job.Cancel() delete(ndp.dad, addr) // Let the integrator know DAD did not resolve. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, false, nil) } } @@ -846,7 +787,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we // only inform the dispatcher on configuration changes. We do nothing else // with the information. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { var configuration DHCPv6ConfigurationFromNDPRA switch { case ra.ManagedAddrConfFlag(): @@ -903,20 +844,20 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() { switch opt := opt.(type) { case header.NDPRecursiveDNSServer: - if ndp.ep.protocol.ndpDisp == nil { + if ndp.ep.protocol.options.NDPDisp == nil { continue } addrs, _ := opt.Addresses() - ndp.ep.protocol.ndpDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime()) + ndp.ep.protocol.options.NDPDisp.OnRecursiveDNSServerOption(ndp.ep.nic.ID(), addrs, opt.Lifetime()) case header.NDPDNSSearchList: - if ndp.ep.protocol.ndpDisp == nil { + if ndp.ep.protocol.options.NDPDisp == nil { continue } domainNames, _ := opt.DomainNames() - ndp.ep.protocol.ndpDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime()) + ndp.ep.protocol.options.NDPDisp.OnDNSSearchListOption(ndp.ep.nic.ID(), domainNames, opt.Lifetime()) case header.NDPPrefixInformation: prefix := opt.Subnet() @@ -964,7 +905,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnDefaultRouterInvalidated(ndp.ep.nic.ID(), ip) } } @@ -976,7 +917,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { - ndpDisp := ndp.ep.protocol.ndpDisp + ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return } @@ -1006,7 +947,7 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) { - ndpDisp := ndp.ep.protocol.ndpDisp + ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return } @@ -1047,7 +988,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnOnLinkPrefixInvalidated(ndp.ep.nic.ID(), prefix) } } @@ -1225,7 +1166,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) addAndAcquireSLAACAddr(addr tcpip.AddressWithPrefix, configType stack.AddressConfigType, deprecated bool) stack.AddressEndpoint { // Inform the integrator that we have a new SLAAC address. - ndpDisp := ndp.ep.protocol.ndpDisp + ndpDisp := ndp.ep.protocol.options.NDPDisp if ndpDisp == nil { return nil } @@ -1272,7 +1213,7 @@ func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixSt } dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures - if oIID := ndp.ep.protocol.opaqueIIDOpts; oIID.NICNameFromID != nil { + if oIID := ndp.ep.protocol.options.OpaqueIIDOpts; oIID.NICNameFromID != nil { addrBytes = header.AppendOpaqueInterfaceIdentifier( addrBytes[:header.IIDOffsetInIPv6Address], prefix, @@ -1676,7 +1617,7 @@ func (ndp *ndpState) deprecateSLAACAddress(addressEndpoint stack.AddressEndpoint } addressEndpoint.SetDeprecated(true) - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressDeprecated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix()) } } @@ -1701,7 +1642,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) { - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } @@ -1761,7 +1702,7 @@ func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLA // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) { - if ndpDisp := ndp.ep.protocol.ndpDisp; ndpDisp != nil { + if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addr) } @@ -1859,7 +1800,7 @@ func (ndp *ndpState) cleanupState(hostOnly bool) { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) startSolicitingRouters() { - if ndp.rtrSolicit.timer != nil { + if ndp.rtrSolicitJob != nil { // We are already soliciting routers. return } @@ -1876,56 +1817,14 @@ func (ndp *ndpState) startSolicitingRouters() { delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay))) } - var done bool - ndp.rtrSolicit.done = &done - ndp.rtrSolicit.timer = ndp.ep.protocol.stack.Clock().AfterFunc(delay, func() { - ndp.ep.mu.Lock() - if done { - // If we reach this point, it means that the RS timer fired after another - // goroutine already obtained the IPv6 endpoint lock and stopped - // solicitations. Simply return here and do nothing further. - ndp.ep.mu.Unlock() - return - } - + ndp.rtrSolicitJob = ndp.ep.protocol.stack.NewJob(&ndp.ep.mu, func() { // As per RFC 4861 section 4.1, the source of the RS is an address assigned // to the sending interface, or the unspecified address if no address is // assigned to the sending interface. - addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false) - if addressEndpoint == nil { - // Incase this ends up creating a new temporary address, we need to hold - // onto the endpoint until a route is obtained. If we decrement the - // reference count before obtaing a route, the address's resources would - // be released and attempting to obtain a route after would fail. Once a - // route is obtainted, it is safe to decrement the reference count since - // obtaining a route increments the address's reference count. - addressEndpoint = ndp.ep.acquireAddressOrCreateTempLocked(header.IPv6Any, true /* createTemp */, stack.NeverPrimaryEndpoint) - } - ndp.ep.mu.Unlock() - - localAddr := addressEndpoint.AddressWithPrefix().Address - r, err := ndp.ep.protocol.stack.FindRoute(ndp.ep.nic.ID(), localAddr, header.IPv6AllRoutersMulticastAddress, ProtocolNumber, false /* multicastLoop */) - addressEndpoint.DecRef() - if err != nil { - return - } - defer r.Release() - - // Route should resolve immediately since - // header.IPv6AllRoutersMulticastAddress is a multicast address so a - // remote link address can be calculated without a resolution process. - if c, err := r.Resolve(nil); err != nil { - // Do not consider the NIC being unknown or disabled as a fatal error. - // Since this method is required to be called when the IPv6 endpoint is - // not locked, the IPv6 endpoint could have been disabled or removed by - // another goroutine. - if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState { - return - } - - panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID(), err)) - } else if c != nil { - panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.ep.nic.ID())) + localAddr := header.IPv6Any + if addressEndpoint := ndp.ep.acquireOutgoingPrimaryAddressRLocked(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil { + localAddr = addressEndpoint.AddressWithPrefix().Address + addressEndpoint.DecRef() } // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source @@ -1936,30 +1835,31 @@ func (ndp *ndpState) startSolicitingRouters() { // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by // LinkEndpoint.LinkAddress) before reaching this point. var optsSerializer header.NDPOptionsSerializer - if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) { + linkAddress := ndp.ep.nic.LinkAddress() + if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(linkAddress) { optsSerializer = header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress), + header.NDPSourceLinkLayerAddressOption(linkAddress), } } payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length()) icmpData := header.ICMPv6(buffer.NewView(payloadSize)) icmpData.SetType(header.ICMPv6RouterSolicit) - rs := header.NDPRouterSolicit(icmpData.NDPPayload()) + rs := header.NDPRouterSolicit(icmpData.MessageBody()) rs.Options().Serialize(optsSerializer) - icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, localAddr, header.IPv6AllRoutersMulticastAddress, buffer.VectorisedView{})) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(r.MaxHeaderLength()), + ReserveHeaderBytes: int(ndp.ep.MaxHeaderLength()), Data: buffer.View(icmpData).ToVectorisedView(), }) - sent := r.Stats().ICMP.V6PacketsSent - if err := r.WritePacket(nil, - stack.NetworkHeaderParams{ - Protocol: header.ICMPv6ProtocolNumber, - TTL: header.NDPHopLimit, - }, pkt, - ); err != nil { + sent := ndp.ep.protocol.stack.Stats().ICMP.V6.PacketsSent + ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, nil /* extensionHeaders */) + + if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.ep.nic.ID(), err) // Don't send any more messages if we had an error. @@ -1969,21 +1869,12 @@ func (ndp *ndpState) startSolicitingRouters() { remaining-- } - ndp.ep.mu.Lock() - if done || remaining == 0 { - ndp.rtrSolicit.timer = nil - ndp.rtrSolicit.done = nil - } else if ndp.rtrSolicit.timer != nil { - // Note, we need to explicitly check to make sure that - // the timer field is not nil because if it was nil but - // we still reached this point, then we know the IPv6 endpoint - // was requested to stop soliciting routers so we don't - // need to send the next Router Solicitation message. - ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval) + if remaining != 0 { + ndp.rtrSolicitJob.Schedule(ndp.configs.RtrSolicitationInterval) } - ndp.ep.mu.Unlock() }) + ndp.rtrSolicitJob.Schedule(delay) } // stopSolicitingRouters stops soliciting routers. If routers are not currently @@ -1991,22 +1882,28 @@ func (ndp *ndpState) startSolicitingRouters() { // // The IPv6 endpoint that ndp belongs to MUST be locked. func (ndp *ndpState) stopSolicitingRouters() { - if ndp.rtrSolicit.timer == nil { + if ndp.rtrSolicitJob == nil { // Nothing to do. return } - *ndp.rtrSolicit.done = true - ndp.rtrSolicit.timer.Stop() - ndp.rtrSolicit.timer = nil - ndp.rtrSolicit.done = nil + ndp.rtrSolicitJob.Cancel() + ndp.rtrSolicitJob = nil } -// initializeTempAddrState initializes state related to temporary SLAAC -// addresses. -func (ndp *ndpState) initializeTempAddrState() { - header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.tempIIDSeed, ndp.ep.nic.ID()) +func (ndp *ndpState) init(ep *endpoint) { + if ndp.dad != nil { + panic("attempted to initialize NDP state twice") + } + + ndp.ep = ep + ndp.configs = ep.protocol.options.NDPConfigs + ndp.dad = make(map[tcpip.Address]dadState) + ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState) + ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState) + ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) if MaxDesyncFactor != 0 { ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 981d1371a..b1a5a5510 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -45,10 +45,6 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } - if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err) - } - { subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr)))) if err != nil { @@ -73,6 +69,17 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeig } t.Cleanup(ep.Close) + addressableEndpoint, ok := ep.(stack.AddressableEndpoint) + if !ok { + t.Fatalf("expected network endpoint to implement stack.AddressableEndpoint") + } + addr := llladdr.WithPrefix() + if addressEP, err := addressableEndpoint.AddAndAcquirePermanentAddress(addr, stack.CanBePrimaryEndpoint, stack.AddressConfigStatic, false /* deprecated */); err != nil { + t.Fatalf("addressableEndpoint.AddAndAcquirePermanentAddress(%s, CanBePrimaryEndpoint, AddressConfigStatic, false): %s", addr, err) + } else { + addressEP.DecRef() + } + return s, ep } @@ -198,7 +205,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(lladdr0) opts := ns.Options() copy(opts, test.optsBuf) @@ -206,14 +213,14 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { @@ -304,7 +311,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(lladdr0) opts := ns.Options() copy(opts, test.optsBuf) @@ -312,23 +319,23 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) neighbors, err := s.Neighbors(nicID) if err != nil { @@ -574,7 +581,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { } s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: header.IPv6EmptySubnet, NIC: 1, }, @@ -584,7 +591,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(nicAddr) opts := ns.Options() opts.Serialize(test.nsOpts) @@ -592,14 +599,14 @@ func TestNeighorSolicitationResponse(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: test.nsSrc, - DstAddr: test.nsDst, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { @@ -665,7 +672,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(test.nsSrc) @@ -674,11 +681,11 @@ func TestNeighorSolicitationResponse(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: test.nsSrc, - DstAddr: nicAddr, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: test.nsSrc, + DstAddr: nicAddr, }) e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -770,7 +777,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns := header.NDPNeighborAdvert(pkt.MessageBody()) ns.SetTargetAddress(lladdr1) opts := ns.Options() copy(opts, test.optsBuf) @@ -778,14 +785,14 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { @@ -883,7 +890,7 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) pkt.SetType(header.ICMPv6NeighborAdvert) - ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns := header.NDPNeighborAdvert(pkt.MessageBody()) ns.SetTargetAddress(lladdr1) opts := ns.Options() copy(opts, test.optsBuf) @@ -891,23 +898,23 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid // Invalid count should initially be 0. if got := invalid.Value(); got != 0 { t.Fatalf("got invalid = %d, want = 0", got) } - e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), - }) + })) neighbors, err := s.Neighbors(nicID) if err != nil { @@ -961,46 +968,36 @@ func TestNDPValidation(t *testing.T) { for _, stackTyp := range stacks { t.Run(stackTyp.name, func(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { + setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() // Create a stack with the assigned link-local address lladdr0 // and an endpoint to lladdr1. s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache) - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - - return s, ep, r + return s, ep } - handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { - nextHdr := uint8(header.ICMPv6ProtocolNumber) - var extensions buffer.View + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { + var extHdrs header.IPv6ExtHdrSerializer if atomicFragment { - extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) - extensions[0] = nextHdr - nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) + extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{}) } + extHdrsLen := extHdrs.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions), + ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen, Data: payload.ToVectorisedView(), }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions))) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload) + len(extensions)), - NextHeader: nextHdr, - HopLimit: hopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, + PayloadLength: uint16(len(payload) + extHdrsLen), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: hopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + ExtensionHeaders: extHdrs, }) - if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { - t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) - } - r.PopulatePacketInfo(pkt) ep.HandlePacket(pkt) } @@ -1114,15 +1111,14 @@ func TestNDPValidation(t *testing.T) { t.Run(name, func(t *testing.T) { for _, test := range subTests { t.Run(test.name, func(t *testing.T) { - s, ep, r := setup(t) - defer r.Release() + s, ep := setup(t) if isRouter { // Enabling forwarding makes the stack act as a router. s.SetForwarding(ProtocolNumber, true) } - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid routerOnly := stats.RouterOnlyPacketsDroppedByHost typStat := typ.statCounter(stats) @@ -1131,7 +1127,7 @@ func TestNDPValidation(t *testing.T) { copy(icmp[typ.size:], typ.extraData) icmp.SetType(typ.typ) icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], lladdr0, lladdr1, buffer.View(typ.extraData).ToVectorisedView())) // Rx count of the NDP message should initially be 0. if got := typStat.Value(); got != 0 { @@ -1152,7 +1148,7 @@ func TestNDPValidation(t *testing.T) { t.FailNow() } - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep) // Rx count of the NDP packet should have increased. if got := typStat.Value(); got != 1 { @@ -1346,19 +1342,19 @@ func TestRouterAdvertValidation(t *testing.T) { pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(test.code) - copy(pkt.NDPPayload(), test.ndpPayload) + copy(pkt.MessageBody(), test.ndpPayload) payloadLength := hdr.UsedLength() pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, }) - stats := s.Stats().ICMP.V6PacketsReceived + stats := s.Stats().ICMP.V6.PacketsReceived invalid := stats.Invalid rxRA := stats.RouterAdvert diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go new file mode 100644 index 000000000..0f4f0e1e1 --- /dev/null +++ b/pkg/tcpip/network/multicast_group_test.go @@ -0,0 +1,1261 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip_test + +import ( + "fmt" + "strings" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + + ipv4Addr = tcpip.Address("\x0a\x00\x00\x01") + ipv6Addr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + + ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03") + ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04") + ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05") + ipv6MulticastAddr1 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + ipv6MulticastAddr2 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04") + ipv6MulticastAddr3 = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x05") + + igmpMembershipQuery = uint8(header.IGMPMembershipQuery) + igmpv1MembershipReport = uint8(header.IGMPv1MembershipReport) + igmpv2MembershipReport = uint8(header.IGMPv2MembershipReport) + igmpLeaveGroup = uint8(header.IGMPLeaveGroup) + mldQuery = uint8(header.ICMPv6MulticastListenerQuery) + mldReport = uint8(header.ICMPv6MulticastListenerReport) + mldDone = uint8(header.ICMPv6MulticastListenerDone) + + maxUnsolicitedReports = 2 +) + +var ( + // unsolicitedIGMPReportIntervalMaxTenthSec is the maximum amount of time the + // NIC will wait before sending an unsolicited report after joining a + // multicast group, in deciseconds. + unsolicitedIGMPReportIntervalMaxTenthSec = func() uint8 { + const decisecond = time.Second / 10 + if ipv4.UnsolicitedReportIntervalMax%decisecond != 0 { + panic(fmt.Sprintf("UnsolicitedReportIntervalMax of %d is a lossy conversion to deciseconds", ipv4.UnsolicitedReportIntervalMax)) + } + return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond) + }() + + ipv6AddrSNMC = header.SolicitedNodeAddr(ipv6Addr) +) + +// validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet +// sent to the provided address with the passed fields set. +func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, mldType uint8, maxRespTime byte, groupAddress tcpip.Address) { + t.Helper() + + payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) + checker.IPv6WithExtHdr(t, payload, + checker.IPv6ExtHdr( + checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), + ), + checker.SrcAddr(ipv6Addr), + checker.DstAddr(remoteAddress), + // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. + checker.TTL(1), + checker.MLD(header.ICMPv6Type(mldType), header.MLDMinimumSize, + checker.MLDMaxRespDelay(time.Duration(maxRespTime)*time.Millisecond), + checker.MLDMulticastAddress(groupAddress), + ), + ) +} + +// validateIGMPPacket checks that a passed PacketInfo is an IPv4 IGMP packet +// sent to the provided address with the passed fields set. +func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.Address, igmpType uint8, maxRespTime byte, groupAddress tcpip.Address) { + t.Helper() + + payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + checker.IPv4(t, payload, + checker.SrcAddr(ipv4Addr), + checker.DstAddr(remoteAddress), + // TTL for an IGMP message must be 1 as per RFC 2236 section 2. + checker.TTL(1), + checker.IPv4RouterAlert(), + checker.IGMP( + checker.IGMPType(header.IGMPType(igmpType)), + checker.IGMPMaxRespTime(header.DecisecondToDuration(maxRespTime)), + checker.IGMPGroupAddress(groupAddress), + ), + ) +} + +func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { + t.Helper() + + e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr) + s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e) + return e, s, clock +} + +func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) { + t.Helper() + + igmpEnabled := v4 && mgpEnabled + mldEnabled := !v4 && mgpEnabled + + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocolWithOptions(ipv4.Options{ + IGMP: ipv4.IGMPOptions{ + Enabled: igmpEnabled, + }, + }), + ipv6.NewProtocolWithOptions(ipv6.Options{ + MLD: ipv6.MLDOptions{ + Enabled: mldEnabled, + }, + }), + }, + Clock: clock, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4Addr, err) + } + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6Addr, err) + } + + return s, clock +} + +// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join +// when it is created with an IPv6 address. +// +// To not interfere with tests, checkInitialIPv6Groups will leave the added +// address's solicited node multicast group so that the tests can all assume +// the NIC has not joined any IPv6 groups. +func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) { + t.Helper() + + stats := s.Stats().ICMP.V6.PacketsSent + + reportCounter++ + if got := stats.MulticastListenerReport.Value(); got != reportCounter { + t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC) + } + + // Leave the group to not affect the tests. This is fine since we are not + // testing DAD or the solicited node address specifically. + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err) + } + leaveCounter++ + if got := stats.MulticastListenerDone.Value(); got != leaveCounter { + t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + + return reportCounter, leaveCounter +} + +// createAndInjectIGMPPacket creates and injects an IGMP packet with the +// specified fields. +// +// Note, the router alert option is not included in this packet. +// +// TODO(b/162198658): set the router alert option. +func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType byte, maxRespTime byte, groupAddress tcpip.Address) { + buf := buffer.NewView(header.IPv4MinimumSize + header.IGMPQueryMinimumSize) + + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(len(buf)), + TTL: header.IGMPTTL, + Protocol: uint8(header.IGMPProtocolNumber), + SrcAddr: header.IPv4Any, + DstAddr: header.IPv4AllSystems, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + igmp := header.IGMP(buf[header.IPv4MinimumSize:]) + igmp.SetType(header.IGMPType(igmpType)) + igmp.SetMaxRespTime(maxRespTime) + igmp.SetGroupAddress(groupAddress) + igmp.SetChecksum(header.IGMPCalculateChecksum(igmp)) + + e.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// createAndInjectMLDPacket creates and injects an MLD packet with the +// specified fields. +// +// Note, the router alert option is not included in this packet. +// +// TODO(b/162198658): set the router alert option. +func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay byte, groupAddress tcpip.Address) { + icmpSize := header.ICMPv6HeaderSize + header.MLDMinimumSize + buf := buffer.NewView(header.IPv6MinimumSize + icmpSize) + + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(icmpSize), + HopLimit: header.MLDHopLimit, + TransportProtocol: header.ICMPv6ProtocolNumber, + SrcAddr: header.IPv4Any, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) + + icmp := header.ICMPv6(buf[header.IPv6MinimumSize:]) + icmp.SetType(header.ICMPv6Type(mldType)) + mld := header.MLD(icmp.MessageBody()) + mld.SetMaximumResponseDelay(uint16(maxRespDelay)) + mld.SetMulticastAddress(groupAddress) + icmp.SetChecksum(header.ICMPv6Checksum(icmp, header.IPv6Any, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + + e.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// TestMGPDisabled tests that the multicast group protocol is not enabled by +// default. +func TestMGPDisabled(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + receivedQueryStat func(*stack.Stack) *tcpip.StatCounter + rxQuery func(*channel.Endpoint) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + rxQuery: func(e *channel.Endpoint) { + createAndInjectIGMPPacket(e, igmpMembershipQuery, unsolicitedIGMPReportIntervalMaxTenthSec, header.IPv4Any) + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + rxQuery: func(e *channel.Endpoint) { + createAndInjectMLDPacket(e, mldQuery, 0, header.IPv6Any) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */) + + // This NIC may join multicast groups when it is enabled but since MGP is + // disabled, no reports should be sent. + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet, stack with disabled MGP sent packet = %#v", p.Pkt) + } + + // Test joining a specific group explicitly and verify that no reports are + // sent. + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %#v", p.Pkt) + } + + // Inject a general query message. This should only trigger a report to be + // sent if the MGP was enabled. + test.rxQuery(e) + if got := test.receivedQueryStat(s).Value(); got != 1 { + t.Fatalf("got receivedQueryStat(_).Value() = %d, want = 1", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet, stack with disabled IGMP sent packet = %+v", p.Pkt) + } + }) + } +} + +func TestMGPReceiveCounters(t *testing.T) { + tests := []struct { + name string + headerType uint8 + maxRespTime byte + groupAddress tcpip.Address + statCounter func(*stack.Stack) *tcpip.StatCounter + rxMGPkt func(*channel.Endpoint, byte, byte, tcpip.Address) + }{ + { + name: "IGMP Membership Query", + headerType: igmpMembershipQuery, + maxRespTime: unsolicitedIGMPReportIntervalMaxTenthSec, + groupAddress: header.IPv4Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "IGMPv1 Membership Report", + headerType: igmpv1MembershipReport, + maxRespTime: 0, + groupAddress: header.IPv4AllSystems, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.V1MembershipReport + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "IGMPv2 Membership Report", + headerType: igmpv2MembershipReport, + maxRespTime: 0, + groupAddress: header.IPv4AllSystems, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.V2MembershipReport + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "IGMP Leave Group", + headerType: igmpLeaveGroup, + maxRespTime: 0, + groupAddress: header.IPv4AllRoutersGroup, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.LeaveGroup + }, + rxMGPkt: createAndInjectIGMPPacket, + }, + { + name: "MLD Query", + headerType: mldQuery, + maxRespTime: 0, + groupAddress: header.IPv6Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + rxMGPkt: createAndInjectMLDPacket, + }, + { + name: "MLD Report", + headerType: mldReport, + maxRespTime: 0, + groupAddress: header.IPv6Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerReport + }, + rxMGPkt: createAndInjectMLDPacket, + }, + { + name: "MLD Done", + headerType: mldDone, + maxRespTime: 0, + groupAddress: header.IPv6Any, + statCounter: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerDone + }, + rxMGPkt: createAndInjectMLDPacket, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */) + + test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress) + if got := test.statCounter(s).Value(); got != 1 { + t.Fatalf("got %s received = %d, want = 1", test.name, got) + } + }) + } +} + +// TestMGPJoinGroup tests that when explicitly joining a multicast group, the +// stack schedules and sends correct Membership Reports. +func TestMGPJoinGroup(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + maxUnsolicitedResponseDelay time.Duration + sentReportStat func(*stack.Stack) *tcpip.StatCounter + receivedQueryStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo) + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, _ = test.checkInitialGroups(t, e, s, clock) + } + + // Test joining a specific address explicitly and verify a Report is sent + // immediately. + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + reportCounter++ + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + if t.Failed() { + t.FailNow() + } + + // Verify the second report is sent by the maximum unsolicited response + // interval. + p, ok := e.Read() + if ok { + t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt) + } + clock.Advance(test.maxUnsolicitedResponseDelay) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +// TestMGPLeaveGroup tests that when leaving a previously joined multicast +// group the stack sends a leave/done message. +func TestMGPLeaveGroup(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo) + validateLeave func(*testing.T, channel.PacketInfo) + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.LeaveGroup + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, ipv4MulticastAddr1) + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } + + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + reportCounter++ + if got := test.sentReportStat(s).Value(); got != reportCounter { + t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + if t.Failed() { + t.FailNow() + } + + // Leaving the group should trigger an leave/done message to be sent. + if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) + } + leaveCounter++ + if got := test.sentLeaveStat(s).Value(); got != leaveCounter { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a leave message to be sent") + } else { + test.validateLeave(t, p) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +// TestMGPQueryMessages tests that a report is sent in response to query +// messages. +func TestMGPQueryMessages(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + maxUnsolicitedResponseDelay time.Duration + sentReportStat func(*stack.Stack) *tcpip.StatCounter + receivedQueryStat func(*stack.Stack) *tcpip.StatCounter + rxQuery func(*channel.Endpoint, uint8, tcpip.Address) + validateReport func(*testing.T, channel.PacketInfo) + maxRespTimeToDuration func(uint8) time.Duration + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsReceived.MembershipQuery + }, + rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { + createAndInjectIGMPPacket(e, igmpMembershipQuery, maxRespTime, groupAddress) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + maxRespTimeToDuration: header.DecisecondToDuration, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + receivedQueryStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsReceived.MulticastListenerQuery + }, + rxQuery: func(e *channel.Endpoint, maxRespTime uint8, groupAddress tcpip.Address) { + createAndInjectMLDPacket(e, mldQuery, maxRespTime, groupAddress) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + maxRespTimeToDuration: func(d uint8) time.Duration { + return time.Duration(d) * time.Millisecond + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + subTests := []struct { + name string + multicastAddr tcpip.Address + expectReport bool + }{ + { + name: "Unspecified", + multicastAddr: tcpip.Address(strings.Repeat("\x00", len(test.multicastAddr))), + expectReport: true, + }, + { + name: "Specified", + multicastAddr: test.multicastAddr, + expectReport: true, + }, + { + name: "Specified other address", + multicastAddr: func() tcpip.Address { + addrBytes := []byte(test.multicastAddr) + addrBytes[len(addrBytes)-1]++ + return tcpip.Address(addrBytes) + }(), + expectReport: false, + }, + } + + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, _ = test.checkInitialGroups(t, e, s, clock) + } + + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + sentReportStat := test.sentReportStat(s) + for i := 0; i < maxUnsolicitedReports; i++ { + sentReportStat := test.sentReportStat(s) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatalf("expected %d-th report message to be sent", i) + } else { + test.validateReport(t, p) + } + clock.Advance(test.maxUnsolicitedResponseDelay) + } + if t.Failed() { + t.FailNow() + } + + // Should not send any more packets until a query. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + + // Receive a query message which should trigger a report to be sent at + // some time before the maximum response time if the report is + // targeted at the host. + const maxRespTime = 100 + test.rxQuery(e, maxRespTime, subTest.multicastAddr) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p.Pkt) + } + + if subTest.expectReport { + clock.Advance(test.maxRespTimeToDuration(maxRespTime)) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } + }) + } +} + +// TestMGPQueryMessages tests that no further reports or leave/done messages +// are sent after receiving a report. +func TestMGPReportMessages(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + rxReport func(*channel.Endpoint) + validateReport func(*testing.T, channel.PacketInfo) + maxRespTimeToDuration func(uint8) time.Duration + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.LeaveGroup + }, + rxReport: func(e *channel.Endpoint) { + createAndInjectIGMPPacket(e, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateIGMPPacket(t, p, ipv4MulticastAddr1, igmpv2MembershipReport, 0, ipv4MulticastAddr1) + }, + maxRespTimeToDuration: header.DecisecondToDuration, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + }, + rxReport: func(e *channel.Endpoint) { + createAndInjectMLDPacket(e, mldReport, 0, ipv6MulticastAddr1) + }, + validateReport: func(t *testing.T, p channel.PacketInfo) { + t.Helper() + + validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) + }, + maxRespTimeToDuration: func(d uint8) time.Duration { + return time.Duration(d) * time.Millisecond + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } + + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + sentReportStat := test.sentReportStat(s) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p) + } + if t.Failed() { + t.FailNow() + } + + // Receiving a report for a group we joined should cancel any further + // reports. + test.rxReport(e) + clock.Advance(time.Hour) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); ok { + t.Errorf("sent unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Leaving a group after getting a report should not send a leave/done + // message. + if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) + } + clock.Advance(time.Hour) + if got := test.sentLeaveStat(s).Value(); got != leaveCounter { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +func TestMGPWithNICLifecycle(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddrs []tcpip.Address + finalMulticastAddr tcpip.Address + maxUnsolicitedResponseDelay time.Duration + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo, tcpip.Address) + validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address) + getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddrs: []tcpip.Address{ipv4MulticastAddr1, ipv4MulticastAddr2}, + finalMulticastAddr: ipv4MulticastAddr3, + maxUnsolicitedResponseDelay: ipv4.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.LeaveGroup + }, + validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateIGMPPacket(t, p, addr, igmpv2MembershipReport, 0, addr) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateIGMPPacket(t, p, header.IPv4AllRoutersGroup, igmpLeaveGroup, 0, addr) + }, + getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { + t.Helper() + + ipv4 := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) + if got := tcpip.TransportProtocolNumber(ipv4.Protocol()); got != header.IGMPProtocolNumber { + t.Fatalf("got ipv4.Protocol() = %d, want = %d", got, header.IGMPProtocolNumber) + } + addr := header.IGMP(ipv4.Payload()).GroupAddress() + s, ok := seen[addr] + if !ok { + t.Fatalf("unexpectedly got a packet for group %s", addr) + } + if s { + t.Fatalf("already saw packet for group %s", addr) + } + seen[addr] = true + return addr + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddrs: []tcpip.Address{ipv6MulticastAddr1, ipv6MulticastAddr2}, + finalMulticastAddr: ipv6MulticastAddr3, + maxUnsolicitedResponseDelay: ipv6.UnsolicitedReportIntervalMax, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + sentLeaveStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + }, + validateReport: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateMLDPacket(t, p, addr, mldReport, 0, addr) + }, + validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { + t.Helper() + + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr) + }, + getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { + t.Helper() + + ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) + + ipv6HeaderIter := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), + buffer.View(ipv6.Payload()).ToVectorisedView(), + ) + + var transport header.IPv6RawPayloadHeader + for { + h, done, err := ipv6HeaderIter.Next() + if err != nil { + t.Fatalf("ipv6HeaderIter.Next(): %s", err) + } + if done { + t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done) + } + if t, ok := h.(header.IPv6RawPayloadHeader); ok { + transport = t + break + } + } + + if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber { + t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber) + } + icmpv6 := header.ICMPv6(transport.Buf.ToView()) + if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone { + t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone) + } + addr := header.MLD(icmpv6.MessageBody()).MulticastAddress() + s, ok := seen[addr] + if !ok { + t.Fatalf("unexpectedly got a packet for group %s", addr) + } + if s { + t.Fatalf("already saw packet for group %s", addr) + } + seen[addr] = true + return addr + }, + checkInitialGroups: checkInitialIPv6Groups, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } + + sentReportStat := test.sentReportStat(s) + for _, a := range test.multicastAddrs { + if err := s.JoinGroup(test.protoNum, nicID, a); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err) + } + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatalf("expected a report message to be sent for %s", a) + } else { + test.validateReport(t, p, a) + } + } + if t.Failed() { + t.FailNow() + } + + // Leave messages should be sent for the joined groups when the NIC is + // disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + sentLeaveStat := test.sentLeaveStat(s) + leaveCounter += uint64(len(test.multicastAddrs)) + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + { + seen := make(map[tcpip.Address]bool) + for _, a := range test.multicastAddrs { + seen[a] = false + } + + for i := range test.multicastAddrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected (%d-th) leave message to be sent", i) + } + + test.validateLeave(t, p, test.getAndCheckGroupAddress(t, seen, p)) + } + } + if t.Failed() { + t.FailNow() + } + + // Reports should be sent for the joined groups when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("EnableNIC(%d): %s", nicID, err) + } + reportCounter += uint64(len(test.multicastAddrs)) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + { + seen := make(map[tcpip.Address]bool) + for _, a := range test.multicastAddrs { + seen[a] = false + } + + for i := range test.multicastAddrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected (%d-th) report message to be sent", i) + } + + test.validateReport(t, p, test.getAndCheckGroupAddress(t, seen, p)) + } + } + if t.Failed() { + t.FailNow() + } + + // Joining/leaving a group while disabled should not send any messages. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("DisableNIC(%d): %s", nicID, err) + } + leaveCounter += uint64(len(test.multicastAddrs)) + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + for i := range test.multicastAddrs { + if _, ok := e.Read(); !ok { + t.Fatalf("expected (%d-th) leave message to be sent", i) + } + } + for _, a := range test.multicastAddrs { + if err := s.LeaveGroup(test.protoNum, nicID, a); err != nil { + t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, a, err) + } + if got := sentLeaveStat.Value(); got != leaveCounter { + t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); ok { + t.Fatalf("leaving group %s on disabled NIC sent unexpected packet = %#v", a, p.Pkt) + } + } + if err := s.JoinGroup(test.protoNum, nicID, test.finalMulticastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.finalMulticastAddr, err) + } + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); ok { + t.Fatalf("joining group %s on disabled NIC sent unexpected packet = %#v", test.finalMulticastAddr, p.Pkt) + } + + // A report should only be sent for the group we last joined after + // enabling the NIC since the original groups were all left. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("EnableNIC(%d): %s", nicID, err) + } + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p, test.finalMulticastAddr) + } + + clock.Advance(test.maxUnsolicitedResponseDelay) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + test.validateReport(t, p, test.finalMulticastAddr) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + }) + } +} + +// TestMGPDisabledOnLoopback tests that the multicast group protocol is not +// performed on loopback interfaces since they have no neighbours. +func TestMGPDisabledOnLoopback(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New()) + + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + + // Test joining a specific group explicitly and verify that no reports are + // sent. + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + }) + } +} diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go index 7cc52985e..5c3363759 100644 --- a/pkg/tcpip/network/testutil/testutil.go +++ b/pkg/tcpip/network/testutil/testutil.go @@ -85,21 +85,6 @@ func (ep *MockLinkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts st return n, nil } -// WriteRawPacket implements LinkEndpoint.WriteRawPacket. -func (ep *MockLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - if ep.allowPackets == 0 { - return ep.err - } - ep.allowPackets-- - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: vv, - }) - ep.WrittenPackets = append(ep.WrittenPackets, pkt) - - return nil -} - // Attach implements LinkEndpoint.Attach. func (*MockLinkEndpoint) Attach(stack.NetworkDispatcher) {} diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 51d428049..4777163cd 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -44,6 +44,7 @@ import ( "bufio" "fmt" "log" + "math" "math/rand" "net" "os" @@ -200,7 +201,7 @@ func main() { // connection from its side. wq.EventRegister(&waitEntry, waiter.EventIn) for { - v, _, err := ep.Read(nil) + _, err := ep.Read(os.Stdout, math.MaxUint16, tcpip.ReadOptions{}) if err != nil { if err == tcpip.ErrClosedForReceive { break @@ -213,8 +214,6 @@ func main() { log.Fatal("Read() failed:", err) } - - os.Stdout.Write(v) } wq.EventUnregister(&waitEntry) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 8e0ee1cd7..a80fa0474 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -20,8 +20,10 @@ package main import ( + "bytes" "flag" "log" + "math" "math/rand" "net" "os" @@ -54,7 +56,8 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) { defer wq.EventUnregister(&waitEntry) for { - v, _, err := ep.Read(nil) + var buf bytes.Buffer + _, err := ep.Read(&buf, math.MaxUint16, tcpip.ReadOptions{}) if err != nil { if err == tcpip.ErrWouldBlock { <-notifyCh @@ -64,7 +67,7 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) { return } - ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) + ep.Write(tcpip.SlicePayload(buf.Bytes()), tcpip.WriteOptions{}) } } @@ -148,10 +151,6 @@ func main() { log.Fatal(err) } - if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - log.Fatal(err) - } - subnet, err := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", len(addr))), tcpip.AddressMask(strings.Repeat("\x00", len(addr)))) if err != nil { log.Fatal(err) diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go new file mode 100644 index 000000000..f3ad40fdf --- /dev/null +++ b/pkg/tcpip/socketops.go @@ -0,0 +1,520 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/sync" +) + +// SocketOptionsHandler holds methods that help define endpoint specific +// behavior for socket level socket options. These must be implemented by +// endpoints to get notified when socket level options are set. +type SocketOptionsHandler interface { + // OnReuseAddressSet is invoked when SO_REUSEADDR is set for an endpoint. + OnReuseAddressSet(v bool) + + // OnReusePortSet is invoked when SO_REUSEPORT is set for an endpoint. + OnReusePortSet(v bool) + + // OnKeepAliveSet is invoked when SO_KEEPALIVE is set for an endpoint. + OnKeepAliveSet(v bool) + + // OnDelayOptionSet is invoked when TCP_NODELAY is set for an endpoint. + // Note that v will be the inverse of TCP_NODELAY option. + OnDelayOptionSet(v bool) + + // OnCorkOptionSet is invoked when TCP_CORK is set for an endpoint. + OnCorkOptionSet(v bool) + + // LastError is invoked when SO_ERROR is read for an endpoint. + LastError() *Error + + // UpdateLastError updates the endpoint specific last error field. + UpdateLastError(err *Error) + + // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE. + HasNIC(v int32) bool +} + +// DefaultSocketOptionsHandler is an embeddable type that implements no-op +// implementations for SocketOptionsHandler methods. +type DefaultSocketOptionsHandler struct{} + +var _ SocketOptionsHandler = (*DefaultSocketOptionsHandler)(nil) + +// OnReuseAddressSet implements SocketOptionsHandler.OnReuseAddressSet. +func (*DefaultSocketOptionsHandler) OnReuseAddressSet(bool) {} + +// OnReusePortSet implements SocketOptionsHandler.OnReusePortSet. +func (*DefaultSocketOptionsHandler) OnReusePortSet(bool) {} + +// OnKeepAliveSet implements SocketOptionsHandler.OnKeepAliveSet. +func (*DefaultSocketOptionsHandler) OnKeepAliveSet(bool) {} + +// OnDelayOptionSet implements SocketOptionsHandler.OnDelayOptionSet. +func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {} + +// OnCorkOptionSet implements SocketOptionsHandler.OnCorkOptionSet. +func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {} + +// LastError implements SocketOptionsHandler.LastError. +func (*DefaultSocketOptionsHandler) LastError() *Error { + return nil +} + +// UpdateLastError implements SocketOptionsHandler.UpdateLastError. +func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {} + +// HasNIC implements SocketOptionsHandler.HasNIC. +func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { + return false +} + +// SocketOptions contains all the variables which store values for SOL_SOCKET, +// SOL_IP, SOL_IPV6 and SOL_TCP level options. +// +// +stateify savable +type SocketOptions struct { + handler SocketOptionsHandler + + // These fields are accessed and modified using atomic operations. + + // broadcastEnabled determines whether datagram sockets are allowed to + // send packets to a broadcast address. + broadcastEnabled uint32 + + // passCredEnabled determines whether SCM_CREDENTIALS socket control + // messages are enabled. + passCredEnabled uint32 + + // noChecksumEnabled determines whether UDP checksum is disabled while + // transmitting for this socket. + noChecksumEnabled uint32 + + // reuseAddressEnabled determines whether Bind() should allow reuse of + // local address. + reuseAddressEnabled uint32 + + // reusePortEnabled determines whether to permit multiple sockets to be + // bound to an identical socket address. + reusePortEnabled uint32 + + // keepAliveEnabled determines whether TCP keepalive is enabled for this + // socket. + keepAliveEnabled uint32 + + // multicastLoopEnabled determines whether multicast packets sent over a + // non-loopback interface will be looped back. + multicastLoopEnabled uint32 + + // receiveTOSEnabled is used to specify if the TOS ancillary message is + // passed with incoming packets. + receiveTOSEnabled uint32 + + // receiveTClassEnabled is used to specify if the IPV6_TCLASS ancillary + // message is passed with incoming packets. + receiveTClassEnabled uint32 + + // receivePacketInfoEnabled is used to specify if more inforamtion is + // provided with incoming packets such as interface index and address. + receivePacketInfoEnabled uint32 + + // hdrIncludeEnabled is used to indicate for a raw endpoint that all packets + // being written have an IP header and the endpoint should not attach an IP + // header. + hdrIncludedEnabled uint32 + + // v6OnlyEnabled is used to determine whether an IPv6 socket is to be + // restricted to sending and receiving IPv6 packets only. + v6OnlyEnabled uint32 + + // quickAckEnabled is used to represent the value of TCP_QUICKACK option. + // It currently does not have any effect on the TCP endpoint. + quickAckEnabled uint32 + + // delayOptionEnabled is used to specify if data should be sent out immediately + // by the transport protocol. For TCP, it determines if the Nagle algorithm + // is on or off. + delayOptionEnabled uint32 + + // corkOptionEnabled is used to specify if data should be held until segments + // are full by the TCP transport protocol. + corkOptionEnabled uint32 + + // receiveOriginalDstAddress is used to specify if the original destination of + // the incoming packet should be returned as an ancillary message. + receiveOriginalDstAddress uint32 + + // recvErrEnabled determines whether extended reliable error message passing + // is enabled. + recvErrEnabled uint32 + + // errQueue is the per-socket error queue. It is protected by errQueueMu. + errQueueMu sync.Mutex `state:"nosave"` + errQueue sockErrorList + + // bindToDevice determines the device to which the socket is bound. + bindToDevice int32 + + // mu protects the access to the below fields. + mu sync.Mutex `state:"nosave"` + + // linger determines the amount of time the socket should linger before + // close. We currently implement this option for TCP socket only. + linger LingerOption +} + +// InitHandler initializes the handler. This must be called before using the +// socket options utility. +func (so *SocketOptions) InitHandler(handler SocketOptionsHandler) { + so.handler = handler +} + +func storeAtomicBool(addr *uint32, v bool) { + var val uint32 + if v { + val = 1 + } + atomic.StoreUint32(addr, val) +} + +// SetLastError sets the last error for a socket. +func (so *SocketOptions) SetLastError(err *Error) { + so.handler.UpdateLastError(err) +} + +// GetBroadcast gets value for SO_BROADCAST option. +func (so *SocketOptions) GetBroadcast() bool { + return atomic.LoadUint32(&so.broadcastEnabled) != 0 +} + +// SetBroadcast sets value for SO_BROADCAST option. +func (so *SocketOptions) SetBroadcast(v bool) { + storeAtomicBool(&so.broadcastEnabled, v) +} + +// GetPassCred gets value for SO_PASSCRED option. +func (so *SocketOptions) GetPassCred() bool { + return atomic.LoadUint32(&so.passCredEnabled) != 0 +} + +// SetPassCred sets value for SO_PASSCRED option. +func (so *SocketOptions) SetPassCred(v bool) { + storeAtomicBool(&so.passCredEnabled, v) +} + +// GetNoChecksum gets value for SO_NO_CHECK option. +func (so *SocketOptions) GetNoChecksum() bool { + return atomic.LoadUint32(&so.noChecksumEnabled) != 0 +} + +// SetNoChecksum sets value for SO_NO_CHECK option. +func (so *SocketOptions) SetNoChecksum(v bool) { + storeAtomicBool(&so.noChecksumEnabled, v) +} + +// GetReuseAddress gets value for SO_REUSEADDR option. +func (so *SocketOptions) GetReuseAddress() bool { + return atomic.LoadUint32(&so.reuseAddressEnabled) != 0 +} + +// SetReuseAddress sets value for SO_REUSEADDR option. +func (so *SocketOptions) SetReuseAddress(v bool) { + storeAtomicBool(&so.reuseAddressEnabled, v) + so.handler.OnReuseAddressSet(v) +} + +// GetReusePort gets value for SO_REUSEPORT option. +func (so *SocketOptions) GetReusePort() bool { + return atomic.LoadUint32(&so.reusePortEnabled) != 0 +} + +// SetReusePort sets value for SO_REUSEPORT option. +func (so *SocketOptions) SetReusePort(v bool) { + storeAtomicBool(&so.reusePortEnabled, v) + so.handler.OnReusePortSet(v) +} + +// GetKeepAlive gets value for SO_KEEPALIVE option. +func (so *SocketOptions) GetKeepAlive() bool { + return atomic.LoadUint32(&so.keepAliveEnabled) != 0 +} + +// SetKeepAlive sets value for SO_KEEPALIVE option. +func (so *SocketOptions) SetKeepAlive(v bool) { + storeAtomicBool(&so.keepAliveEnabled, v) + so.handler.OnKeepAliveSet(v) +} + +// GetMulticastLoop gets value for IP_MULTICAST_LOOP option. +func (so *SocketOptions) GetMulticastLoop() bool { + return atomic.LoadUint32(&so.multicastLoopEnabled) != 0 +} + +// SetMulticastLoop sets value for IP_MULTICAST_LOOP option. +func (so *SocketOptions) SetMulticastLoop(v bool) { + storeAtomicBool(&so.multicastLoopEnabled, v) +} + +// GetReceiveTOS gets value for IP_RECVTOS option. +func (so *SocketOptions) GetReceiveTOS() bool { + return atomic.LoadUint32(&so.receiveTOSEnabled) != 0 +} + +// SetReceiveTOS sets value for IP_RECVTOS option. +func (so *SocketOptions) SetReceiveTOS(v bool) { + storeAtomicBool(&so.receiveTOSEnabled, v) +} + +// GetReceiveTClass gets value for IPV6_RECVTCLASS option. +func (so *SocketOptions) GetReceiveTClass() bool { + return atomic.LoadUint32(&so.receiveTClassEnabled) != 0 +} + +// SetReceiveTClass sets value for IPV6_RECVTCLASS option. +func (so *SocketOptions) SetReceiveTClass(v bool) { + storeAtomicBool(&so.receiveTClassEnabled, v) +} + +// GetReceivePacketInfo gets value for IP_PKTINFO option. +func (so *SocketOptions) GetReceivePacketInfo() bool { + return atomic.LoadUint32(&so.receivePacketInfoEnabled) != 0 +} + +// SetReceivePacketInfo sets value for IP_PKTINFO option. +func (so *SocketOptions) SetReceivePacketInfo(v bool) { + storeAtomicBool(&so.receivePacketInfoEnabled, v) +} + +// GetHeaderIncluded gets value for IP_HDRINCL option. +func (so *SocketOptions) GetHeaderIncluded() bool { + return atomic.LoadUint32(&so.hdrIncludedEnabled) != 0 +} + +// SetHeaderIncluded sets value for IP_HDRINCL option. +func (so *SocketOptions) SetHeaderIncluded(v bool) { + storeAtomicBool(&so.hdrIncludedEnabled, v) +} + +// GetV6Only gets value for IPV6_V6ONLY option. +func (so *SocketOptions) GetV6Only() bool { + return atomic.LoadUint32(&so.v6OnlyEnabled) != 0 +} + +// SetV6Only sets value for IPV6_V6ONLY option. +// +// Preconditions: the backing TCP or UDP endpoint must be in initial state. +func (so *SocketOptions) SetV6Only(v bool) { + storeAtomicBool(&so.v6OnlyEnabled, v) +} + +// GetQuickAck gets value for TCP_QUICKACK option. +func (so *SocketOptions) GetQuickAck() bool { + return atomic.LoadUint32(&so.quickAckEnabled) != 0 +} + +// SetQuickAck sets value for TCP_QUICKACK option. +func (so *SocketOptions) SetQuickAck(v bool) { + storeAtomicBool(&so.quickAckEnabled, v) +} + +// GetDelayOption gets inverted value for TCP_NODELAY option. +func (so *SocketOptions) GetDelayOption() bool { + return atomic.LoadUint32(&so.delayOptionEnabled) != 0 +} + +// SetDelayOption sets inverted value for TCP_NODELAY option. +func (so *SocketOptions) SetDelayOption(v bool) { + storeAtomicBool(&so.delayOptionEnabled, v) + so.handler.OnDelayOptionSet(v) +} + +// GetCorkOption gets value for TCP_CORK option. +func (so *SocketOptions) GetCorkOption() bool { + return atomic.LoadUint32(&so.corkOptionEnabled) != 0 +} + +// SetCorkOption sets value for TCP_CORK option. +func (so *SocketOptions) SetCorkOption(v bool) { + storeAtomicBool(&so.corkOptionEnabled, v) + so.handler.OnCorkOptionSet(v) +} + +// GetReceiveOriginalDstAddress gets value for IP(V6)_RECVORIGDSTADDR option. +func (so *SocketOptions) GetReceiveOriginalDstAddress() bool { + return atomic.LoadUint32(&so.receiveOriginalDstAddress) != 0 +} + +// SetReceiveOriginalDstAddress sets value for IP(V6)_RECVORIGDSTADDR option. +func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) { + storeAtomicBool(&so.receiveOriginalDstAddress, v) +} + +// GetRecvError gets value for IP*_RECVERR option. +func (so *SocketOptions) GetRecvError() bool { + return atomic.LoadUint32(&so.recvErrEnabled) != 0 +} + +// SetRecvError sets value for IP*_RECVERR option. +func (so *SocketOptions) SetRecvError(v bool) { + storeAtomicBool(&so.recvErrEnabled, v) + if !v { + so.pruneErrQueue() + } +} + +// GetLastError gets value for SO_ERROR option. +func (so *SocketOptions) GetLastError() *Error { + return so.handler.LastError() +} + +// GetOutOfBandInline gets value for SO_OOBINLINE option. +func (*SocketOptions) GetOutOfBandInline() bool { + return true +} + +// SetOutOfBandInline sets value for SO_OOBINLINE option. We currently do not +// support disabling this option. +func (*SocketOptions) SetOutOfBandInline(bool) {} + +// GetLinger gets value for SO_LINGER option. +func (so *SocketOptions) GetLinger() LingerOption { + so.mu.Lock() + linger := so.linger + so.mu.Unlock() + return linger +} + +// SetLinger sets value for SO_LINGER option. +func (so *SocketOptions) SetLinger(linger LingerOption) { + so.mu.Lock() + so.linger = linger + so.mu.Unlock() +} + +// SockErrOrigin represents the constants for error origin. +type SockErrOrigin uint8 + +const ( + // SockExtErrorOriginNone represents an unknown error origin. + SockExtErrorOriginNone SockErrOrigin = iota + + // SockExtErrorOriginLocal indicates a local error. + SockExtErrorOriginLocal + + // SockExtErrorOriginICMP indicates an IPv4 ICMP error. + SockExtErrorOriginICMP + + // SockExtErrorOriginICMP6 indicates an IPv6 ICMP error. + SockExtErrorOriginICMP6 +) + +// IsICMPErr indicates if the error originated from an ICMP error. +func (origin SockErrOrigin) IsICMPErr() bool { + return origin == SockExtErrorOriginICMP || origin == SockExtErrorOriginICMP6 +} + +// SockError represents a queue entry in the per-socket error queue. +// +// +stateify savable +type SockError struct { + sockErrorEntry + + // Err is the error caused by the errant packet. + Err *Error + // ErrOrigin indicates the error origin. + ErrOrigin SockErrOrigin + // ErrType is the type in the ICMP header. + ErrType uint8 + // ErrCode is the code in the ICMP header. + ErrCode uint8 + // ErrInfo is additional info about the error. + ErrInfo uint32 + + // Payload is the errant packet's payload. + Payload []byte + // Dst is the original destination address of the errant packet. + Dst FullAddress + // Offender is the original sender address of the errant packet. + Offender FullAddress + // NetProto is the network protocol being used to transmit the packet. + NetProto NetworkProtocolNumber +} + +// pruneErrQueue resets the queue. +func (so *SocketOptions) pruneErrQueue() { + so.errQueueMu.Lock() + so.errQueue.Reset() + so.errQueueMu.Unlock() +} + +// DequeueErr dequeues a socket extended error from the error queue and returns +// it. Returns nil if queue is empty. +func (so *SocketOptions) DequeueErr() *SockError { + so.errQueueMu.Lock() + defer so.errQueueMu.Unlock() + + err := so.errQueue.Front() + if err != nil { + so.errQueue.Remove(err) + } + return err +} + +// PeekErr returns the error in the front of the error queue. Returns nil if +// the error queue is empty. +func (so *SocketOptions) PeekErr() *SockError { + so.errQueueMu.Lock() + defer so.errQueueMu.Unlock() + return so.errQueue.Front() +} + +// QueueErr inserts the error at the back of the error queue. +// +// Preconditions: so.GetRecvError() == true. +func (so *SocketOptions) QueueErr(err *SockError) { + so.errQueueMu.Lock() + defer so.errQueueMu.Unlock() + so.errQueue.PushBack(err) +} + +// QueueLocalErr queues a local error onto the local queue. +func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) { + so.QueueErr(&SockError{ + Err: err, + ErrOrigin: SockExtErrorOriginLocal, + ErrInfo: info, + Payload: payload, + Dst: dst, + NetProto: net, + }) +} + +// GetBindToDevice gets value for SO_BINDTODEVICE option. +func (so *SocketOptions) GetBindToDevice() int32 { + return atomic.LoadInt32(&so.bindToDevice) +} + +// SetBindToDevice sets value for SO_BINDTODEVICE option. +func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error { + if !so.handler.HasNIC(bindToDevice) { + return ErrUnknownDevice + } + + atomic.StoreInt32(&so.bindToDevice, bindToDevice) + return nil +} diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index d09ebe7fa..bb30556cf 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "most_shards") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -112,7 +112,7 @@ go_test( "transport_demuxer_test.go", "transport_test.go", ], - shard_count = 20, + shard_count = most_shards, deps = [ ":stack", "//pkg/rand", @@ -120,6 +120,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", @@ -131,7 +132,6 @@ go_test( "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", - "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) @@ -148,7 +148,6 @@ go_test( ], library = ":stack", deps = [ - "//pkg/sleep", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 9478f3fb7..cd423bf71 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -var _ GroupAddressableEndpoint = (*AddressableEndpointState)(nil) var _ AddressableEndpoint = (*AddressableEndpointState)(nil) // AddressableEndpointState is an implementation of an AddressableEndpoint. @@ -37,10 +36,6 @@ type AddressableEndpointState struct { endpoints map[tcpip.Address]*addressState primary []*addressState - - // groups holds the mapping between group addresses and the number of times - // they have been joined. - groups map[tcpip.Address]uint32 } } @@ -53,65 +48,33 @@ func (a *AddressableEndpointState) Init(networkEndpoint NetworkEndpoint) { a.mu.Lock() defer a.mu.Unlock() a.mu.endpoints = make(map[tcpip.Address]*addressState) - a.mu.groups = make(map[tcpip.Address]uint32) -} - -// ReadOnlyAddressableEndpointState provides read-only access to an -// AddressableEndpointState. -type ReadOnlyAddressableEndpointState struct { - inner *AddressableEndpointState } -// AddrOrMatching returns an endpoint for the passed address that is consisdered -// bound to the wrapped AddressableEndpointState. +// GetAddress returns the AddressEndpoint for the passed address. // -// If addr is an exact match with an existing address, that address is returned. -// Otherwise, f is called with each address and the address that f returns true -// for is returned. -// -// Returns nil of no address matches. -func (m ReadOnlyAddressableEndpointState) AddrOrMatching(addr tcpip.Address, spoofingOrPrimiscuous bool, f func(AddressEndpoint) bool) AddressEndpoint { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() - - if ep, ok := m.inner.mu.endpoints[addr]; ok { - if ep.IsAssigned(spoofingOrPrimiscuous) && ep.IncRef() { - return ep - } - } - - for _, ep := range m.inner.mu.endpoints { - if ep.IsAssigned(spoofingOrPrimiscuous) && f(ep) && ep.IncRef() { - return ep - } - } - - return nil -} - -// Lookup returns the AddressEndpoint for the passed address. +// GetAddress does not increment the address's reference count or check if the +// address is considered bound to the endpoint. // -// Returns nil if the passed address is not associated with the -// AddressableEndpointState. -func (m ReadOnlyAddressableEndpointState) Lookup(addr tcpip.Address) AddressEndpoint { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() +// Returns nil if the passed address is not associated with the endpoint. +func (a *AddressableEndpointState) GetAddress(addr tcpip.Address) AddressEndpoint { + a.mu.RLock() + defer a.mu.RUnlock() - ep, ok := m.inner.mu.endpoints[addr] + ep, ok := a.mu.endpoints[addr] if !ok { return nil } return ep } -// ForEach calls f for each address pair. +// ForEachEndpoint calls f for each address. // -// If f returns false, f is no longer be called. -func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() +// Once f returns false, f will no longer be called. +func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) { + a.mu.RLock() + defer a.mu.RUnlock() - for _, ep := range m.inner.mu.endpoints { + for _, ep := range a.mu.endpoints { if !f(ep) { return } @@ -120,18 +83,16 @@ func (m ReadOnlyAddressableEndpointState) ForEach(f func(AddressEndpoint) bool) // ForEachPrimaryEndpoint calls f for each primary address. // -// If f returns false, f is no longer be called. -func (m ReadOnlyAddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) { - m.inner.mu.RLock() - defer m.inner.mu.RUnlock() - for _, ep := range m.inner.mu.primary { - f(ep) - } -} +// Once f returns false, f will no longer be called. +func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) { + a.mu.RLock() + defer a.mu.RUnlock() -// ReadOnly returns a readonly reference to a. -func (a *AddressableEndpointState) ReadOnly() ReadOnlyAddressableEndpointState { - return ReadOnlyAddressableEndpointState{inner: a} + for _, ep := range a.mu.primary { + if !f(ep) { + return + } + } } func (a *AddressableEndpointState) releaseAddressState(addrState *addressState) { @@ -335,11 +296,6 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address func (a *AddressableEndpointState) RemovePermanentAddress(addr tcpip.Address) *tcpip.Error { a.mu.Lock() defer a.mu.Unlock() - - if _, ok := a.mu.groups[addr]; ok { - panic(fmt.Sprintf("group address = %s must be removed with LeaveGroup", addr)) - } - return a.removePermanentAddressLocked(addr) } @@ -471,8 +427,19 @@ func (a *AddressableEndpointState) acquirePrimaryAddressRLocked(isValid func(*ad return deprecatedEndpoint } -// AcquireAssignedAddress implements AddressableEndpoint. -func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { +// AcquireAssignedAddressOrMatching returns an address endpoint that is +// considered assigned to the addressable endpoint. +// +// If the address is an exact match with an existing address, that address is +// returned. Otherwise, if f is provided, f is called with each address and +// the address that f returns true for is returned. +// +// If there is no matching address, a temporary address will be returned if +// allowTemp is true. +// +// Regardless how the address was obtained, it will be acquired before it is +// returned. +func (a *AddressableEndpointState) AcquireAssignedAddressOrMatching(localAddr tcpip.Address, f func(AddressEndpoint) bool, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { a.mu.Lock() defer a.mu.Unlock() @@ -488,6 +455,14 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres return addrState } + if f != nil { + for _, addrState := range a.mu.endpoints { + if addrState.IsAssigned(allowTemp) && f(addrState) && addrState.IncRef() { + return addrState + } + } + } + if !allowTemp { return nil } @@ -520,6 +495,11 @@ func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Addres return ep } +// AcquireAssignedAddress implements AddressableEndpoint. +func (a *AddressableEndpointState) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp bool, tempPEB PrimaryEndpointBehavior) AddressEndpoint { + return a.AcquireAssignedAddressOrMatching(localAddr, nil, allowTemp, tempPEB) +} + // AcquireOutgoingPrimaryAddress implements AddressableEndpoint. func (a *AddressableEndpointState) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) AddressEndpoint { a.mu.RLock() @@ -588,72 +568,11 @@ func (a *AddressableEndpointState) PermanentAddresses() []tcpip.AddressWithPrefi return addrs } -// JoinGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) JoinGroup(group tcpip.Address) (bool, *tcpip.Error) { - a.mu.Lock() - defer a.mu.Unlock() - - joins, ok := a.mu.groups[group] - if !ok { - ep, err := a.addAndAcquireAddressLocked(group.WithPrefix(), NeverPrimaryEndpoint, AddressConfigStatic, false /* deprecated */, true /* permanent */) - if err != nil { - return false, err - } - // We have no need for the address endpoint. - a.decAddressRefLocked(ep) - } - - a.mu.groups[group] = joins + 1 - return !ok, nil -} - -// LeaveGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) { - a.mu.Lock() - defer a.mu.Unlock() - - joins, ok := a.mu.groups[group] - if !ok { - return false, tcpip.ErrBadLocalAddress - } - - if joins == 1 { - a.removeGroupAddressLocked(group) - delete(a.mu.groups, group) - return true, nil - } - - a.mu.groups[group] = joins - 1 - return false, nil -} - -// IsInGroup implements GroupAddressableEndpoint. -func (a *AddressableEndpointState) IsInGroup(group tcpip.Address) bool { - a.mu.RLock() - defer a.mu.RUnlock() - _, ok := a.mu.groups[group] - return ok -} - -func (a *AddressableEndpointState) removeGroupAddressLocked(group tcpip.Address) { - if err := a.removePermanentAddressLocked(group); err != nil { - // removePermanentEndpointLocked would only return an error if group is - // not bound to the addressable endpoint, but we know it MUST be assigned - // since we have group in our map of groups. - panic(fmt.Sprintf("error removing group address = %s: %s", group, err)) - } -} - // Cleanup forcefully leaves all groups and removes all permanent addresses. func (a *AddressableEndpointState) Cleanup() { a.mu.Lock() defer a.mu.Unlock() - for group := range a.mu.groups { - a.removeGroupAddressLocked(group) - } - a.mu.groups = make(map[tcpip.Address]uint32) - for _, ep := range a.mu.endpoints { // removePermanentEndpointLocked returns tcpip.ErrBadLocalAddress if ep is // not a permanent address. diff --git a/pkg/tcpip/stack/addressable_endpoint_state_test.go b/pkg/tcpip/stack/addressable_endpoint_state_test.go index 26787d0a3..140f146f6 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state_test.go +++ b/pkg/tcpip/stack/addressable_endpoint_state_test.go @@ -53,25 +53,9 @@ func TestAddressableEndpointStateCleanup(t *testing.T) { ep.DecRef() } - group := tcpip.Address("\x02") - if added, err := s.JoinGroup(group); err != nil { - t.Fatalf("s.JoinGroup(%s): %s", group, err) - } else if !added { - t.Fatalf("got s.JoinGroup(%s) = false, want = true", group) - } - if !s.IsInGroup(group) { - t.Fatalf("got s.IsInGroup(%s) = false, want = true", group) - } - s.Cleanup() - { - ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint) - if ep != nil { - ep.DecRef() - t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) - } - } - if s.IsInGroup(group) { - t.Fatalf("got s.IsInGroup(%s) = true, want = false", group) + if ep := s.AcquireAssignedAddress(addr.Address, false /* allowTemp */, stack.NeverPrimaryEndpoint); ep != nil { + ep.DecRef() + t.Fatalf("got s.AcquireAssignedAddress(%s, false, NeverPrimaryEndpoint) = %s, want = nil", addr.Address, ep.AddressWithPrefix()) } } diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 9a17efcba..5e649cca6 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -142,19 +142,19 @@ func (cn *conn) timedOut(now time.Time) bool { // update the connection tracking state. // -// Precondition: ct.mu must be held. -func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) { +// Precondition: cn.mu must be held. +func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle // other tcp states. - if ct.tcb.IsEmpty() { - ct.tcb.Init(tcpHeader) - } else if hook == ct.tcbHook { - ct.tcb.UpdateStateOutbound(tcpHeader) + if cn.tcb.IsEmpty() { + cn.tcb.Init(tcpHeader) + } else if hook == cn.tcbHook { + cn.tcb.UpdateStateOutbound(tcpHeader) } else { - ct.tcb.UpdateStateInbound(tcpHeader) + cn.tcb.UpdateStateInbound(tcpHeader) } } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 7a501acdc..93e8e1c51 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -74,8 +74,30 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { } func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { - // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + netHdr := pkt.NetworkHeader().View() + _, dst := f.proto.ParseAddresses(netHdr) + + addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), CanBePrimaryEndpoint) + if addressEndpoint != nil { + addressEndpoint.DecRef() + // Dispatch the packet to the transport protocol. + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(netHdr[protocolNumberOffset]), pkt) + return + } + + r, err := f.proto.stack.FindRoute(0, "", dst, fwdTestNetNumber, false /* multicastLoop */) + if err != nil { + return + } + defer r.Release() + + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + pkt = NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: vv.ToView().ToVectorisedView(), + }) + // TODO(b/143425874) Decrease the TTL field in forwarded packets. + _ = r.WriteHeaderIncludedPacket(pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -106,8 +128,13 @@ func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuf panic("not implemented") } -func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported +func (f *fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error { + // The network header should not already be populated. + if _, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen); !ok { + return tcpip.ErrMalformedHeader + } + + return f.nic.WritePacket(r, nil /* gso */, fwdTestNetNumber, pkt) } func (f *fwdTestNetworkEndpoint) Close() { @@ -117,6 +144,8 @@ func (f *fwdTestNetworkEndpoint) Close() { // fwdTestNetworkProtocol is a network-layer protocol that implements Address // resolution. type fwdTestNetworkProtocol struct { + stack *Stack + addrCache *linkAddrCache neigh *neighborCache addrResolveDelay time.Duration @@ -280,7 +309,7 @@ func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress { func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { p := fwdTestPacketInfo{ - RemoteLinkAddress: r.RemoteLinkAddress, + RemoteLinkAddress: r.RemoteLinkAddress(), LocalLinkAddress: r.LocalLinkAddress, Pkt: pkt, } @@ -304,20 +333,6 @@ func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBuffer return n, nil } -// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. -func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - p := fwdTestPacketInfo{ - Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}), - } - - select { - case e.C <- p: - default: - } - - return nil -} - // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} @@ -334,7 +349,10 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ - NetworkProtocols: []NetworkProtocolFactory{func(*Stack) NetworkProtocol { return proto }}, + NetworkProtocols: []NetworkProtocolFactory{func(s *Stack) NetworkProtocol { + proto.stack = s + return proto + }}, UseNeighborCache: useNeighborCache, }) @@ -542,6 +560,38 @@ func TestForwardingWithNoResolver(t *testing.T) { } } +func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 50 * time.Millisecond, + onLinkAddressResolved: func(*linkAddrCache, *neighborCache, tcpip.Address, tcpip.LinkAddress) { + // Don't resolve the link address. + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto, true /* useNeighborCache */) + + const numPackets int = 5 + // These packets will all be enqueued in the packet queue to wait for link + // address resolution. + for i := 0; i < numPackets; i++ { + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } + + // All packets should fail resolution. + // TODO(gvisor.dev/issue/5141): Use a fake clock. + for i := 0; i < numPackets; i++ { + select { + case got := <-ep2.C: + t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got) + case <-time.After(100 * time.Millisecond): + } + } +} + func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { tests := []struct { name string diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 2d8c883cd..09c7811fa 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -45,13 +45,13 @@ const reaperDelay = 5 * time.Second func DefaultTables() *IPTables { return &IPTables{ v4Tables: [NumTables]Table{ - NATID: Table{ + NATID: { Rules: []Rule{ - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -68,11 +68,11 @@ func DefaultTables() *IPTables { Postrouting: 3, }, }, - MangleID: Table{ + MangleID: { Rules: []Rule{ - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -86,12 +86,12 @@ func DefaultTables() *IPTables { Postrouting: HookUnset, }, }, - FilterID: Table{ + FilterID: { Rules: []Rule{ - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, - Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: HookUnset, @@ -110,13 +110,13 @@ func DefaultTables() *IPTables { }, }, v6Tables: [NumTables]Table{ - NATID: Table{ + NATID: { Rules: []Rule{ - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -133,11 +133,11 @@ func DefaultTables() *IPTables { Postrouting: 3, }, }, - MangleID: Table{ + MangleID: { Rules: []Rule{ - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: 0, @@ -151,12 +151,12 @@ func DefaultTables() *IPTables { Postrouting: HookUnset, }, }, - FilterID: Table{ + FilterID: { Rules: []Rule{ - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, - Rule{Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, + {Target: &ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, }, BuiltinChains: [NumHooks]int{ Prerouting: HookUnset, @@ -175,9 +175,9 @@ func DefaultTables() *IPTables { }, }, priorities: [NumHooks][]TableID{ - Prerouting: []TableID{MangleID, NATID}, - Input: []TableID{NATID, FilterID}, - Output: []TableID{MangleID, NATID, FilterID}, + Prerouting: {MangleID, NATID}, + Input: {NATID, FilterID}, + Output: {MangleID, NATID, FilterID}, }, connections: ConnTrack{ seed: generateRandUint32(), diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 4b86c1be9..56a3e7861 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -56,7 +56,7 @@ const ( // Postrouting happens just before a packet goes out on the wire. Postrouting - // The total number of hooks. + // NumHooks is the total number of hooks. NumHooks ) @@ -273,14 +273,12 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo return true } - // If the interface name ends with '+', any interface which begins - // with the name should be matched. + // If the interface name ends with '+', any interface which + // begins with the name should be matched. ifName := fl.OutputInterface - matches := true + matches := nicName == ifName if strings.HasSuffix(ifName, "+") { matches = strings.HasPrefix(nicName, ifName[:n-1]) - } else { - matches = nicName == ifName } return fl.OutputInterfaceInvert != matches } diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index c9b13cd0e..792f4f170 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -18,7 +18,6 @@ import ( "fmt" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -58,9 +57,6 @@ const ( incomplete entryState = iota // ready means that the address has been resolved and can be used. ready - // failed means that address resolution timed out and the address - // could not be resolved. - failed ) // String implements Stringer. @@ -70,8 +66,6 @@ func (s entryState) String() string { return "incomplete" case ready: return "ready" - case failed: - return "failed" default: return fmt.Sprintf("unknown(%d)", s) } @@ -80,40 +74,48 @@ func (s entryState) String() string { // A linkAddrEntry is an entry in the linkAddrCache. // This struct is thread-compatible. type linkAddrEntry struct { + // linkAddrEntryEntry access is synchronized by the linkAddrCache lock. linkAddrEntryEntry + // TODO(gvisor.dev/issue/5150): move these fields under mu. + // mu protects the fields below. + mu sync.RWMutex + addr tcpip.FullAddress linkAddr tcpip.LinkAddress expiration time.Time s entryState - // wakers is a set of waiters for address resolution result. Anytime - // state transitions out of incomplete these waiters are notified. - wakers map[*sleep.Waker]struct{} - - // done is used to allow callers to wait on address resolution. It is nil iff - // s is incomplete and resolution is not yet in progress. + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. done chan struct{} + + // onResolve is called with the result of address resolution. + onResolve []func(tcpip.LinkAddress, bool) } -// changeState sets the entry's state to ns, notifying any waiters. +func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { + for _, callback := range e.onResolve { + callback(linkAddr, len(linkAddr) != 0) + } + e.onResolve = nil + if ch := e.done; ch != nil { + close(ch) + e.done = nil + } +} + +// changeStateLocked sets the entry's state to ns. // // The entry's expiration is bumped up to the greater of itself and the passed // expiration; the zero value indicates immediate expiration, and is set // unconditionally - this is an implementation detail that allows for entries // to be reused. -func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { - // Notify whoever is waiting on address resolution when transitioning - // out of incomplete. - if e.s == incomplete && ns != incomplete { - for w := range e.wakers { - w.Assert() - } - e.wakers = nil - if ch := e.done; ch != nil { - close(ch) - } - e.done = nil +// +// Precondition: e.mu must be locked +func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { + if e.s == incomplete && ns == ready { + e.notifyCompletionLocked(e.linkAddr) } if expiration.IsZero() || expiration.After(e.expiration) { @@ -122,10 +124,6 @@ func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { e.s = ns } -func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { - delete(e.wakers, w) -} - // add adds a k -> v mapping to the cache. func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { // Calculate expiration time before acquiring the lock, since expiration is @@ -135,10 +133,12 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { c.cache.Lock() entry := c.getOrCreateEntryLocked(k) - entry.linkAddr = v - - entry.changeState(ready, expiration) c.cache.Unlock() + + entry.mu.Lock() + defer entry.mu.Unlock() + entry.linkAddr = v + entry.changeStateLocked(ready, expiration) } // getOrCreateEntryLocked retrieves a cache entry associated with k. The @@ -159,13 +159,14 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt var entry *linkAddrEntry if len(c.cache.table) == linkAddrCacheSize { entry = c.cache.lru.Back() + entry.mu.Lock() delete(c.cache.table, entry.addr) c.cache.lru.Remove(entry) - // Wake waiters and mark the soon-to-be-reused entry as expired. Note - // that the state passed doesn't matter when the zero time is passed. - entry.changeState(failed, time.Time{}) + // Wake waiters and mark the soon-to-be-reused entry as expired. + entry.notifyCompletionLocked("" /* linkAddr */) + entry.mu.Unlock() } else { entry = new(linkAddrEntry) } @@ -180,9 +181,12 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if linkRes != nil { if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { + if onResolve != nil { + onResolve(addr, true) + } return addr, nil, nil } } @@ -190,56 +194,35 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo c.cache.Lock() defer c.cache.Unlock() entry := c.getOrCreateEntryLocked(k) + entry.mu.Lock() + defer entry.mu.Unlock() + switch s := entry.s; s { - case ready, failed: + case ready: if !time.Now().After(entry.expiration) { // Not expired. - switch s { - case ready: - return entry.linkAddr, nil, nil - case failed: - return entry.linkAddr, nil, tcpip.ErrNoLinkAddress - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) + if onResolve != nil { + onResolve(entry.linkAddr, true) } + return entry.linkAddr, nil, nil } - entry.changeState(incomplete, time.Time{}) + entry.changeStateLocked(incomplete, time.Time{}) fallthrough case incomplete: - if waker != nil { - if entry.wakers == nil { - entry.wakers = make(map[*sleep.Waker]struct{}) - } - entry.wakers[waker] = struct{}{} + if onResolve != nil { + entry.onResolve = append(entry.onResolve, onResolve) } - if entry.done == nil { - // Address resolution needs to be initiated. - if linkRes == nil { - return entry.linkAddr, nil, tcpip.ErrNoLinkAddress - } - entry.done = make(chan struct{}) go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } - return entry.linkAddr, entry.done, tcpip.ErrWouldBlock default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } } -// removeWaker removes a waker previously added through get(). -func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { - c.cache.Lock() - defer c.cache.Unlock() - - if entry, ok := c.cache.table[k]; ok { - entry.removeWaker(waker) - } -} - func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check @@ -257,9 +240,9 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link } } -// checkLinkRequest checks whether previous attempt to resolve address has succeeded -// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request -// can stop, false if another request should be sent. +// checkLinkRequest checks whether previous attempt to resolve address has +// succeeded and mark the entry accordingly. Returns true if request can stop, +// false if another request should be sent. func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { c.cache.Lock() defer c.cache.Unlock() @@ -268,16 +251,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att // Entry was evicted from the cache. return true } + entry.mu.Lock() + defer entry.mu.Unlock() + switch s := entry.s; s { - case ready, failed: - // Entry was made ready by resolver or failed. Either way we're done. + case ready: + // Entry was made ready by resolver. case incomplete: if attempt+1 < c.resolutionAttempts { // No response yet, need to send another ARP request. return false } - // Max number of retries reached, mark entry as failed. - entry.changeState(failed, now.Add(c.ageLimit)) + // Max number of retries reached, delete entry. + entry.notifyCompletionLocked("" /* linkAddr */) + delete(c.cache.table, k) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index d2e37f38d..6883045b5 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -21,7 +21,6 @@ import ( "testing" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -50,6 +49,7 @@ type testLinkAddressResolver struct { } func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { + // TODO(gvisor.dev/issue/5141): Use a fake clock. time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { f() @@ -78,16 +78,18 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe } func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { - w := sleep.Waker{} - s := sleep.Sleeper{} - s.AddWaker(&w, 123) - defer s.Done() - + var attemptedResolution bool for { - if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - return got, err + got, ch, err := c.get(addr, linkRes, "", nil, nil) + if err == tcpip.ErrWouldBlock { + if attemptedResolution { + return got, tcpip.ErrNoLinkAddress + } + attemptedResolution = true + <-ch + continue } - s.Fetch(true) + return got, err } } @@ -116,16 +118,19 @@ func TestCacheOverflow(t *testing.T) { } } // The earliest entries should no longer be in the cache. + c.cache.Lock() + defer c.cache.Unlock() for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { e := testAddrs[i] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) + if entry, ok := c.cache.table[e.addr]; ok { + t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) } } } func TestCacheConcurrent(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + linkRes := &testLinkAddressResolver{cache: c} var wg sync.WaitGroup for r := 0; r < 16; r++ { @@ -133,7 +138,6 @@ func TestCacheConcurrent(t *testing.T) { go func() { for _, e := range testAddrs { c.add(e.addr, e.linkAddr) - c.get(e.addr, nil, "", nil, nil) // make work for gotsan } wg.Done() }() @@ -144,7 +148,7 @@ func TestCacheConcurrent(t *testing.T) { // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) } @@ -153,18 +157,22 @@ func TestCacheConcurrent(t *testing.T) { } e = testAddrs[0] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + c.cache.Lock() + defer c.cache.Unlock() + if entry, ok := c.cache.table[e.addr]; ok { + t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) } } func TestCacheAgeLimit(t *testing.T) { c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) + linkRes := &testLinkAddressResolver{cache: c} + e := testAddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err) } } @@ -282,71 +290,3 @@ func TestStaticResolution(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want)) } } - -// TestCacheWaker verifies that RemoveWaker removes a waker previously added -// through get(). -func TestCacheWaker(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - - // First, sanity check that wakers are working. - { - linkRes := &testLinkAddressResolver{cache: c} - s := sleep.Sleeper{} - defer s.Done() - - const wakerID = 1 - w := sleep.Waker{} - s.AddWaker(&w, wakerID) - - e := testAddrs[0] - - if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) - } - id, ok := s.Fetch(true /* block */) - if !ok { - t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)") - } - if id != wakerID { - t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID) - } - - if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil { - t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) - } else if got != e.linkAddr { - t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) - } - } - - // Check that RemoveWaker works. - { - linkRes := &testLinkAddressResolver{cache: c} - s := sleep.Sleeper{} - defer s.Done() - - const wakerID = 2 // different than the ID used in the sanity check - w := sleep.Waker{} - s.AddWaker(&w, wakerID) - - e := testAddrs[1] - linkRes.onLinkAddressRequest = func() { - // Remove the waker before the linkAddrCache has the opportunity to send - // a notification. - c.removeWaker(e.addr, &w) - } - - if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) - } - - if got, err := getBlocking(c, e.addr, linkRes); err != nil { - t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) - } else if got != e.linkAddr { - t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) - } - - if id, ok := s.Fetch(false /* block */); ok { - t.Fatalf("unexpected notification from waker with id %d", id) - } - } -} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 73a01c2dd..61636cae5 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -352,7 +353,7 @@ func TestDADDisabled(t *testing.T) { } // We should not have sent any NDP NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != 0 { t.Fatalf("got NeighborSolicit = %d, want = 0", got) } } @@ -465,14 +466,18 @@ func TestDADResolve(t *testing.T) { if err != tcpip.ErrNoRoute { t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) } - r.Release() + if r != nil { + r.Release() + } } { r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false) if err != tcpip.ErrNoRoute { t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute) } - r.Release() + if r != nil { + r.Release() + } } if t.Failed() { @@ -510,7 +515,9 @@ func TestDADResolve(t *testing.T) { } else if r.LocalAddress != addr1 { t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1) } - r.Release() + if r != nil { + r.Release() + } } if t.Failed() { @@ -518,7 +525,7 @@ func TestDADResolve(t *testing.T) { } // Should not have sent any more NS messages. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) { t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits) } @@ -563,18 +570,18 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborSolicitMinimumSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize)) pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns := header.NDPNeighborSolicit(pkt.MessageBody()) ns.SetTargetAddress(tgt) snmc := header.SolicitedNodeAddr(tgt) pkt.SetChecksum(header.ICMPv6Checksum(pkt, header.IPv6Any, snmc, buffer.VectorisedView{})) payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: 255, + SrcAddr: header.IPv6Any, + DstAddr: snmc, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) } @@ -605,7 +612,7 @@ func TestDADFail(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize) pkt := header.ICMPv6(hdr.Prepend(naSize)) pkt.SetType(header.ICMPv6NeighborAdvert) - na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na := header.NDPNeighborAdvert(pkt.MessageBody()) na.SetSolicitedFlag(true) na.SetOverrideFlag(true) na.SetTargetAddress(tgt) @@ -616,11 +623,11 @@ func TestDADFail(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: tgt, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: 255, + SrcAddr: tgt, + DstAddr: header.IPv6AllNodesMulticastAddress, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) }, @@ -666,7 +673,7 @@ func TestDADFail(t *testing.T) { // Receive a packet to simulate an address conflict. test.rxPkt(e, addr1) - stat := test.getStat(s.Stats().ICMP.V6PacketsReceived) + stat := test.getStat(s.Stats().ICMP.V6.PacketsReceived) if got := stat.Value(); got != 1 { t.Fatalf("got stat = %d, want = 1", got) } @@ -803,7 +810,7 @@ func TestDADStop(t *testing.T) { } // Should not have sent more than 1 NS message. - if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 { + if got := s.Stats().ICMP.V6.PacketsSent.NeighborSolicit.Value(); got > 1 { t.Errorf("got NeighborSolicit = %d, want <= 1", got) } }) @@ -982,7 +989,7 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo pkt := header.ICMPv6(hdr.Prepend(icmpSize)) pkt.SetType(header.ICMPv6RouterAdvert) pkt.SetCode(0) - raPayload := pkt.NDPPayload() + raPayload := pkt.MessageBody() ra := header.NDPRouterAdvert(raPayload) // Populate the Router Lifetime. binary.BigEndian.PutUint16(raPayload[2:], rl) @@ -1004,11 +1011,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo payloadLength := hdr.UsedLength() iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) iph.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: header.NDPHopLimit, - SrcAddr: ip, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: header.NDPHopLimit, + SrcAddr: ip, + DstAddr: header.IPv6AllNodesMulticastAddress, }) return stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -2162,8 +2169,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { NDPConfigs: ipv6.NDPConfigurations{ AutoGenTempGlobalAddresses: true, }, - NDPDisp: &ndpDisp, - AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDisp, + AutoGenLinkLocal: true, })}, }) @@ -2843,9 +2850,7 @@ func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Connect(addr); err != nil { t.Fatalf("ep.Connect(%+v): %s", addr, err) } @@ -2879,9 +2884,7 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Bind(addr); err != nil { t.Fatalf("ep.Bind(%+v): %s", addr, err) } @@ -3250,9 +3253,7 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } + ep.SocketOptions().SetV6Only(true) if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute) @@ -4044,9 +4045,9 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { ndpConfigs.AutoGenAddressConflictRetries = maxRetries s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: ndpConfigs, - NDPDisp: &ndpDisp, + AutoGenLinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: ndpConfigs, + NDPDisp: &ndpDisp, OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(_ tcpip.NICID, nicName string) string { return nicName @@ -4179,9 +4180,9 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { e := channel.New(0, 1280, linkAddr1) s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal, - NDPConfigs: addrType.ndpConfigs, - NDPDisp: &ndpDisp, + AutoGenLinkLocal: addrType.autoGenLinkLocal, + NDPConfigs: addrType.ndpConfigs, + NDPDisp: &ndpDisp, })}, }) if err := s.CreateNIC(nicID, e); err != nil { @@ -4708,7 +4709,7 @@ func TestCleanupNDPState(t *testing.T) { } s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: true, + AutoGenLinkLocal: true, NDPConfigs: ipv6.NDPConfigurations{ HandleRAs: true, DiscoverDefaultRouters: true, @@ -5174,113 +5175,99 @@ func TestRouterSolicitation(t *testing.T) { }, } - // This Run will not return until the parallel tests finish. - // - // We need this because we need to do some teardown work after the - // parallel tests complete. - // - // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for - // more details. - t.Run("group", func(t *testing.T) { - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + e := channelLinkWithHeaderLength{ + Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), + headerLength: test.linkHeaderLen, + } + e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired + waitForPkt := func(timeout time.Duration) { + t.Helper() - e := channelLinkWithHeaderLength{ - Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr), - headerLength: test.linkHeaderLen, + clock.Advance(timeout) + p, ok := e.Read() + if !ok { + t.Fatal("expected router solicitation packet") } - e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired - waitForPkt := func(timeout time.Duration) { - t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - p, ok := e.ReadContext(ctx) - if !ok { - t.Fatal("timed out waiting for packet") - return - } - if p.Proto != header.IPv6ProtocolNumber { - t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) - } + if p.Proto != header.IPv6ProtocolNumber { + t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber) + } - // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { - t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) - } + // Make sure the right remote link address is used. + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) + } - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(header.NDPHopLimit), - checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), - ) + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.expectedSrcAddr), + checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.TTL(header.NDPHopLimit), + checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), + ) - if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { - t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) - } - } - waitForNothing := func(timeout time.Duration) { - t.Helper() - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - if _, ok := e.ReadContext(ctx); ok { - t.Fatal("unexpectedly got a packet") - } - } - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ipv6.NDPConfigurations{ - MaxRtrSolicitations: test.maxRtrSolicit, - RtrSolicitationInterval: test.rtrSolicitInt, - MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, - }, - })}, - }) - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want { + t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want) } + } + waitForNothing := func(timeout time.Duration) { + t.Helper() - if addr := test.nicAddr; addr != "" { - if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) - } + clock.Advance(timeout) + if p, ok := e.Read(); ok { + t.Fatalf("unexpectedly got a packet = %#v", p) } + } + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + MaxRtrSolicitations: test.maxRtrSolicit, + RtrSolicitationInterval: test.rtrSolicitInt, + MaxRtrSolicitationDelay: test.maxRtrSolicitDelay, + }, + })}, + Clock: clock, + }) + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - // Make sure each RS is sent at the right time. - remaining := test.maxRtrSolicit - if remaining > 0 { - waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout) - remaining-- + if addr := test.nicAddr; addr != "" { + if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err) } + } - for ; remaining > 0; remaining-- { - if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { - waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout) - waitForPkt(defaultAsyncPositiveEventTimeout) - } else { - waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout) - } - } + // Make sure each RS is sent at the right time. + remaining := test.maxRtrSolicit + if remaining > 0 { + waitForPkt(test.effectiveMaxRtrSolicitDelay) + remaining-- + } - // Make sure no more RS. - if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { - waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout) + for ; remaining > 0; remaining-- { + if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout { + waitForNothing(test.effectiveRtrSolicitInt - time.Nanosecond) + waitForPkt(time.Nanosecond) } else { - waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout) + waitForPkt(test.effectiveRtrSolicitInt) } + } - // Make sure the counter got properly - // incremented. - if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { - t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) - } - }) - } - }) + // Make sure no more RS. + if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay { + waitForNothing(test.effectiveRtrSolicitInt) + } else { + waitForNothing(test.effectiveMaxRtrSolicitDelay) + } + + if got, want := s.Stats().ICMP.V6.PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want { + t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want) + } + }) + } } func TestStopStartSolicitingRouters(t *testing.T) { diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 177bf5516..c15f10e76 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -17,16 +17,22 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) const neighborCacheSize = 512 // max entries per interface +// NeighborStats holds metrics for the neighbor table. +type NeighborStats struct { + // FailedEntryLookups counts the number of lookups performed on an entry in + // Failed state. + FailedEntryLookups *tcpip.StatCounter +} + // neighborCache maps IP addresses to link addresses. It uses the Least // Recently Used (LRU) eviction strategy to implement a bounded cache for -// dynmically acquired entries. It contains the state machine and configuration +// dynamically acquired entries. It contains the state machine and configuration // for running Neighbor Unreachability Detection (NUD). // // There are two types of entries in the neighbor cache: @@ -92,9 +98,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA n.dynamic.lru.Remove(e) n.dynamic.count-- - e.dispatchRemoveEventLocked() - e.setStateLocked(Unknown) - e.notifyWakersLocked() + e.removeLocked() e.mu.Unlock() } n.cache[remoteAddr] = entry @@ -103,21 +107,27 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA return entry } -// entry looks up the neighbor cache for translating address to link address -// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there -// is a LinkAddressResolver registered with the network protocol, the cache -// attempts to resolve the address and returns ErrWouldBlock. If a Waker is -// provided, it will be notified when address resolution is complete (success -// or not). +// entry looks up neighbor information matching the remote address, and returns +// it if readily available. +// +// Returns ErrWouldBlock if the link address is not readily available, along +// with a notification channel for the caller to block on. Triggers address +// resolution asynchronously. +// +// If onResolve is provided, it will be called either immediately, if resolution +// is not required, or when address resolution is complete, with the resolved +// link address and whether resolution succeeded. After any callbacks have been +// called, the returned notification channel is closed. +// +// NB: if a callback is provided, it should not call into the neighbor cache. // // If specified, the local address must be an address local to the interface the // neighbor cache belongs to. The local address is the source address of a // packet prompting NUD/link address resolution. // -// If address resolution is required, ErrNoLinkAddress and a notification -// channel is returned for the top level caller to block. Channel is closed -// once address resolution is complete (success or not). -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { +// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) { + // TODO(gvisor.dev/issue/5149): Handle static resolution in route.Resolve. if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { e := NeighborEntry{ Addr: remoteAddr, @@ -125,6 +135,9 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA State: Static, UpdatedAtNanos: 0, } + if onResolve != nil { + onResolve(linkAddr, true) + } return e, nil, nil } @@ -142,47 +155,36 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA // of packets to a neighbor. While reasserting a neighbor's reachability, // a node continues sending packets to that neighbor using the cached // link-layer address." + if onResolve != nil { + onResolve(entry.neigh.LinkAddr, true) + } return entry.neigh, nil, nil - case Unknown, Incomplete: - entry.addWakerLocked(w) - + case Unknown, Incomplete, Failed: + if onResolve != nil { + entry.onResolve = append(entry.onResolve, onResolve) + } if entry.done == nil { // Address resolution needs to be initiated. - if linkRes == nil { - return entry.neigh, nil, tcpip.ErrNoLinkAddress - } entry.done = make(chan struct{}) } - entry.handlePacketQueuedLocked(localAddr) return entry.neigh, entry.done, tcpip.ErrWouldBlock - case Failed: - return entry.neigh, nil, tcpip.ErrNoLinkAddress default: panic(fmt.Sprintf("Invalid cache entry state: %s", s)) } } -// removeWaker removes a waker that has been added when link resolution for -// addr was requested. -func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) { - n.mu.Lock() - if entry, ok := n.cache[addr]; ok { - delete(entry.wakers, waker) - } - n.mu.Unlock() -} - // entries returns all entries in the neighbor cache. func (n *neighborCache) entries() []NeighborEntry { - entries := make([]NeighborEntry, 0, len(n.cache)) n.mu.RLock() + defer n.mu.RUnlock() + + entries := make([]NeighborEntry, 0, len(n.cache)) for _, entry := range n.cache { entry.mu.RLock() entries = append(entries, entry.neigh) entry.mu.RUnlock() } - n.mu.RUnlock() return entries } @@ -214,32 +216,13 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd return } - // Notify that resolution has been interrupted, just in case the entry was - // in the Incomplete or Probe state. - entry.dispatchRemoveEventLocked() - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() + entry.removeLocked() entry.mu.Unlock() } n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) } -// removeEntryLocked removes the specified entry from the neighbor cache. -func (n *neighborCache) removeEntryLocked(entry *neighborEntry) { - if entry.neigh.State != Static { - n.dynamic.lru.Remove(entry) - n.dynamic.count-- - } - if entry.neigh.State != Failed { - entry.dispatchRemoveEventLocked() - } - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() - - delete(n.cache, entry.neigh.Addr) -} - // removeEntry removes a dynamic or static entry by address from the neighbor // cache. Returns true if the entry was found and deleted. func (n *neighborCache) removeEntry(addr tcpip.Address) bool { @@ -254,7 +237,13 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool { entry.mu.Lock() defer entry.mu.Unlock() - n.removeEntryLocked(entry) + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } + + entry.removeLocked() + delete(n.cache, entry.neigh.Addr) return true } @@ -265,9 +254,7 @@ func (n *neighborCache) clear() { for _, entry := range n.cache { entry.mu.Lock() - entry.dispatchRemoveEventLocked() - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() + entry.removeLocked() entry.mu.Unlock() } diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index ed33418f3..a2ed6ae2a 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -28,7 +28,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" ) @@ -80,17 +79,20 @@ func entryDiffOptsWithSort() []cmp.Option { func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { config.resetInvalidFields() rng := rand.New(rand.NewSource(time.Now().UnixNano())) - return &neighborCache{ + neigh := &neighborCache{ nic: &NIC{ stack: &Stack{ clock: clock, nudDisp: nudDisp, }, - id: 1, + id: 1, + stats: makeNICStats(), }, state: NewNUDState(config, rng), cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), } + neigh.nic.neigh = neigh + return neigh } // testEntryStore contains a set of IP to NeighborEntry mappings. @@ -187,15 +189,18 @@ type testNeighborResolver struct { entries *testEntryStore delay time.Duration onLinkAddressRequest func() + dropReplies bool } var _ LinkAddressResolver = (*testNeighborResolver)(nil) func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { - // Delay handling the request to emulate network latency. - r.clock.AfterFunc(r.delay, func() { - r.fakeRequest(targetAddr) - }) + if !r.dropReplies { + // Delay handling the request to emulate network latency. + r.clock.AfterFunc(r.delay, func() { + r.fakeRequest(targetAddr) + }) + } // Execute post address resolution action, if available. if f := r.onLinkAddressRequest; f != nil { @@ -288,10 +293,10 @@ func TestNeighborCacheEntry(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -324,7 +329,7 @@ func TestNeighborCacheEntry(t *testing.T) { } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -351,11 +356,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -410,7 +415,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } } @@ -458,7 +463,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { return fmt.Errorf("c.store.entry(%d) not found", i) } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) @@ -510,7 +515,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { } // Expect to find only the most recent entries. The order of entries reported - // by entries() is undeterministic, so entries have to be sorted before + // by entries() is nondeterministic, so entries have to be sorted before // comparison. wantUnsortedEntries := opts.wantStaticEntries for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { @@ -572,10 +577,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ @@ -647,7 +652,7 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -691,7 +696,7 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -753,7 +758,7 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -823,10 +828,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -904,150 +909,6 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { } } -func TestNeighborCacheNotifiesWaker(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } - - w := sleep.Waker{} - s := sleep.Sleeper{} - const wakerID = 1 - s.AddWaker(&w, wakerID) - - entry, ok := store.entry(0) - if !ok { - t.Fatalf("store.entry(0) not found") - } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) - } - clock.Advance(typicalLatency) - - select { - case <-doneCh: - default: - t.Fatal("expected notification from done channel") - } - - id, ok := s.Fetch(false /* block */) - if !ok { - t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr) - } - if id != wakerID { - t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } -} - -func TestNeighborCacheRemoveWaker(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } - - w := sleep.Waker{} - s := sleep.Sleeper{} - const wakerID = 1 - s.AddWaker(&w, wakerID) - - entry, ok := store.entry(0) - if !ok { - t.Fatalf("store.entry(0) not found") - } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) - } - - // Remove the waker before the neighbor cache has the opportunity to send a - // notification. - neigh.removeWaker(entry.Addr, &w) - clock.Advance(typicalLatency) - - select { - case <-doneCh: - default: - t.Fatal("expected notification from done channel") - } - - if id, ok := s.Fetch(false /* block */); ok { - t.Errorf("unexpected notification from waker with id %d", id) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } -} - func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { config := DefaultNUDConfigurations() // Stay in Reachable so the cache can overflow @@ -1059,12 +920,12 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) if err != nil { - t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1072,7 +933,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { State: Static, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ @@ -1126,10 +987,10 @@ func TestNeighborCacheClear(t *testing.T) { // Add a dynamic entry. entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -1184,7 +1045,7 @@ func TestNeighborCacheClear(t *testing.T) { } } - // Clear shoud remove both dynamic and static entries. + // Clear should remove both dynamic and static entries. neigh.clear() // Remove events dispatched from clear() have no deterministic order so they @@ -1231,10 +1092,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -1315,7 +1176,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { frequentlyUsedEntry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } // The following logic is very similar to overflowCache, but @@ -1327,15 +1188,22 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { @@ -1370,7 +1238,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err) } } @@ -1378,15 +1246,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy @@ -1432,7 +1308,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Expect to find only the frequently used entry and the most recent entries. - // The order of entries reported by entries() is undeterministic, so entries + // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. wantUnsortedEntries := []NeighborEntry{ { @@ -1491,12 +1367,12 @@ func TestNeighborCacheConcurrent(t *testing.T) { go func(entry NeighborEntry) { defer wg.Done() if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) } }(entry) } - // Wait for all gorountines to send a request + // Wait for all goroutines to send a request wg.Wait() // Process all the requests for a single entry concurrently @@ -1506,7 +1382,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { // All goroutines add in the same order and add more values than can fit in // the cache. Our eviction strategy requires that the last entries are // present, up to the size of the neighbor cache, and the rest are missing. - // The order of entries reported by entries() is undeterministic, so entries + // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. var wantUnsortedEntries []NeighborEntry for i := store.size() - neighborCacheSize; i < store.size(); i++ { @@ -1544,27 +1420,32 @@ func TestNeighborCacheReplace(t *testing.T) { // Add an entry entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } // Verify the entry exists { - e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) - } - if doneCh != nil { - t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh) + t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) } if t.Failed() { t.FailNow() @@ -1575,7 +1456,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } @@ -1584,7 +1465,7 @@ func TestNeighborCacheReplace(t *testing.T) { { entry, ok := store.entry(1) if !ok { - t.Fatalf("store.entry(1) not found") + t.Fatal("store.entry(1) not found") } updatedLinkAddr = entry.LinkAddr } @@ -1601,7 +1482,7 @@ func TestNeighborCacheReplace(t *testing.T) { { e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1609,7 +1490,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Delay, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } @@ -1619,7 +1500,7 @@ func TestNeighborCacheReplace(t *testing.T) { e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1627,7 +1508,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } } @@ -1651,18 +1532,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { }, } - // First, sanity check that resolution is working entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + + // First, sanity check that resolution is working + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + clock.Advance(typicalLatency) + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } - clock.Advance(typicalLatency) + got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1670,20 +1568,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { State: Reachable, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } - // Verify that address resolution for an unknown address returns ErrNoLinkAddress + // Verify address resolution fails for an unknown address. before := atomic.LoadUint32(&requestCount) entry.Addr += "2" - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) - clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) + clock.Advance(waitFor) + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } maxAttempts := neigh.config().MaxUnicastProbes @@ -1711,15 +1624,129 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) + + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } +} + +// TestNeighborCacheRetryResolution simulates retrying communication after +// failing to perform address resolution. +func TestNeighborCacheRetryResolution(t *testing.T) { + config := DefaultNUDConfigurations() + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + // Simulate a faulty link. + dropReplies: true, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatal("store.entry(0) not found") + } + + // Perform address resolution with a faulty link, which will fail. + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) + clock.Advance(waitFor) + + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } + } + + // Verify the entry is in Failed state. + wantEntries := []NeighborEntry{ + { + Addr: entry.Addr, + LinkAddr: "", + State: Failed, + }, + } + if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { + t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // Retry address resolution with a working link. + linkRes.dropReplies = false + { + incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + if incompleteEntry.State != Incomplete { + t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) + } + clock.Advance(typicalLatency) + + select { + case <-ch: + if !ok { + t.Fatal("expected successful address resolution") + } + reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if err != nil { + t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) + } + if reachableEntry.Addr != entry.Addr { + t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr) + } + if reachableEntry.LinkAddr != entry.LinkAddr { + t.Fatalf("got entry.LinkAddr = %s, want = %s", reachableEntry.LinkAddr, entry.LinkAddr) + } + if reachableEntry.State != Reachable { + t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String()) + } + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } } @@ -1739,7 +1766,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) { got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", testEntryBroadcastAddr, err) } want := NeighborEntry{ Addr: testEntryBroadcastAddr, @@ -1747,7 +1774,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) { State: Static, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) } } @@ -1772,12 +1799,23 @@ func BenchmarkCacheClear(b *testing.B) { if !ok { b.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + b.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } - if doneCh != nil { - <-doneCh + + select { + case <-ch: + default: + b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } } diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 493e48031..75afb3001 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -19,7 +19,6 @@ import ( "sync" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -67,8 +66,7 @@ const ( // Static describes entries that have been explicitly added by the user. They // do not expire and are not deleted until explicitly removed. Static - // Failed means traffic should not be sent to this neighbor since attempts of - // reachability have returned inconclusive. + // Failed means recent attempts of reachability have returned inconclusive. Failed ) @@ -93,16 +91,13 @@ type neighborEntry struct { neigh NeighborEntry - // wakers is a set of waiters for address resolution result. Anytime state - // transitions out of incomplete these waiters are notified. It is nil iff - // address resolution is ongoing and no clients are waiting for the result. - wakers map[*sleep.Waker]struct{} - - // done is used to allow callers to wait on address resolution. It is nil - // iff nudState is not Reachable and address resolution is not yet in - // progress. + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. done chan struct{} + // onResolve is called with the result of address resolution. + onResolve []func(tcpip.LinkAddress, bool) + isRouter bool job *tcpip.Job } @@ -143,25 +138,15 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd } } -// addWaker adds w to the list of wakers waiting for address resolution. -// Assumes the entry has already been appropriately locked. -func (e *neighborEntry) addWakerLocked(w *sleep.Waker) { - if w == nil { - return - } - if e.wakers == nil { - e.wakers = make(map[*sleep.Waker]struct{}) - } - e.wakers[w] = struct{}{} -} - -// notifyWakersLocked notifies those waiting for address resolution, whether it -// succeeded or failed. Assumes the entry has already been appropriately locked. -func (e *neighborEntry) notifyWakersLocked() { - for w := range e.wakers { - w.Assert() +// notifyCompletionLocked notifies those waiting for address resolution, with +// the link address if resolution completed successfully. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { + for _, callback := range e.onResolve { + callback(e.neigh.LinkAddr, succeeded) } - e.wakers = nil + e.onResolve = nil if ch := e.done; ch != nil { close(ch) e.done = nil @@ -170,6 +155,8 @@ func (e *neighborEntry) notifyWakersLocked() { // dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has // been added. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchAddEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborAdded(e.nic.id, e.neigh) @@ -178,6 +165,8 @@ func (e *neighborEntry) dispatchAddEventLocked() { // dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry // has changed state or link-layer address. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchChangeEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborChanged(e.nic.id, e.neigh) @@ -186,23 +175,41 @@ func (e *neighborEntry) dispatchChangeEventLocked() { // dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry // has been removed. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchRemoveEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) } } +// cancelJobLocked cancels the currently scheduled action, if there is one. +// Entries in Unknown, Stale, or Static state do not have a scheduled action. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) cancelJobLocked() { + if job := e.job; job != nil { + job.Cancel() + } +} + +// removeLocked prepares the entry for removal. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) removeLocked() { + e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.dispatchRemoveEventLocked() + e.cancelJobLocked() + e.notifyCompletionLocked(false /* succeeded */) +} + // setStateLocked transitions the entry to the specified state immediately. // // Follows the logic defined in RFC 4861 section 7.3.3. // -// e.mu MUST be locked. +// Precondition: e.mu MUST be locked. func (e *neighborEntry) setStateLocked(next NeighborState) { - // Cancel the previously scheduled action, if there is one. Entries in - // Unknown, Stale, or Static state do not have scheduled actions. - if timer := e.job; timer != nil { - timer.Cancel() - } + e.cancelJobLocked() prev := e.neigh.State e.neigh.State = next @@ -257,11 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { e.job.Schedule(immediateDuration) case Failed: - e.notifyWakersLocked() - e.job = e.nic.stack.newJob(&e.mu, func() { - e.nic.neigh.removeEntryLocked(e) - }) - e.job.Schedule(config.UnreachableTime) + e.notifyCompletionLocked(false /* succeeded */) case Unknown, Stale, Static: // Do nothing @@ -275,8 +278,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // being queued for outgoing transmission. // // Follows the logic defined in RFC 4861 section 7.3.3. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { + case Failed: + e.nic.stats.Neighbor.FailedEntryLookups.Increment() + + fallthrough case Unknown: e.neigh.State = Incomplete e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() @@ -309,7 +318,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // implementation may find it convenient in some cases to return errors // to the sender by taking the offending packet, generating an ICMP // error message, and then delivering it (locally) through the generic - // error-handling routines.' - RFC 4861 section 2.1 + // error-handling routines." - RFC 4861 section 2.1 e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return @@ -347,9 +356,8 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { e.setStateLocked(Delay) e.dispatchChangeEventLocked() - case Incomplete, Reachable, Delay, Probe, Static, Failed: + case Incomplete, Reachable, Delay, Probe, Static: // Do nothing - default: panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) } @@ -359,18 +367,30 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // Neighbor Solicitation for ARP or NDP, respectively). // // Follows the logic defined in RFC 4861 section 7.2.3. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // Probes MUST be silently discarded if the target address is tentative, does // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These // checks MUST be done by the NetworkEndpoint. switch e.neigh.State { - case Unknown, Incomplete, Failed: + case Unknown, Failed: e.neigh.LinkAddr = remoteLinkAddr e.setStateLocked(Stale) - e.notifyWakersLocked() e.dispatchAddEventLocked() + case Incomplete: + // "If an entry already exists, and the cached link-layer address + // differs from the one in the received Source Link-Layer option, the + // cached address should be replaced by the received address, and the + // entry's reachability state MUST be set to STALE." + // - RFC 4861 section 7.2.3 + e.neigh.LinkAddr = remoteLinkAddr + e.setStateLocked(Stale) + e.notifyCompletionLocked(true /* succeeded */) + e.dispatchChangeEventLocked() + case Reachable, Delay, Probe: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr @@ -403,6 +423,8 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // not be possible. SEND uses RSA key pairs to produce Cryptographically // Generated Addresses (CGA), as defined in RFC 3972. This ensures that the // claimed source of an NDP message is the owner of the claimed address. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { switch e.neigh.State { case Incomplete: @@ -421,7 +443,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla } e.dispatchChangeEventLocked() e.isRouter = flags.IsRouter - e.notifyWakersLocked() + e.notifyCompletionLocked(true /* succeeded */) // "Note that the Override flag is ignored if the entry is in the // INCOMPLETE state." - RFC 4861 section 7.2.5 @@ -456,7 +478,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla wasReachable := e.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) - e.notifyWakersLocked() + e.notifyCompletionLocked(true /* succeeded */) if !wasReachable { e.dispatchChangeEventLocked() } @@ -494,6 +516,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // handleUpperLevelConfirmationLocked processes an incoming upper-level protocol // (e.g. TCP acknowledgements) reachability confirmation. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleUpperLevelConfirmationLocked() { switch e.neigh.State { case Reachable, Stale, Delay, Probe: diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index c2b763325..ec34ffa5a 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -25,7 +25,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -73,35 +72,36 @@ func eventDiffOptsWithSort() []cmp.Option { // The following unit tests exercise every state transition and verify its // behavior with RFC 4681. // -// | From | To | Cause | Action | Event | -// | ========== | ========== | ========================================== | =============== | ======= | -// | Unknown | Unknown | Confirmation w/ unknown address | | Added | -// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added | -// | Unknown | Stale | Probe w/ unknown address | | Added | -// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed | -// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed | -// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed | -// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed | -// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | | -// | Reachable | Stale | Reachable timer expired | | Changed | -// | Reachable | Stale | Probe or confirmation w/ different address | | Changed | -// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Stale | Stale | Override confirmation | Update LinkAddr | Changed | -// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | -// | Stale | Delay | Packet sent | | Changed | -// | Delay | Reachable | Upper-layer confirmation | | Changed | -// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Delay | Stale | Probe or confirmation w/ different address | | Changed | -// | Delay | Probe | Delay timer expired | Send probe | Changed | -// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed | -// | Probe | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Probe | Stale | Probe or confirmation w/ different address | | Changed | -// | Probe | Probe | Retransmit timer expired | Send probe | Changed | -// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | -// | Failed | | Unreachability timer expired | Delete entry | | +// | From | To | Cause | Update | Action | Event | +// | ========== | ========== | ========================================== | ======== | ===========| ======= | +// | Unknown | Unknown | Confirmation w/ unknown address | | | Added | +// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added | +// | Unknown | Stale | Probe | | | Added | +// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed | +// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed | +// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed | +// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed | +// | Incomplete | Failed | Max probes sent without reply | | Notify | Removed | +// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | | +// | Reachable | Stale | Reachable timer expired | | | Changed | +// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed | +// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Stale | Stale | Override confirmation | LinkAddr | | Changed | +// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed | +// | Stale | Delay | Packet sent | | | Changed | +// | Delay | Reachable | Upper-layer confirmation | | | Changed | +// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Delay | Stale | Probe or confirmation w/ different address | | | Changed | +// | Delay | Probe | Delay timer expired | | Send probe | Changed | +// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed | +// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Probe | Stale | Probe or confirmation w/ different address | | | Changed | +// | Probe | Probe | Retransmit timer expired | | | Changed | +// | Probe | Failed | Max probes sent without reply | | Notify | Removed | +// | Failed | Incomplete | Packet queued | | Send probe | Added | type testEntryEventType uint8 @@ -228,6 +228,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e clock: clock, nudDisp: &disp, }, + stats: makeNICStats(), } nic.networkEndpoints = map[tcpip.NetworkProtocolNumber]NetworkEndpoint{ header.IPv6ProtocolNumber: (&testIPv6Protocol{}).NewEndpoint(&nic, nil, nil, nil), @@ -256,8 +257,8 @@ func TestEntryInitiallyUnknown(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - if got, want := e.neigh.State, Unknown; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Unknown { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown) } e.mu.Unlock() @@ -289,8 +290,8 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { Override: false, IsRouter: false, }) - if got, want := e.neigh.State, Unknown; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Unknown { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown) } e.mu.Unlock() @@ -318,8 +319,8 @@ func TestEntryUnknownToIncomplete(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() @@ -365,8 +366,8 @@ func TestEntryUnknownToStale(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -404,8 +405,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } updatedAtNanos := e.neigh.UpdatedAtNanos e.mu.Unlock() @@ -558,21 +559,15 @@ func TestEntryIncompleteToReachable(t *testing.T) { nudDisp.mu.Unlock() } -// TestEntryAddsAndClearsWakers verifies that wakers are added when -// addWakerLocked is called and cleared when address resolution finishes. In -// this case, address resolution will finish when transitioning from Incomplete -// to Reachable. -func TestEntryAddsAndClearsWakers(t *testing.T) { +func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) - w := sleep.Waker{} - s := sleep.Sleeper{} - s.AddWaker(&w, 123) - defer s.Done() - e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + } e.mu.Unlock() runImmediatelyScheduledJobs(clock) @@ -591,26 +586,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { } e.mu.Lock() - if got := e.wakers; got != nil { - t.Errorf("got e.wakers = %v, want = nil", got) - } - e.addWakerLocked(&w) - if got, want := w.IsAsserted(), false; got != want { - t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) - } - if e.wakers == nil { - t.Error("expected e.wakers to be non-nil") - } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, - IsRouter: false, + IsRouter: true, }) - if e.wakers != nil { - t.Errorf("got e.wakers = %v, want = nil", e.wakers) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } - if got, want := w.IsAsserted(), true; got != want { - t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + if !e.isRouter { + t.Errorf("got e.isRouter = %t, want = true", e.isRouter) } e.mu.Unlock() @@ -641,7 +626,7 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { +func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) @@ -661,22 +646,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { }, } linkRes.mu.Lock() - if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" { + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - linkRes.mu.Unlock() e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, + Solicited: false, Override: false, - IsRouter: true, + IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) - } - if !e.isRouter { - t.Errorf("got e.isRouter = %t, want = true", e.isRouter) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -696,7 +679,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { Entry: NeighborEntry{ Addr: entryTestAddr1, LinkAddr: entryTestLinkAddr1, - State: Reachable, + State: Stale, }, }, } @@ -707,7 +690,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryIncompleteToStale(t *testing.T) { +func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) @@ -734,11 +717,7 @@ func TestEntryIncompleteToStale(t *testing.T) { } e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) + e.handleProbeLocked(entryTestLinkAddr1) if e.neigh.State != Stale { t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } @@ -778,8 +757,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() @@ -839,8 +818,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Failed; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) } e.mu.Unlock() } @@ -883,8 +862,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { Override: false, IsRouter: true, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.isRouter, true; got != want { t.Errorf("got e.isRouter = %t, want = %t", got, want) @@ -930,8 +909,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } e.mu.Unlock() } @@ -1081,8 +1060,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() } @@ -2379,8 +2358,8 @@ func TestEntryDelayToProbe(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) } e.mu.Unlock() @@ -2445,8 +2424,8 @@ func TestEntryDelayToProbe(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.mu.Unlock() } @@ -2503,12 +2482,12 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -2618,16 +2597,16 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -2738,16 +2717,16 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -2834,16 +2813,16 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -2962,16 +2941,16 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -3099,16 +3078,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } e.mu.Unlock() @@ -3433,72 +3412,61 @@ func TestEntryProbeToFailed(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryFailedGetsDeleted(t *testing.T) { +func TestEntryFailedToIncomplete(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 - c.MaxUnicastProbes = 3 e, nudDisp, linkRes, clock := entryTestSetup(c) - // Verify the cache contains the entry. - if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { - t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) - } - + // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in + // their expected state. e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + } e.mu.Unlock() - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } + waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) + clock.Advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The Incomplete-to-Incomplete state transition is tested here by + // verifying that 3 reachability probes were sent. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) + } e.mu.Unlock() - waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime - clock.Advance(waitFor) - { - wantProbes := []entryTestProbeInfo{ - // The next three probe are sent in Probe. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } + e.mu.Lock() + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } + e.mu.Unlock() wantEvents := []testEntryEventInfo{ { @@ -3511,39 +3479,21 @@ func TestEntryFailedGetsDeleted(t *testing.T) { }, }, { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, + EventType: entryTestRemoved, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, }, }, { - EventType: entryTestRemoved, + EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, }, }, } @@ -3552,9 +3502,4 @@ func TestEntryFailedGetsDeleted(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - // Verify the cache no longer contains the entry. - if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok { - t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1) - } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 60c81a3aa..4a34805b5 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -20,7 +20,6 @@ import ( "reflect" "sync/atomic" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -54,18 +53,20 @@ type NIC struct { sync.RWMutex spoofing bool promiscuous bool - // packetEPs is protected by mu, but the contained PacketEndpoint - // values are not. - packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint + // packetEPs is protected by mu, but the contained packetEndpointList are + // not. + packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList } } -// NICStats includes transmitted and received stats. +// NICStats hold statistics for a NIC. type NICStats struct { Tx DirectionStats Rx DirectionStats DisabledRx DirectionStats + + Neighbor NeighborStats } func makeNICStats() NICStats { @@ -80,6 +81,39 @@ type DirectionStats struct { Bytes *tcpip.StatCounter } +type packetEndpointList struct { + mu sync.RWMutex + + // eps is protected by mu, but the contained PacketEndpoint values are not. + eps []PacketEndpoint +} + +func (p *packetEndpointList) add(ep PacketEndpoint) { + p.mu.Lock() + defer p.mu.Unlock() + p.eps = append(p.eps, ep) +} + +func (p *packetEndpointList) remove(ep PacketEndpoint) { + p.mu.Lock() + defer p.mu.Unlock() + for i, epOther := range p.eps { + if epOther == ep { + p.eps = append(p.eps[:i], p.eps[i+1:]...) + break + } + } +} + +// forEach calls fn with each endpoints in p while holding the read lock on p. +func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) { + p.mu.RLock() + defer p.mu.RUnlock() + for _, ep := range p.eps { + fn(ep) + } +} + // newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For @@ -100,7 +134,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) + nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) // Check for Neighbor Unreachability Detection support. var nud NUDHandler @@ -123,11 +157,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // Register supported packet and network endpoint protocols. for _, netProto := range header.Ethertypes { - nic.mu.packetEPs[netProto] = []PacketEndpoint{} + nic.mu.packetEPs[netProto] = new(packetEndpointList) } for _, netProto := range stack.networkProtocols { netNum := netProto.Number() - nic.mu.packetEPs[netNum] = nil + nic.mu.packetEPs[netNum] = new(packetEndpointList) nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } @@ -170,7 +204,7 @@ func (n *NIC) disable() { // // n MUST be locked. func (n *NIC) disableLocked() { - if !n.setEnabled(false) { + if !n.Enabled() { return } @@ -182,6 +216,10 @@ func (n *NIC) disableLocked() { for _, ep := range n.networkEndpoints { ep.Disable() } + + if !n.setEnabled(false) { + panic("should have only done work to disable the NIC if it was enabled") + } } // enable enables n. @@ -232,7 +270,8 @@ func (n *NIC) setPromiscuousMode(enable bool) { n.mu.Unlock() } -func (n *NIC) isPromiscuousMode() bool { +// Promiscuous implements NetworkInterface. +func (n *NIC) Promiscuous() bool { n.mu.RLock() rv := n.mu.promiscuous n.mu.RUnlock() @@ -255,16 +294,18 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb // the same unresolved IP address, and transmit the saved // packet when the address has been resolved. // - // RFC 4861 section 5.2 (for IPv6): - // Once the IP address of the next-hop node is known, the sender - // examines the Neighbor Cache for link-layer information about that - // neighbor. If no entry exists, the sender creates one, sets its state - // to INCOMPLETE, initiates Address Resolution, and then queues the data - // packet pending completion of address resolution. + // RFC 4861 section 7.2.2 (for IPv6): + // While waiting for address resolution to complete, the sender MUST, for + // each neighbor, retain a small queue of packets waiting for address + // resolution to complete. The queue MUST hold at least one packet, and MAY + // contain more. However, the number of queued packets per neighbor SHOULD + // be limited to some small value. When a queue overflows, the new arrival + // SHOULD replace the oldest entry. Once address resolution completes, the + // node transmits any queued packets. if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { - r := r.Clone() - n.stack.linkResQueue.enqueue(ch, &r, protocol, pkt) + r.Acquire() + n.stack.linkResQueue.enqueue(ch, r, protocol, pkt) return nil } return err @@ -276,9 +317,11 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb // WritePacketToRemote implements NetworkInterface. func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { r := Route{ - NetProto: protocol, - RemoteLinkAddress: remoteLinkAddr, + routeInfo: routeInfo{ + NetProto: protocol, + }, } + r.ResolveWith(remoteLinkAddr) return n.writePacket(&r, gso, protocol, pkt) } @@ -320,16 +363,21 @@ func (n *NIC) setSpoofing(enable bool) { // primaryAddress returns an address that can be used to communicate with // remoteAddr. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) AssignableAddressEndpoint { - n.mu.RLock() - spoofing := n.mu.spoofing - n.mu.RUnlock() - ep, ok := n.networkEndpoints[protocol] if !ok { return nil } - return ep.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return nil + } + + n.mu.RLock() + spoofing := n.mu.spoofing + n.mu.RUnlock() + + return addressableEndpoint.AcquireOutgoingPrimaryAddress(remoteAddr, spoofing) } type getAddressBehaviour int @@ -388,11 +436,17 @@ func (n *NIC) getAddressOrCreateTemp(protocol tcpip.NetworkProtocolNumber, addre // getAddressOrCreateTempInner is like getAddressEpOrCreateTemp except a boolean // is passed to indicate whether or not we should generate temporary endpoints. func (n *NIC) getAddressOrCreateTempInner(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, createTemp bool, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { - if ep, ok := n.networkEndpoints[protocol]; ok { - return ep.AcquireAssignedAddress(address, createTemp, peb) + ep, ok := n.networkEndpoints[protocol] + if !ok { + return nil } - return nil + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return nil + } + + return addressableEndpoint.AcquireAssignedAddress(address, createTemp, peb) } // addAddress adds a new address to n, so that it starts accepting packets @@ -403,7 +457,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return tcpip.ErrUnknownProtocol } - addressEndpoint, err := ep.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return tcpip.ErrNotSupported + } + + addressEndpoint, err := addressableEndpoint.AddAndAcquirePermanentAddress(protocolAddress.AddressWithPrefix, peb, AddressConfigStatic, false /* deprecated */) if err == nil { // We have no need for the address endpoint. addressEndpoint.DecRef() @@ -416,7 +475,12 @@ func (n *NIC) addAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { - for _, a := range ep.PermanentAddresses() { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + for _, a := range addressableEndpoint.PermanentAddresses() { addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } @@ -427,7 +491,12 @@ func (n *NIC) allPermanentAddresses() []tcpip.ProtocolAddress { func (n *NIC) primaryAddresses() []tcpip.ProtocolAddress { var addrs []tcpip.ProtocolAddress for p, ep := range n.networkEndpoints { - for _, a := range ep.PrimaryAddresses() { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + for _, a := range addressableEndpoint.PrimaryAddresses() { addrs = append(addrs, tcpip.ProtocolAddress{Protocol: p, AddressWithPrefix: a}) } } @@ -445,13 +514,23 @@ func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWit return tcpip.AddressWithPrefix{} } - return ep.MainAddress() + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + return tcpip.AddressWithPrefix{} + } + + return addressableEndpoint.MainAddress() } // removeAddress removes an address from n. func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { for _, ep := range n.networkEndpoints { - if err := ep.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { + addressableEndpoint, ok := ep.(AddressableEndpoint) + if !ok { + continue + } + + if err := addressableEndpoint.RemovePermanentAddress(addr); err == tcpip.ErrBadLocalAddress { continue } else { return err @@ -469,14 +548,6 @@ func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { return n.neigh.entries(), nil } -func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) { - if n.neigh == nil { - return - } - - n.neigh.removeWaker(addr, w) -} - func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error { if n.neigh == nil { return tcpip.ErrNotSupported @@ -524,8 +595,7 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address return tcpip.ErrNotSupported } - _, err := gep.JoinGroup(addr) - return err + return gep.JoinGroup(addr) } // leaveGroup decrements the count for the given multicast address, and when it @@ -541,11 +611,7 @@ func (n *NIC) leaveGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Addres return tcpip.ErrNotSupported } - if _, err := gep.LeaveGroup(addr); err != nil { - return err - } - - return nil + return gep.LeaveGroup(addr) } // isInGroup returns true if n has joined the multicast group addr. @@ -564,13 +630,6 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { return false } -func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) - defer r.Release() - r.PopulatePacketInfo(pkt) - n.getNetworkEndpoint(protocol).HandlePacket(pkt) -} - // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when // the NIC receives a packet from the link endpoint. @@ -592,7 +651,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp n.stats.Rx.Packets.Increment() n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size())) - netProto, ok := n.stack.networkProtocols[protocol] + networkEndpoint, ok := n.networkEndpoints[protocol] if !ok { n.mu.RUnlock() n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -607,21 +666,26 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 // Are any packet type sockets listening for this network protocol? - packetEPs := n.mu.packetEPs[protocol] - // Add any other packet type sockets that may be listening for all protocols. - packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) + protoEPs := n.mu.packetEPs[protocol] + // Other packet type sockets that are listening for all protocols. + anyEPs := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() - for _, ep := range packetEPs { + + // Deliver to interested packet endpoints without holding NIC lock. + deliverPacketEPs := func(ep PacketEndpoint) { p := pkt.Clone() p.PktType = tcpip.PacketHost ep.HandlePacket(n.id, local, protocol, p) } - - if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { - n.stack.stats.IP.PacketsReceived.Increment() + if protoEPs != nil { + protoEPs.forEach(deliverPacketEPs) + } + if anyEPs != nil { + anyEPs.forEach(deliverPacketEPs) } // Parse headers. + netProto := n.stack.NetworkProtocolInstance(protocol) transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) if !ok { // The packet is too small to contain a network header. @@ -636,9 +700,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) - if n.stack.handleLocal && !n.IsLoopback() { + src, _ := netProto.ParseAddresses(pkt.NetworkHeader().View()) if r := n.getAddress(protocol, src); r != nil { r.DecRef() @@ -651,78 +714,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } - // Loopback traffic skips the prerouting chain. - if !n.IsLoopback() { - // iptables filtering. - ipt := n.stack.IPTables() - address := n.primaryAddress(protocol) - if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { - // iptables is telling us to drop the packet. - n.stack.stats.IP.IPTablesPreroutingDropped.Increment() - return - } - } - - if addressEndpoint := n.getAddress(protocol, dst); addressEndpoint != nil { - n.handlePacket(protocol, dst, src, remote, addressEndpoint, pkt) - return - } - - // This NIC doesn't care about the packet. Find a NIC that cares about the - // packet and forward it to the NIC. - // - // TODO: Should we be forwarding the packet even if promiscuous? - if n.stack.Forwarding(protocol) { - r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */) - if err != nil { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - return - } - - // Found a NIC. - n := r.localAddressNIC - if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { - if n.isValidForOutgoing(addressEndpoint) { - pkt.NICID = n.ID() - r.RemoteAddress = src - pkt.NetworkPacketInfo = r.networkPacketInfo() - n.getNetworkEndpoint(protocol).HandlePacket(pkt) - addressEndpoint.DecRef() - r.Release() - return - } - - addressEndpoint.DecRef() - } - - // n doesn't have a destination endpoint. - // Send the packet out of n. - // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease - // the TTL field for ipv4/ipv6. - - // pkt may have set its header and may not have enough headroom for - // link-layer header for the other link to prepend. Here we create a new - // packet to forward. - fwdPkt := NewPacketBuffer(PacketBufferOptions{ - ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), - // We need to do a deep copy of the IP packet because WritePacket (and - // friends) take ownership of the packet buffer, but we do not own it. - Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), - }) - - // TODO(b/143425874) Decrease the TTL field in forwarded packets. - if err := n.WritePacket(&r, nil, protocol, fwdPkt); err != nil { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - } - - r.Release() - return - } - - // If a packet socket handled the packet, don't treat it as invalid. - if len(packetEPs) == 0 { - n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment() - } + networkEndpoint.HandlePacket(pkt) } // DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. @@ -731,16 +723,17 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // We do not deliver to protocol specific packet endpoints as on Linux // only ETH_P_ALL endpoints get outbound packets. // Add any other packet sockets that maybe listening for all protocols. - packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + eps := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() - for _, ep := range packetEPs { + + eps.forEach(func(ep PacketEndpoint) { p := pkt.Clone() p.PktType = tcpip.PacketOutgoing // Add the link layer header as outgoing packets are intercepted // before the link layer header is created. n.LinkEndpoint.AddHeader(local, remote, protocol, p) ep.HandlePacket(n.id, local, protocol, p) - } + }) } // DeliverTransportPacket delivers the packets to the appropriate transport @@ -893,7 +886,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa if !ok { return tcpip.ErrNotSupported } - n.mu.packetEPs[netProto] = append(eps, ep) + eps.add(ep) return nil } @@ -906,13 +899,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep if !ok { return } - - for i, epOther := range eps { - if epOther == ep { - n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) - return - } - } + eps.remove(ep) } // isValidForOutgoing returns true if the endpoint can be used to send out a diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index ab629b3a4..12d67409a 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -109,14 +109,6 @@ const ( // // Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10. defaultMaxReachbilityConfirmations = 3 - - // defaultUnreachableTime is the default duration for how long an entry will - // remain in the FAILED state before being removed from the neighbor cache. - // - // Note, there is no equivalent protocol constant defined in RFC 4861. It - // leaves the specifics of any garbage collection mechanism up to the - // implementation. - defaultUnreachableTime = 5 * time.Second ) // NUDDispatcher is the interface integrators of netstack must implement to @@ -278,10 +270,6 @@ type NUDConfigurations struct { // TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD // configuration option is necessary. MaxReachabilityConfirmations uint32 - - // UnreachableTime describes how long an entry will remain in the FAILED - // state before being removed from the neighbor cache. - UnreachableTime time.Duration } // DefaultNUDConfigurations returns a NUDConfigurations populated with default @@ -299,7 +287,6 @@ func DefaultNUDConfigurations() NUDConfigurations { MaxUnicastProbes: defaultMaxUnicastProbes, MaxAnycastDelayTime: defaultMaxAnycastDelayTime, MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations, - UnreachableTime: defaultUnreachableTime, } } @@ -329,9 +316,6 @@ func (c *NUDConfigurations) resetInvalidFields() { if c.MaxUnicastProbes == 0 { c.MaxUnicastProbes = defaultMaxUnicastProbes } - if c.UnreachableTime == 0 { - c.UnreachableTime = defaultUnreachableTime - } } // calcMaxRandomFactor calculates the maximum value of the random factor used @@ -416,7 +400,7 @@ func (s *NUDState) ReachableTime() time.Duration { s.config.BaseReachableTime != s.prevBaseReachableTime || s.config.MinRandomFactor != s.prevMinRandomFactor || s.config.MaxRandomFactor != s.prevMaxRandomFactor { - return s.recomputeReachableTimeLocked() + s.recomputeReachableTimeLocked() } return s.reachableTime } @@ -442,7 +426,7 @@ func (s *NUDState) ReachableTime() time.Duration { // random value gets re-computed at least once every few hours. // // s.mu MUST be locked for writing. -func (s *NUDState) recomputeReachableTimeLocked() time.Duration { +func (s *NUDState) recomputeReachableTimeLocked() { s.prevBaseReachableTime = s.config.BaseReachableTime s.prevMinRandomFactor = s.config.MinRandomFactor s.prevMaxRandomFactor = s.config.MaxRandomFactor @@ -462,5 +446,4 @@ func (s *NUDState) recomputeReachableTimeLocked() time.Duration { } s.expiration = time.Now().Add(2 * time.Hour) - return s.reachableTime } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 8cffb9fc6..7bca1373e 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -37,7 +37,6 @@ const ( defaultMaxUnicastProbes = 3 defaultMaxAnycastDelayTime = time.Second defaultMaxReachbilityConfirmations = 3 - defaultUnreachableTime = 5 * time.Second defaultFakeRandomNum = 0.5 ) @@ -565,58 +564,6 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { } } -func TestNUDConfigurationsUnreachableTime(t *testing.T) { - tests := []struct { - name string - unreachableTime time.Duration - want time.Duration - }{ - // Invalid cases - { - name: "EqualToZero", - unreachableTime: 0, - want: defaultUnreachableTime, - }, - // Valid cases - { - name: "MoreThanZero", - unreachableTime: time.Millisecond, - want: time.Millisecond, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.UnreachableTime = test.unreachableTime - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - UseNeighborCache: true, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) - } - if got := sc.UnreachableTime; got != test.want { - t.Errorf("got UnreachableTime = %q, want = %q", got, test.want) - } - }) - } -} - // TestNUDStateReachableTime verifies the correctness of the ReachableTime // computation. func TestNUDStateReachableTime(t *testing.T) { diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 5d364a2b0..4a3adcf33 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -103,7 +103,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro for _, p := range packets { if cancelled { p.route.Stats().IP.OutgoingPacketErrors.Increment() - } else if _, err := p.route.Resolve(nil); err != nil { + } else if p.route.IsResolutionRequired() { p.route.Stats().IP.OutgoingPacketErrors.Increment() } else { p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 00e9a82ae..4795208b4 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -17,7 +17,6 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -65,10 +64,6 @@ const ( // NetworkPacketInfo holds information about a network layer packet. type NetworkPacketInfo struct { - // RemoteAddressBroadcast is true if the packet's remote address is a - // broadcast address. - RemoteAddressBroadcast bool - // LocalAddressBroadcast is true if the packet's local address is a broadcast // address. LocalAddressBroadcast bool @@ -89,7 +84,7 @@ type TransportEndpoint interface { // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. // HandleControlPacket takes ownership of pkt. - HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) + HandleControlPacket(typ ControlType, extra uint32, pkt *PacketBuffer) // Abort initiates an expedited endpoint teardown. It puts the endpoint // in a closed state and frees all resources associated with it. This @@ -263,15 +258,6 @@ const ( PacketLoop ) -// NetOptions is an interface that allows us to pass network protocol specific -// options through the Stack layer code. -type NetOptions interface { - // AllocationSize returns the amount of memory that must be allocated to - // hold the options given that the value must be rounded up to the next - // multiple of 4 bytes. - AllocationSize() int -} - // NetworkHeaderParams are the header parameters given as input by the // transport endpoint to the network. type NetworkHeaderParams struct { @@ -283,10 +269,6 @@ type NetworkHeaderParams struct { // TOS refers to TypeOfService or TrafficClass field of the IP-header. TOS uint8 - - // Options is a set of options to add to a network header (or nil). - // It will be protocol specific opaque information from higher layers. - Options NetOptions } // GroupAddressableEndpoint is an endpoint that supports group addressing. @@ -295,14 +277,10 @@ type NetworkHeaderParams struct { // endpoints may associate themselves with the same identifier (group address). type GroupAddressableEndpoint interface { // JoinGroup joins the specified group. - // - // Returns true if the group was newly joined. - JoinGroup(group tcpip.Address) (bool, *tcpip.Error) + JoinGroup(group tcpip.Address) *tcpip.Error // LeaveGroup attempts to leave the specified group. - // - // Returns tcpip.ErrBadLocalAddress if the endpoint has not joined the group. - LeaveGroup(group tcpip.Address) (bool, *tcpip.Error) + LeaveGroup(group tcpip.Address) *tcpip.Error // IsInGroup returns true if the endpoint is a member of the specified group. IsInGroup(group tcpip.Address) bool @@ -518,6 +496,9 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool + // Promiscuous returns true if the interface is in promiscuous mode. + Promiscuous() bool + // WritePacketToRemote writes the packet to the given remote link address. WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error } @@ -525,8 +506,6 @@ type NetworkInterface interface { // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { - AddressableEndpoint - // Enable enables the endpoint. // // Must only be called when the stack is in a state that allows the endpoint @@ -742,10 +721,6 @@ type LinkEndpoint interface { // endpoint. Capabilities() LinkEndpointCapabilities - // WriteRawPacket writes a packet directly to the link. The packet - // should already have an ethernet header. It takes ownership of vv. - WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error - // Attach attaches the data link layer endpoint to the network-layer // dispatcher of the stack. // @@ -823,19 +798,26 @@ type LinkAddressCache interface { // AddLinkAddress adds a link address to the cache. AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) - // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC). - // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver - // registered with the network protocol, the cache attempts to resolve the address - // and returns ErrWouldBlock. Waker is notified when address resolution is - // complete (success or not). + // GetLinkAddress finds the link address corresponding to the remote address + // (e.g. IP -> MAC). // - // If address resolution is required, ErrNoLinkAddress and a notification channel is - // returned for the top level caller to block. Channel is closed once address resolution - // is complete (success or not). - GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) - - // RemoveWaker removes a waker that has been added in GetLinkAddress(). - RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) + // Returns a link address for the remote address, if readily available. + // + // Returns ErrWouldBlock if the link address is not readily available, along + // with a notification channel for the caller to block on. Triggers address + // resolution asynchronously. + // + // If onResolve is provided, it will be called either immediately, if + // resolution is not required, or when address resolution is complete, with + // the resolved link address and whether resolution succeeded. After any + // callbacks have been called, the returned notification channel is closed. + // + // If specified, the local address must be an address local to the interface + // the neighbor cache belongs to. The local address is the source address of + // a packet prompting NUD/link address resolution. + // + // TODO(gvisor.dev/issue/5151): Don't return the link address. + GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) } // RawFactory produces endpoints for writing various types of raw packets. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 15ff437c7..b0251d0b4 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -17,20 +17,53 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) // Route represents a route through the networking stack to a given destination. +// +// It is safe to call Route's methods from multiple goroutines. +// +// The exported fields are immutable. +// +// TODO(gvisor.dev/issue/4902): Unexpose immutable fields. type Route struct { + routeInfo + + // localAddressNIC is the interface the address is associated with. + // TODO(gvisor.dev/issue/4548): Remove this field once we can query the + // address's assigned status without the NIC. + localAddressNIC *NIC + + mu struct { + sync.RWMutex + + // localAddressEndpoint is the local address this route is associated with. + localAddressEndpoint AssignableAddressEndpoint + + // remoteLinkAddress is the link-layer (MAC) address of the next hop in the + // route. + remoteLinkAddress tcpip.LinkAddress + } + + // outgoingNIC is the interface this route uses to write packets. + outgoingNIC *NIC + + // linkCache is set if link address resolution is enabled for this protocol on + // the route's NIC. + linkCache LinkAddressCache + + // linkRes is set if link address resolution is enabled for this protocol on + // the route's NIC. + linkRes LinkAddressResolver +} + +type routeInfo struct { // RemoteAddress is the final destination of the route. RemoteAddress tcpip.Address - // RemoteLinkAddress is the link-layer (MAC) address of the - // final destination of the route. - RemoteLinkAddress tcpip.LinkAddress - // LocalAddress is the local address where the route starts. LocalAddress tcpip.Address @@ -46,47 +79,48 @@ type Route struct { // Loop controls where WritePacket should send packets. Loop PacketLooping +} - // localAddressNIC is the interface the address is associated with. - // TODO(gvisor.dev/issue/4548): Remove this field once we can query the - // address's assigned status without the NIC. - localAddressNIC *NIC - - // localAddressEndpoint is the local address this route is associated with. - localAddressEndpoint AssignableAddressEndpoint - - // outgoingNIC is the interface this route uses to write packets. - outgoingNIC *NIC +// RouteInfo contains all of Route's exported fields. +type RouteInfo struct { + routeInfo - // linkCache is set if link address resolution is enabled for this protocol on - // the route's NIC. - linkCache LinkAddressCache + // RemoteLinkAddress is the link-layer (MAC) address of the next hop in the + // route. + RemoteLinkAddress tcpip.LinkAddress +} - // linkRes is set if link address resolution is enabled for this protocol on - // the route's NIC. - linkRes LinkAddressResolver +// GetFields returns a RouteInfo with all of r's exported fields. This allows +// callers to store the route's fields without retaining a reference to it. +func (r *Route) GetFields() RouteInfo { + return RouteInfo{ + routeInfo: r.routeInfo, + RemoteLinkAddress: r.RemoteLinkAddress(), + } } // constructAndValidateRoute validates and initializes a route. It takes // ownership of the provided local address. // // Returns an empty route if validation fails. -func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route { - addrWithPrefix := addressEndpoint.AddressWithPrefix() +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, localAddr, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) *Route { + if len(localAddr) == 0 { + localAddr = addressEndpoint.AddressWithPrefix().Address + } - if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) { + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) { addressEndpoint.DecRef() - return Route{} + return nil } // If no remote address is provided, use the local address. if len(remoteAddr) == 0 { - remoteAddr = addrWithPrefix.Address + remoteAddr = localAddr } r := makeRoute( netProto, - addrWithPrefix.Address, + localAddr, remoteAddr, outgoingNIC, localAddressNIC, @@ -99,8 +133,8 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp // broadcast it. if len(gateway) > 0 { r.NextHop = gateway - } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + } else if subnet := addressEndpoint.Subnet(); subnet.IsBroadcast(remoteAddr) { + r.ResolveWith(header.EthernetBroadcastAddress) } return r @@ -108,11 +142,15 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp // makeRoute initializes a new route. It takes ownership of the provided // AssignableAddressEndpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) *Route { if localAddressNIC.stack != outgoingNIC.stack { panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) } + if len(localAddr) == 0 { + localAddr = localAddressEndpoint.AddressWithPrefix().Address + } + loop := PacketOut // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the @@ -133,18 +171,23 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) } -func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route { - r := Route{ - NetProto: netProto, - LocalAddress: localAddr, - LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), - RemoteAddress: remoteAddr, - localAddressNIC: localAddressNIC, - localAddressEndpoint: localAddressEndpoint, - outgoingNIC: outgoingNIC, - Loop: loop, +func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route { + r := &Route{ + routeInfo: routeInfo{ + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), + RemoteAddress: remoteAddr, + Loop: loop, + }, + localAddressNIC: localAddressNIC, + outgoingNIC: outgoingNIC, } + r.mu.Lock() + r.mu.localAddressEndpoint = localAddressEndpoint + r.mu.Unlock() + if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes @@ -159,7 +202,7 @@ func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr // provided AssignableAddressEndpoint. // // A local route is a route to a destination that is local to the stack. -func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route { +func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) *Route { loop := PacketLoop // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the // link endpoint level. We can remove this check once loopback interfaces @@ -170,26 +213,12 @@ func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) } -// PopulatePacketInfo populates a packet buffer's packet information fields. -// -// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by -// the network layer. -func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) { - if r.local() { - pkt.RXTransportChecksumValidated = true - } - pkt.NetworkPacketInfo = r.networkPacketInfo() -} - -// networkPacketInfo returns the network packet information of the route. -// -// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by -// the network layer. -func (r *Route) networkPacketInfo() NetworkPacketInfo { - return NetworkPacketInfo{ - RemoteAddressBroadcast: r.IsOutboundBroadcast(), - LocalAddressBroadcast: r.isInboundBroadcast(), - } +// RemoteLinkAddress returns the link-layer (MAC) address of the next hop in +// the route. +func (r *Route) RemoteLinkAddress() tcpip.LinkAddress { + r.mu.RLock() + defer r.mu.RUnlock() + return r.mu.remoteLinkAddress } // NICID returns the id of the NIC from which this route originates. @@ -253,22 +282,26 @@ func (r *Route) GSOMaxSize() uint32 { // ResolveWith immediately resolves a route with the specified remote link // address. func (r *Route) ResolveWith(addr tcpip.LinkAddress) { - r.RemoteLinkAddress = addr + r.mu.Lock() + defer r.mu.Unlock() + r.mu.remoteLinkAddress = addr } -// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in -// case address resolution requires blocking, e.g. wait for ARP reply. Waker is -// notified when address resolution is complete (success or not). -// -// If address resolution is required, ErrNoLinkAddress and a notification channel is -// returned for the top level caller to block. Channel is closed once address resolution -// is complete (success or not). +// Resolve attempts to resolve the link address if necessary. // -// The NIC r uses must not be locked. -func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { - if !r.IsResolutionRequired() { +// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g. +// waiting for ARP reply). If address resolution is required, a notification +// channel is also returned for the caller to block on. The channel is closed +// once address resolution is complete (successful or not). If a callback is +// provided, it will be called when address resolution is complete, regardless +// of success or failure. +func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { + r.mu.Lock() + + if !r.isResolutionRequiredRLocked() { // Nothing to do if there is no cache (which does the resolution on cache miss) or // link address is already known. + r.mu.Unlock() return nil, nil } @@ -276,7 +309,8 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { if nextAddr == "" { // Local link address is already known. if r.RemoteAddress == r.LocalAddress { - r.RemoteLinkAddress = r.LocalLinkAddress + r.mu.remoteLinkAddress = r.LocalLinkAddress + r.mu.Unlock() return nil, nil } nextAddr = r.RemoteAddress @@ -289,38 +323,36 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { linkAddressResolutionRequestLocalAddr = r.LocalAddress } + // Increment the route's reference count because finishResolution retains a + // reference to the route and releases it when called. + r.acquireLocked() + r.mu.Unlock() + + finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) { + if ok { + r.ResolveWith(linkAddress) + } + if afterResolve != nil { + afterResolve() + } + r.Release() + } + if neigh := r.outgoingNIC.neigh; neigh != nil { - entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker) + _, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution) if err != nil { return ch, err } - r.RemoteLinkAddress = entry.LinkAddr return nil, nil } - linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker) + _, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, finishResolution) if err != nil { return ch, err } - r.RemoteLinkAddress = linkAddr return nil, nil } -// RemoveWaker removes a waker that has been added in Resolve(). -func (r *Route) RemoveWaker(waker *sleep.Waker) { - nextAddr := r.NextHop - if nextAddr == "" { - nextAddr = r.RemoteAddress - } - - if neigh := r.outgoingNIC.neigh; neigh != nil { - neigh.removeWaker(nextAddr, waker) - return - } - - r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker) -} - // local returns true if the route is a local route. func (r *Route) local() bool { return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback() @@ -331,7 +363,13 @@ func (r *Route) local() bool { // // The NICs the route is associated with must not be locked. func (r *Route) IsResolutionRequired() bool { - if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() { + r.mu.RLock() + defer r.mu.RUnlock() + return r.isResolutionRequiredRLocked() +} + +func (r *Route) isResolutionRequiredRLocked() bool { + if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() { return false } @@ -339,11 +377,18 @@ func (r *Route) IsResolutionRequired() bool { } func (r *Route) isValidForOutgoing() bool { + r.mu.RLock() + defer r.mu.RUnlock() + return r.isValidForOutgoingRLocked() +} + +func (r *Route) isValidForOutgoingRLocked() bool { if !r.outgoingNIC.Enabled() { return false } - if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) { + localAddressEndpoint := r.mu.localAddressEndpoint + if localAddressEndpoint == nil || !r.localAddressNIC.isValidForOutgoing(localAddressEndpoint) { return false } @@ -395,39 +440,31 @@ func (r *Route) MTU() uint32 { return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU() } -// Release frees all resources associated with the route. +// Release decrements the reference counter of the resources associated with the +// route. func (r *Route) Release() { - if r.localAddressEndpoint != nil { - r.localAddressEndpoint.DecRef() - r.localAddressEndpoint = nil + r.mu.Lock() + defer r.mu.Unlock() + + if ep := r.mu.localAddressEndpoint; ep != nil { + ep.DecRef() } } -// Clone clones the route. -func (r *Route) Clone() Route { - if r.localAddressEndpoint != nil { - if !r.localAddressEndpoint.IncRef() { +// Acquire increments the reference counter of the resources associated with the +// route. +func (r *Route) Acquire() { + r.mu.RLock() + defer r.mu.RUnlock() + r.acquireLocked() +} + +func (r *Route) acquireLocked() { + if ep := r.mu.localAddressEndpoint; ep != nil { + if !ep.IncRef() { panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress)) } } - return *r -} - -// MakeLoopedRoute duplicates the given route with special handling for routes -// used for sending multicast or broadcast packets. In those cases the -// multicast/broadcast address is the remote address when sending out, but for -// incoming (looped) packets it becomes the local address. Similarly, the local -// interface address that was the local address going out becomes the remote -// address coming in. This is different to unicast routes where local and -// remote addresses remain the same as they identify location (local vs remote) -// not direction (source vs destination). -func (r *Route) MakeLoopedRoute() Route { - l := r.Clone() - if r.RemoteAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(r.RemoteAddress) || header.IsV6MulticastAddress(r.RemoteAddress) { - l.RemoteAddress, l.LocalAddress = l.LocalAddress, l.RemoteAddress - l.RemoteLinkAddress = l.LocalLinkAddress - } - return l } // Stack returns the instance of the Stack that owns this route. @@ -440,7 +477,14 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool { return true } - subnet := r.localAddressEndpoint.Subnet() + r.mu.RLock() + localAddressEndpoint := r.mu.localAddressEndpoint + r.mu.RUnlock() + if localAddressEndpoint == nil { + return false + } + + subnet := localAddressEndpoint.Subnet() return subnet.IsBroadcast(addr) } @@ -450,27 +494,3 @@ func (r *Route) IsOutboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.isV4Broadcast(r.RemoteAddress) } - -// isInboundBroadcast returns true if the route is for an inbound broadcast -// packet. -func (r *Route) isInboundBroadcast() bool { - // Only IPv4 has a notion of broadcast. - return r.isV4Broadcast(r.LocalAddress) -} - -// ReverseRoute returns new route with given source and destination address. -func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { - return Route{ - NetProto: r.NetProto, - LocalAddress: dst, - LocalLinkAddress: r.RemoteLinkAddress, - RemoteAddress: src, - RemoteLinkAddress: r.LocalLinkAddress, - Loop: r.Loop, - localAddressNIC: r.localAddressNIC, - localAddressEndpoint: r.localAddressEndpoint, - outgoingNIC: r.outgoingNIC, - linkCache: r.linkCache, - linkRes: r.linkRes, - } -} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 0fe157128..114643b03 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -29,7 +29,6 @@ import ( "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -82,6 +81,7 @@ type TCPRACKState struct { FACK seqnum.Value RTT time.Duration Reord bool + DSACKSeen bool } // TCPEndpointID is the unique 4 tuple that identifies a given endpoint. @@ -170,6 +170,9 @@ type TCPSenderState struct { // Outstanding is the number of packets in flight. Outstanding int + // SackedOut is the number of packets which have been selectively acked. + SackedOut int + // SndWnd is the send window size in bytes. SndWnd seqnum.Size @@ -1080,7 +1083,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { flags := NICStateFlags{ Up: true, // Netstack interfaces are always up. Running: nic.Enabled(), - Promiscuous: nic.isPromiscuousMode(), + Promiscuous: nic.Promiscuous(), Loopback: nic.IsLoopback(), } nics[id] = NICInfo{ @@ -1117,6 +1120,16 @@ func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) } +// AddAddressWithPrefix is the same as AddAddress, but allows you to specify +// the address prefix. +func (s *Stack) AddAddressWithPrefix(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.AddressWithPrefix) *tcpip.Error { + ap := tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: addr, + } + return s.AddProtocolAddressWithOptions(id, ap, CanBePrimaryEndpoint) +} + // AddProtocolAddress adds a new network-layer protocol address to the // specified NIC. func (s *Stack) AddProtocolAddress(id tcpip.NICID, protocolAddress tcpip.ProtocolAddress) *tcpip.Error { @@ -1207,10 +1220,10 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP // from the specified NIC. // // Precondition: s.mu must be read locked. -func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { +func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint) if localAddressEndpoint == nil { - return Route{}, false + return nil } var outgoingNIC *NIC @@ -1234,12 +1247,12 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re // route. if outgoingNIC == nil { localAddressEndpoint.DecRef() - return Route{}, false + return nil } r := makeLocalRoute( netProto, - localAddressEndpoint.AddressWithPrefix().Address, + localAddr, remoteAddr, outgoingNIC, localAddressNIC, @@ -1248,10 +1261,10 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re if r.IsOutboundBroadcast() { r.Release() - return Route{}, false + return nil } - return r, true + return r } // findLocalRouteRLocked returns a local route. @@ -1260,26 +1273,26 @@ func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, re // is, a local route is a route where packets never have to leave the stack. // // Precondition: s.mu must be read locked. -func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { +func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) *Route { if len(localAddr) == 0 { localAddr = remoteAddr } if localAddressNICID == 0 { for _, localAddressNIC := range s.nics { - if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok { - return r, true + if r := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); r != nil { + return r } } - return Route{}, false + return nil } if localAddressNIC, ok := s.nics[localAddressNICID]; ok { return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto) } - return Route{}, false + return nil } // FindRoute creates a route to the given destination address, leaving through @@ -1293,7 +1306,7 @@ func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, // If no local address is provided, the stack will select a local address. If no // remote address is provided, the stack wil use a remote address equal to the // local address. -func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) { +func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (*Route, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() @@ -1304,7 +1317,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback) if s.handleLocal && !isMulticast && !isLocalBroadcast { - if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok { + if r := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); r != nil { return r, nil } } @@ -1316,7 +1329,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { return makeRoute( netProto, - addressEndpoint.AddressWithPrefix().Address, + localAddr, remoteAddr, nic, /* outboundNIC */ nic, /* localAddressNIC*/ @@ -1328,9 +1341,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if isLoopback { - return Route{}, tcpip.ErrBadLocalAddress + return nil, tcpip.ErrBadLocalAddress } - return Route{}, tcpip.ErrNetworkUnreachable + return nil, tcpip.ErrNetworkUnreachable } canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal @@ -1353,8 +1366,8 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if needRoute { gateway = route.Gateway } - r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop) - if r == (Route{}) { + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop) + if r == nil { panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) } return r, nil @@ -1390,13 +1403,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n if id != 0 { if aNIC, ok := s.nics[id]; ok { if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil { - if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil { return r, nil } } } - return Route{}, tcpip.ErrNoRoute + return nil, tcpip.ErrNoRoute } if id == 0 { @@ -1408,7 +1421,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n continue } - if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, localAddr, remoteAddr, s.handleLocal, multicastLoop); r != nil { return r, nil } } @@ -1416,12 +1429,12 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } if needRoute { - return Route{}, tcpip.ErrNoRoute + return nil, tcpip.ErrNoRoute } if header.IsV6LoopbackAddress(remoteAddr) { - return Route{}, tcpip.ErrBadLocalAddress + return nil, tcpip.ErrBadLocalAddress } - return Route{}, tcpip.ErrNetworkUnreachable + return nil, tcpip.ErrNetworkUnreachable } // CheckNetworkProtocol checks if a given network protocol is enabled in the @@ -1506,7 +1519,7 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t } // GetLinkAddress implements LinkAddressCache.GetLinkAddress. -func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { s.mu.RLock() nic := s.nics[nicID] if nic == nil { @@ -1517,7 +1530,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, onResolve) } // Neighbors returns all IP to MAC address associations. @@ -1533,29 +1546,6 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) { return nic.neighbors() } -// RemoveWaker removes a waker that has been added when link resolution for -// addr was requested. -func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { - if s.useNeighborCache { - s.mu.RLock() - nic, ok := s.nics[nicID] - s.mu.RUnlock() - - if ok { - nic.removeWaker(addr, waker) - } - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - if nic := s.nics[nicID]; nic == nil { - fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} - s.linkAddrCache.removeWaker(fullAddr, waker) - } -} - // AddStaticNeighbor statically associates an IP address to a MAC address. func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error { s.mu.RLock() @@ -1809,49 +1799,20 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip nic.unregisterPacketEndpoint(netProto, ep) } -// WritePacket writes data directly to the specified NIC. It adds an ethernet -// header based on the arguments. -func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { +// WritePacketToRemote writes a payload on the specified NIC using the provided +// network protocol and remote link address. +func (s *Stack) WritePacketToRemote(nicID tcpip.NICID, remote tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error { s.mu.Lock() nic, ok := s.nics[nicID] s.mu.Unlock() if !ok { return tcpip.ErrUnknownDevice } - - // Add our own fake ethernet header. - ethFields := header.EthernetFields{ - SrcAddr: nic.LinkEndpoint.LinkAddress(), - DstAddr: dst, - Type: netProto, - } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - vv := buffer.View(fakeHeader).ToVectorisedView() - vv.Append(payload) - - if err := nic.LinkEndpoint.WriteRawPacket(vv); err != nil { - return err - } - - return nil -} - -// WriteRawPacket writes data directly to the specified NIC without adding any -// headers. -func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error { - s.mu.Lock() - nic, ok := s.nics[nicID] - s.mu.Unlock() - if !ok { - return tcpip.ErrUnknownDevice - } - - if err := nic.LinkEndpoint.WriteRawPacket(payload); err != nil { - return err - } - - return nil + pkt := NewPacketBuffer(PacketBufferOptions{ + ReserveHeaderBytes: int(nic.MaxHeaderLength()), + Data: payload, + }) + return nic.WritePacketToRemote(remote, nil, netProto, pkt) } // NetworkProtocolInstance returns the protocol instance in the stack for the @@ -1911,7 +1872,6 @@ func (s *Stack) RemoveTCPProbe() { // JoinGroup joins the given multicast group on the given NIC. func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { - // TODO: notify network of subscription via igmp protocol. s.mu.RLock() defer s.mu.RUnlock() @@ -2158,3 +2118,43 @@ func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber { } return protos } + +func isSubnetBroadcastOnNIC(nic *NIC, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + addressEndpoint := nic.getAddressOrCreateTempInner(protocol, addr, false /* createTemp */, NeverPrimaryEndpoint) + if addressEndpoint == nil { + return false + } + + subnet := addressEndpoint.Subnet() + addressEndpoint.DecRef() + return subnet.IsBroadcast(addr) +} + +// IsSubnetBroadcast returns true if the provided address is a subnet-local +// broadcast address on the specified NIC and protocol. +// +// Returns false if the NIC is unknown or if the protocol is unknown or does +// not support addressing. +// +// If the NIC is not specified, the stack will check all NICs. +func (s *Stack) IsSubnetBroadcast(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + s.mu.RLock() + defer s.mu.RUnlock() + + if nicID != 0 { + nic, ok := s.nics[nicID] + if !ok { + return false + } + + return isSubnetBroadcastOnNIC(nic, protocol, addr) + } + + for _, nic := range s.nics { + if isSubnetBroadcastOnNIC(nic, protocol, addr) { + return true + } + } + + return false +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index dedfdd435..856ebf6d4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -27,7 +27,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -112,7 +111,15 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. netHdr := pkt.NetworkHeader().View() - f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++ + + dst := tcpip.Address(netHdr[dstAddrOffset:][:1]) + addressEndpoint := f.AcquireAssignedAddress(dst, f.nic.Promiscuous(), stack.CanBePrimaryEndpoint) + if addressEndpoint == nil { + return + } + addressEndpoint.DecRef() + + f.proto.packetCount[int(dst[0])%len(f.proto.packetCount)]++ // Handle control packets. if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { @@ -159,9 +166,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - pkt := pkt.Clone() - r.PopulatePacketInfo(pkt) - f.HandlePacket(pkt) + f.HandlePacket(pkt.Clone()) } if r.Loop&stack.PacketOut == 0 { return nil @@ -401,7 +406,7 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro return send(r, payload) } -func send(r stack.Route, payload buffer.View) *tcpip.Error { +func send(r *stack.Route, payload buffer.View) *tcpip.Error { return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()), Data: payload.ToVectorisedView(), @@ -419,7 +424,7 @@ func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.En } } -func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { +func testSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View) { t.Helper() ep.Drain() if err := send(r, payload); err != nil { @@ -430,7 +435,7 @@ func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer. } } -func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r *stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) @@ -1557,15 +1562,15 @@ func TestSpoofingNoAddress(t *testing.T) { // testSendTo(t, s, remoteAddr, ep, nil) } -func verifyRoute(gotRoute, wantRoute stack.Route) error { +func verifyRoute(gotRoute, wantRoute *stack.Route) error { if gotRoute.LocalAddress != wantRoute.LocalAddress { return fmt.Errorf("bad local address: got %s, want = %s", gotRoute.LocalAddress, wantRoute.LocalAddress) } if gotRoute.RemoteAddress != wantRoute.RemoteAddress { return fmt.Errorf("bad remote address: got %s, want = %s", gotRoute.RemoteAddress, wantRoute.RemoteAddress) } - if gotRoute.RemoteLinkAddress != wantRoute.RemoteLinkAddress { - return fmt.Errorf("bad remote link address: got %s, want = %s", gotRoute.RemoteLinkAddress, wantRoute.RemoteLinkAddress) + if got, want := gotRoute.RemoteLinkAddress(), wantRoute.RemoteLinkAddress(); got != want { + return fmt.Errorf("bad remote link address: got %s, want = %s", got, want) } if gotRoute.NextHop != wantRoute.NextHop { return fmt.Errorf("bad next-hop address: got %s, want = %s", gotRoute.NextHop, wantRoute.NextHop) @@ -1597,7 +1602,10 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + var wantRoute stack.Route + wantRoute.LocalAddress = header.IPv4Any + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1651,7 +1659,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + var wantRoute stack.Route + wantRoute.LocalAddress = nic1Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1661,7 +1672,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + wantRoute = stack.Route{} + wantRoute.LocalAddress = nic2Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } @@ -1677,7 +1691,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + wantRoute = stack.Route{} + wantRoute.LocalAddress = nic1Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } @@ -2214,88 +2231,6 @@ func TestNICStats(t *testing.T) { } } -func TestNICForwarding(t *testing.T) { - const nicID1 = 1 - const nicID2 = 2 - const dstAddr = tcpip.Address("\x03") - - tests := []struct { - name string - headerLen uint16 - }{ - { - name: "Zero header length", - }, - { - name: "Non-zero header length", - headerLen: 16, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) - s.SetForwarding(fakeNetNumber, true) - - ep1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID1, ep1); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) - } - if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err) - } - - ep2 := channelLinkWithHeaderLength{ - Endpoint: channel.New(10, defaultMTU, ""), - headerLength: test.headerLen, - } - if err := s.CreateNIC(nicID2, &ep2); err != nil { - t.Fatalf("CreateNIC(%d, _): %s", nicID2, err) - } - if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err) - } - - // Route all packets to dstAddr to NIC 2. - { - subnet, err := tcpip.NewSubnet(dstAddr, "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}}) - } - - // Send a packet to dstAddr. - buf := buffer.NewView(30) - buf[dstAddrOffset] = dstAddr[0] - ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - pkt, ok := ep2.Read() - if !ok { - t.Fatal("packet not forwarded") - } - - // Test that the link's MaxHeaderLength is honoured. - if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want { - t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want) - } - - // Test that forwarding increments Tx stats correctly. - if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want { - t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want) - } - - if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want { - t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) - } - }) - } -} - // TestNICContextPreservation tests that you can read out via stack.NICInfo the // Context data you pass via NICContext.Context in stack.CreateNICWithOptions. func TestNICContextPreservation(t *testing.T) { @@ -2483,9 +2418,9 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) { } opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: test.autoGen, - NDPDisp: &ndpDisp, - OpaqueIIDOpts: test.iidOpts, + AutoGenLinkLocal: test.autoGen, + NDPDisp: &ndpDisp, + OpaqueIIDOpts: test.iidOpts, })}, } @@ -2578,8 +2513,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) { t.Run(test.name, func(t *testing.T) { opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - AutoGenIPv6LinkLocal: true, - OpaqueIIDOpts: test.opaqueIIDOpts, + AutoGenLinkLocal: true, + OpaqueIIDOpts: test.opaqueIIDOpts, })}, } @@ -2612,9 +2547,9 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { ndpConfigs := ipv6.DefaultNDPConfigurations() opts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPConfigs: ndpConfigs, - AutoGenIPv6LinkLocal: true, - NDPDisp: &ndpDisp, + NDPConfigs: ndpConfigs, + AutoGenLinkLocal: true, + NDPDisp: &ndpDisp, })}, } @@ -2803,8 +2738,16 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - nicID = 1 - lifetimeSeconds = 9999 + globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01") + ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02") + toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + + nicID = 1 + lifetimeSeconds = 9999 ) prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1) @@ -2821,139 +2764,191 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix nicAddrs []tcpip.Address slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix - connectAddr tcpip.Address + remoteAddr tcpip.Address expectedLocalAddr tcpip.Address }{ - // Test Rule 1 of RFC 6724 section 5. + // Test Rule 1 of RFC 6724 section 5 (prefer same address). { name: "Same Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: globalAddr1, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: globalAddr1, expectedLocalAddr: globalAddr1, }, { name: "Same Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: globalAddr1, + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1}, + remoteAddr: globalAddr1, expectedLocalAddr: globalAddr1, }, { name: "Same Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalAddr1, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: linkLocalAddr1, expectedLocalAddr: linkLocalAddr1, }, { name: "Same Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalAddr1, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: linkLocalAddr1, expectedLocalAddr: linkLocalAddr1, }, { name: "Same Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, - connectAddr: uniqueLocalAddr1, + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1}, + remoteAddr: uniqueLocalAddr1, expectedLocalAddr: uniqueLocalAddr1, }, { name: "Same Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: uniqueLocalAddr1, + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1}, + remoteAddr: uniqueLocalAddr1, expectedLocalAddr: uniqueLocalAddr1, }, - // Test Rule 2 of RFC 6724 section 5. + // Test Rule 2 of RFC 6724 section 5 (prefer appropriate scope). { name: "Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: globalAddr2, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: globalAddr2, expectedLocalAddr: globalAddr1, }, { name: "Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: globalAddr2, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: globalAddr2, expectedLocalAddr: globalAddr1, }, { name: "Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalAddr2, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: linkLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalAddr2, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: linkLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local most preferred for link local multicast (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalMulticastAddr, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: linkLocalMulticastAddr, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local most preferred for link local multicast (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalMulticastAddr, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: linkLocalMulticastAddr, expectedLocalAddr: linkLocalAddr1, }, + + // Test Rule 6 of 6724 section 5 (prefer matching label). { name: "Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, - connectAddr: uniqueLocalAddr2, + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1}, + remoteAddr: uniqueLocalAddr2, expectedLocalAddr: uniqueLocalAddr1, }, { name: "Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: uniqueLocalAddr2, + nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1}, + remoteAddr: uniqueLocalAddr2, expectedLocalAddr: uniqueLocalAddr1, }, + { + name: "Toredo most preferred (first address)", + nicAddrs: []tcpip.Address{toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1}, + remoteAddr: toredoAddr2, + expectedLocalAddr: toredoAddr1, + }, + { + name: "Toredo most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1}, + remoteAddr: toredoAddr2, + expectedLocalAddr: toredoAddr1, + }, + { + name: "6To4 most preferred (first address)", + nicAddrs: []tcpip.Address{ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1}, + remoteAddr: ipv6ToIPv4Addr2, + expectedLocalAddr: ipv6ToIPv4Addr1, + }, + { + name: "6To4 most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, uniqueLocalAddr1, toredoAddr1, ipv6ToIPv4Addr1}, + remoteAddr: ipv6ToIPv4Addr2, + expectedLocalAddr: ipv6ToIPv4Addr1, + }, + { + name: "IPv4 mapped IPv6 most preferred (first address)", + nicAddrs: []tcpip.Address{ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1}, + remoteAddr: ipv4MappedIPv6Addr2, + expectedLocalAddr: ipv4MappedIPv6Addr1, + }, + { + name: "IPv4 mapped IPv6 most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1, ipv4MappedIPv6Addr1}, + remoteAddr: ipv4MappedIPv6Addr2, + expectedLocalAddr: ipv4MappedIPv6Addr1, + }, - // Test Rule 7 of RFC 6724 section 5. + // Test Rule 7 of RFC 6724 section 5 (prefer temporary addresses). { name: "Temp Global most preferred (last address)", slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: tempGlobalAddr1, }, { name: "Temp Global most preferred (first address)", nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, slaacPrefixForTempAddrAfterNICAddrAdd: prefix1, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: tempGlobalAddr1, }, + // Test Rule 8 of RFC 6724 section 5 (use longest matching prefix). + { + name: "Longest prefix matched most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr2, globalAddr1}, + remoteAddr: globalAddr3, + expectedLocalAddr: globalAddr2, + }, + { + name: "Longest prefix matched most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, globalAddr2}, + remoteAddr: globalAddr3, + expectedLocalAddr: globalAddr2, + }, + // Test returning the endpoint that is closest to the front when // candidate addresses are "equal" from the perspective of RFC 6724 // section 5. { name: "Unique Local for Global", nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2}, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: uniqueLocalAddr1, }, { name: "Link Local for Global", nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local for Unique Local", nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - connectAddr: uniqueLocalAddr2, + remoteAddr: uniqueLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Temp Global for Global", slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, slaacPrefixForTempAddrAfterNICAddrAdd: prefix2, - connectAddr: globalAddr1, + remoteAddr: globalAddr1, expectedLocalAddr: tempGlobalAddr2, }, } @@ -2975,12 +2970,6 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: llAddr3, - NIC: nicID, - }}) - s.AddLinkAddress(nicID, llAddr3, linkAddr3) if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) { e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) @@ -3000,7 +2989,23 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { t.FailNow() } - if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr { + netEP, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } + + addressableEndpoint, ok := netEP.(stack.AddressableEndpoint) + if !ok { + t.Fatal("network endpoint is not addressable") + } + + addressEP := addressableEndpoint.AcquireOutgoingPrimaryAddress(test.remoteAddr, false /* allowExpired */) + if addressEP == nil { + t.Fatal("expected a non-nil address endpoint") + } + defer addressEP.DecRef() + + if got := addressEP.AddressWithPrefix().Address; got != test.expectedLocalAddr { t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr) } }) @@ -3427,11 +3432,16 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { remNetSubnetBcast := remNetSubnet.Broadcast() tests := []struct { - name string - nicAddr tcpip.ProtocolAddress - routes []tcpip.Route - remoteAddr tcpip.Address - expectedRoute stack.Route + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + expectedLocalAddress tcpip.Address + expectedRemoteAddress tcpip.Address + expectedRemoteLinkAddress tcpip.LinkAddress + expectedNextHop tcpip.Address + expectedNetProto tcpip.NetworkProtocolNumber + expectedLoop stack.PacketLooping }{ // Broadcast to a locally attached subnet populates the broadcast MAC. { @@ -3446,14 +3456,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4SubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: ipv4SubnetBcast, - RemoteLinkAddress: header.EthernetBroadcastAddress, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut | stack.PacketLoop, - }, + remoteAddr: ipv4SubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: ipv4SubnetBcast, + expectedRemoteLinkAddress: header.EthernetBroadcastAddress, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut | stack.PacketLoop, }, // Broadcast to a locally attached /31 subnet does not populate the // broadcast MAC. @@ -3469,13 +3477,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4Subnet31Bcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4AddrPrefix31.Address, - RemoteAddress: ipv4Subnet31Bcast, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv4Subnet31Bcast, + expectedLocalAddress: ipv4AddrPrefix31.Address, + expectedRemoteAddress: ipv4Subnet31Bcast, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to a locally attached /32 subnet does not populate the // broadcast MAC. @@ -3491,13 +3497,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv4Subnet32Bcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4AddrPrefix32.Address, - RemoteAddress: ipv4Subnet32Bcast, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv4Subnet32Bcast, + expectedLocalAddress: ipv4AddrPrefix32.Address, + expectedRemoteAddress: ipv4Subnet32Bcast, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // IPv6 has no notion of a broadcast. { @@ -3512,13 +3516,11 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: ipv6SubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv6Addr.Address, - RemoteAddress: ipv6SubnetBcast, - NetProto: header.IPv6ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: ipv6SubnetBcast, + expectedLocalAddress: ipv6Addr.Address, + expectedRemoteAddress: ipv6SubnetBcast, + expectedNetProto: header.IPv6ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to a remote subnet in the route table is send to the next-hop // gateway. @@ -3535,14 +3537,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: remNetSubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: remNetSubnetBcast, - NextHop: ipv4Gateway, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: remNetSubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: remNetSubnetBcast, + expectedNextHop: ipv4Gateway, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, // Broadcast to an unknown subnet follows the default route. Note that this // is essentially just routing an unknown destination IP, because w/o any @@ -3560,14 +3560,12 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: remNetSubnetBcast, - expectedRoute: stack.Route{ - LocalAddress: ipv4Addr.Address, - RemoteAddress: remNetSubnetBcast, - NextHop: ipv4Gateway, - NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, - }, + remoteAddr: remNetSubnetBcast, + expectedLocalAddress: ipv4Addr.Address, + expectedRemoteAddress: remNetSubnetBcast, + expectedNextHop: ipv4Gateway, + expectedNetProto: header.IPv4ProtocolNumber, + expectedLoop: stack.PacketOut, }, } @@ -3596,10 +3594,27 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { t.Fatalf("got unexpected address length = %d bytes", l) } - if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil { + r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */) + if err != nil { t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err) - } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" { - t.Errorf("route mismatch (-want +got):\n%s", diff) + } + if r.LocalAddress != test.expectedLocalAddress { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.expectedLocalAddress) + } + if r.RemoteAddress != test.expectedRemoteAddress { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.expectedRemoteAddress) + } + if got := r.RemoteLinkAddress(); got != test.expectedRemoteLinkAddress { + t.Errorf("got r.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddress) + } + if r.NextHop != test.expectedNextHop { + t.Errorf("got r.NextHop = %s, want = %s", r.NextHop, test.expectedNextHop) + } + if r.NetProto != test.expectedNetProto { + t.Errorf("got r.NetProto = %d, want = %d", r.NetProto, test.expectedNetProto) + } + if r.Loop != test.expectedLoop { + t.Errorf("got r.Loop = %x, want = %x", r.Loop, test.expectedLoop) } }) } @@ -4167,10 +4182,12 @@ func TestFindRouteWithForwarding(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + if r != nil { + defer r.Release() + } if err != test.findRouteErr { t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr) } - defer r.Release() if test.findRouteErr != nil { return @@ -4228,3 +4245,63 @@ func TestFindRouteWithForwarding(t *testing.T) { }) } } + +func TestWritePacketToRemote(t *testing.T) { + const nicID = 1 + const MTU = 1280 + e := channel.New(1, MTU, linkAddr1) + s := stack.New(stack.Options{}) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("CreateNIC(%d) = %s", nicID, err) + } + tests := []struct { + name string + protocol tcpip.NetworkProtocolNumber + payload []byte + }{ + { + name: "SuccessIPv4", + protocol: header.IPv4ProtocolNumber, + payload: []byte{1, 2, 3, 4}, + }, + { + name: "SuccessIPv6", + protocol: header.IPv6ProtocolNumber, + payload: []byte{5, 6, 7, 8}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := s.WritePacketToRemote(nicID, linkAddr2, test.protocol, buffer.View(test.payload).ToVectorisedView()); err != nil { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s", err) + } + + pkt, ok := e.Read() + if got, want := ok, true; got != want { + t.Fatalf("e.Read() = %t, want %t", got, want) + } + if got, want := pkt.Proto, test.protocol; got != want { + t.Fatalf("pkt.Proto = %d, want %d", got, want) + } + if pkt.Route.RemoteLinkAddress != linkAddr2 { + t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2) + } + if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { + t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) + } + }) + } + + t.Run("InvalidNICID", func(t *testing.T) { + if got, want := s.WritePacketToRemote(234, linkAddr2, header.IPv4ProtocolNumber, buffer.View([]byte{1}).ToVectorisedView()), tcpip.ErrUnknownDevice; got != want { + t.Fatalf("s.WritePacketToRemote(_, _, _, _) = %s, want = %s", got, want) + } + pkt, ok := e.Read() + if got, want := ok, false; got != want { + t.Fatalf("e.Read() = %t, %v; want %t", got, pkt, want) + } + }) +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index f183ec6e4..07b2818d2 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -182,7 +182,8 @@ func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *Packet epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +// handleControlPacket delivers a control packet to the transport endpoint +// identified by id. func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) { epsByNIC.mu.RLock() defer epsByNIC.mu.RUnlock() @@ -199,7 +200,7 @@ func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpoint // broadcast like we are doing with handlePacket above? // multiPortEndpoints are guaranteed to have at least one element. - selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt) + selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(typ, extra, pkt) } // registerEndpoint returns true if it succeeds. It fails and returns diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 41a8e5ad0..859278f0b 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "io/ioutil" "math" "math/rand" "testing" @@ -141,11 +142,11 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: testSrcAddrV6, - DstAddr: testDstAddrV6, + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: testSrcAddrV6, + DstAddr: testDstAddrV6, }) // Initialize the UDP header. @@ -307,12 +308,9 @@ func TestBindToDeviceDistribution(t *testing.T) { }(ep) defer ep.Close() - if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil { - t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err) - } - bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) - if err := ep.SetSockOpt(&bindToDeviceOption); err != nil { - t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err) + ep.SocketOptions().SetReusePort(endpoint.reuse) + if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { + t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) } var dstAddr tcpip.Address @@ -354,7 +352,7 @@ func TestBindToDeviceDistribution(t *testing.T) { } ep := <-pollChannel - if _, _, err := ep.Read(nil); err != nil { + if _, err := ep.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != nil { t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err) } stats[ep]++ diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index c457b67a2..0ff32c6ea 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -15,12 +15,12 @@ package stack_test import ( + "io" "testing" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -39,14 +39,18 @@ const ( // use it. type fakeTransportEndpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler proto *fakeTransportProtocol peerAddr tcpip.Address - route stack.Route + route *stack.Route uniqueID uint64 // acceptQueue is non-nil iff bound. - acceptQueue []fakeTransportEndpoint + acceptQueue []*fakeTransportEndpoint + + // ops is used to set and get socket options. + ops tcpip.SocketOptions } func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { @@ -59,8 +63,14 @@ func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} +func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions { + return &f.ops +} + func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { - return &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} + ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} + ep.ops.InitHandler(ep) + return ep } func (f *fakeTransportEndpoint) Abort() { @@ -68,6 +78,7 @@ func (f *fakeTransportEndpoint) Abort() { } func (f *fakeTransportEndpoint) Close() { + // TODO(gvisor.dev/issue/5153): Consider retaining the route. f.route.Release() } @@ -75,8 +86,8 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask return mask } -func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - return buffer.View{}, tcpip.ControlMessages{}, nil +func (*fakeTransportEndpoint) Read(io.Writer, int, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { + return tcpip.ReadResult{}, nil } func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { @@ -100,30 +111,16 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return int64(len(v)), nil, nil } -func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil -} - // SetSockOpt sets a socket option. Currently not supported. func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error { return tcpip.ErrInvalidEndpointState } -// SetSockOptBool sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - // SetSockOptInt sets a socket option. Currently not supported. func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error { return tcpip.ErrInvalidEndpointState } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrUnknownProtocolOption -} - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption @@ -147,16 +144,16 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return tcpip.ErrNoRoute } - defer r.Release() // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { + r.Release() return err } - f.route = r.Clone() + f.route = r return nil } @@ -186,7 +183,7 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai } a := f.acceptQueue[0] f.acceptQueue = f.acceptQueue[1:] - return &a, nil, nil + return a, nil, nil } func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { @@ -201,7 +198,7 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { ); err != nil { return err } - f.acceptQueue = []fakeTransportEndpoint{} + f.acceptQueue = []*fakeTransportEndpoint{} return nil } @@ -227,7 +224,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * } route.ResolveWith(pkt.SourceLinkAddress()) - f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ + ep := &fakeTransportEndpoint{ TransportEndpointInfo: stack.TransportEndpointInfo{ ID: f.ID, NetProto: f.NetProto, @@ -235,10 +232,12 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt * proto: f.proto, peerAddr: route.RemoteAddress, route: route, - }) + } + ep.ops.InitHandler(ep) + f.acceptQueue = append(f.acceptQueue, ep) } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.ControlType, uint32, *stack.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -553,87 +552,3 @@ func TestTransportOptions(t *testing.T) { t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true") } } - -func TestTransportForwarding(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - TransportProtocols: []stack.TransportProtocolFactory{fakeTransFactory}, - }) - s.SetForwarding(fakeNetNumber, true) - - // TODO(b/123449044): Change this to a channel NIC. - ep1 := loopback.New() - if err := s.CreateNIC(1, ep1); err != nil { - t.Fatalf("CreateNIC #1 failed: %v", err) - } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { - t.Fatalf("AddAddress #1 failed: %v", err) - } - - ep2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, ep2); err != nil { - t.Fatalf("CreateNIC #2 failed: %v", err) - } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { - t.Fatalf("AddAddress #2 failed: %v", err) - } - - // Route all packets to address 3 to NIC 2 and all packets to address - // 1 to NIC 1. - { - subnet0, err := tcpip.NewSubnet("\x03", "\xff") - if err != nil { - t.Fatal(err) - } - subnet1, err := tcpip.NewSubnet("\x01", "\xff") - if err != nil { - t.Fatal(err) - } - s.SetRouteTable([]tcpip.Route{ - {Destination: subnet0, Gateway: "\x00", NIC: 2}, - {Destination: subnet1, Gateway: "\x00", NIC: 1}, - }) - } - - wq := waiter.Queue{} - ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) - if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) - } - - if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}); err != nil { - t.Fatalf("Bind failed: %v", err) - } - - // Send a packet to address 1 from address 3. - req := buffer.NewView(30) - req[0] = 1 - req[1] = 3 - req[2] = byte(fakeTransNumber) - ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: req.ToVectorisedView(), - })) - - aep, _, err := ep.Accept(nil) - if err != nil || aep == nil { - t.Fatalf("Accept failed: %v, %v", aep, err) - } - - resp := buffer.NewView(30) - if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil { - t.Fatalf("Write failed: %v", err) - } - - p, ok := ep2.Read() - if !ok { - t.Fatal("Response packet not forwarded") - } - - nh := stack.PayloadSince(p.Pkt.NetworkHeader()) - if dst := nh[0]; dst != 3 { - t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) - } - if src := nh[1]; src != 1 { - t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) - } -} diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 3ab2b7654..f798056c0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -31,6 +31,7 @@ package tcpip import ( "errors" "fmt" + "io" "math/bits" "reflect" "strconv" @@ -39,7 +40,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/waiter" ) @@ -49,8 +49,9 @@ const ipv4AddressSize = 4 // Error represents an error in the netstack error space. Using a special type // ensures that errors outside of this space are not accidentally introduced. // -// Note: to support save / restore, it is important that all tcpip errors have -// distinct error messages. +// All errors must have unique msg strings. +// +// +stateify savable type Error struct { msg string @@ -112,6 +113,7 @@ var ( ErrNotPermitted = &Error{msg: "operation not permitted"} ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} ErrMalformedHeader = &Error{msg: "header is malformed"} + ErrBadBuffer = &Error{msg: "bad buffer"} ) var messageToError map[string]*Error @@ -161,6 +163,7 @@ func StringToError(s string) *Error { ErrNotPermitted, ErrAddressFamilyNotSupported, ErrMalformedHeader, + ErrBadBuffer, } messageToError = make(map[string]*Error) @@ -247,6 +250,54 @@ func (a Address) WithPrefix() AddressWithPrefix { } } +// Unspecified returns true if the address is unspecified. +func (a Address) Unspecified() bool { + for _, b := range a { + if b != 0 { + return false + } + } + return true +} + +// MatchingPrefix returns the matching prefix length in bits. +// +// Panics if b and a have different lengths. +func (a Address) MatchingPrefix(b Address) uint8 { + const bitsInAByte = 8 + + if len(a) != len(b) { + panic(fmt.Sprintf("addresses %s and %s do not have the same length", a, b)) + } + + var prefix uint8 + for i := range a { + aByte := a[i] + bByte := b[i] + + if aByte == bByte { + prefix += bitsInAByte + continue + } + + // Count the remaining matching bits in the byte from MSbit to LSBbit. + mask := uint8(1) << (bitsInAByte - 1) + for { + if aByte&mask == bByte&mask { + prefix++ + mask >>= 1 + continue + } + + break + } + + break + } + + return prefix +} + // AddressMask is a bitmask for an address. type AddressMask string @@ -447,6 +498,21 @@ func (s SlicePayload) Payload(size int) ([]byte, *Error) { return s[:size], nil } +var _ io.Writer = (*SliceWriter)(nil) + +// SliceWriter implements io.Writer for slices. +type SliceWriter []byte + +// Write implements io.Writer.Write. +func (s *SliceWriter) Write(b []byte) (int, error) { + n := copy(*s, b) + *s = (*s)[n:] + if n < len(b) { + return n, io.ErrShortWrite + } + return n, nil +} + // A ControlMessages contains socket control messages for IP sockets. // // +stateify savable @@ -481,6 +547,17 @@ type ControlMessages struct { // PacketInfo holds interface and address data on an incoming packet. PacketInfo IPPacketInfo + + // HasOriginalDestinationAddress indicates whether OriginalDstAddress is + // set. + HasOriginalDstAddress bool + + // OriginalDestinationAddress holds the original destination address + // and port of the incoming packet. + OriginalDstAddress FullAddress + + // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE). + SockErr *SockError } // PacketOwner is used to get UID and GID of the packet. @@ -492,6 +569,40 @@ type PacketOwner interface { GID() uint32 } +// ReadOptions contains options for Endpoint.Read. +type ReadOptions struct { + // Peek indicates whether this read is a peek. + Peek bool + + // NeedRemoteAddr indicates whether to return the remote address, if + // supported. + NeedRemoteAddr bool + + // NeedLinkPacketInfo indicates whether to return the link-layer information, + // if supported. + NeedLinkPacketInfo bool +} + +// ReadResult represents result for a successful Endpoint.Read. +type ReadResult struct { + // Count is the number of bytes received and written to the buffer. + Count int + + // Total is the number of bytes of the received packet. This can be used to + // determine whether the read is truncated. + Total int + + // ControlMessages is the control messages received. + ControlMessages ControlMessages + + // RemoteAddr is the remote address if ReadOptions.NeedAddr is true. + RemoteAddr FullAddress + + // LinkPacketInfo is the link-layer information of the received packet if + // ReadOptions.NeedLinkPacketInfo is true. + LinkPacketInfo LinkPacketInfo +} + // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) // that exposes functionality like read, write, connect, etc. to users of the // networking stack. @@ -506,11 +617,15 @@ type Endpoint interface { // Abort is best effort; implementing Abort with Close is acceptable. Abort() - // Read reads data from the endpoint and optionally returns the sender. + // Read reads data from the endpoint and optionally writes to dst. + // + // This method does not block if there is no data pending; in this case, + // ErrWouldBlock is returned. // - // This method does not block if there is no data pending. It will also - // either return an error or data, never both. - Read(*FullAddress) (buffer.View, ControlMessages, *Error) + // If non-zero number of bytes are successfully read and written to dst, err + // must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer + // should be returned. + Read(dst io.Writer, count int, opts ReadOptions) (res ReadResult, err *Error) // Write writes data to the endpoint's peer. This method does not block if // the data cannot be written. @@ -532,11 +647,6 @@ type Endpoint interface { // not). The channel is only non-nil in this case. Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error) - // Peek reads data without consuming it from the endpoint. - // - // This method does not block if there is no data pending. - Peek([][]byte) (int64, ControlMessages, *Error) - // Connect connects the endpoint to its peer. Specifying a NIC is // optional. // @@ -593,10 +703,6 @@ type Endpoint interface { // SetSockOpt sets a socket option. SetSockOpt(opt SettableSocketOption) *Error - // SetSockOptBool sets a socket option, for simple cases where a value - // has the bool type. - SetSockOptBool(opt SockOptBool, v bool) *Error - // SetSockOptInt sets a socket option, for simple cases where a value // has the int type. SetSockOptInt(opt SockOptInt, v int) *Error @@ -604,10 +710,6 @@ type Endpoint interface { // GetSockOpt gets a socket option. GetSockOpt(opt GettableSocketOption) *Error - // GetSockOptBool gets a socket option for simple cases where a return - // value has the bool type. - GetSockOptBool(SockOptBool) (bool, *Error) - // GetSockOptInt gets a socket option for simple cases where a return // value has the int type. GetSockOptInt(SockOptInt) (int, *Error) @@ -634,6 +736,10 @@ type Endpoint interface { // LastError clears and returns the last error reported by the endpoint. LastError() *Error + + // SocketOptions returns the structure which contains all the socket + // level options. + SocketOptions() *SocketOptions } // LinkPacketInfo holds Link layer information for a received packet. @@ -647,17 +753,6 @@ type LinkPacketInfo struct { PktType PacketType } -// PacketEndpoint are additional methods that are only implemented by Packet -// endpoints. -type PacketEndpoint interface { - // ReadPacket reads a datagram/packet from the endpoint and optionally - // returns the sender and additional LinkPacketInfo. - // - // This method does not block if there is no data pending. It will also - // either return an error or data, never both. - ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error) -} - // EndpointInfo is the interface implemented by each endpoint info struct. type EndpointInfo interface { // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo @@ -690,84 +785,6 @@ type WriteOptions struct { Atomic bool } -// SockOptBool represents socket options which values have the bool type. -type SockOptBool int - -const ( - // BroadcastOption is used by SetSockOptBool/GetSockOptBool to specify - // whether datagram sockets are allowed to send packets to a broadcast - // address. - BroadcastOption SockOptBool = iota - - // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if - // data should be held until segments are full by the TCP transport - // protocol. - CorkOption - - // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if - // data should be sent out immediately by the transport protocol. For - // TCP, it determines if the Nagle algorithm is on or off. - DelayOption - - // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to - // specify whether TCP keepalive is enabled for this socket. - KeepaliveEnabledOption - - // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to - // specify whether multicast packets sent over a non-loopback interface - // will be looped back. - MulticastLoopOption - - // NoChecksumOption is used by SetSockOptBool/GetSockOptBool to specify - // whether UDP checksum is disabled for this socket. - NoChecksumOption - - // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify - // whether SCM_CREDENTIALS socket control messages are enabled. - // - // Only supported on Unix sockets. - PasscredOption - - // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool. - QuickAckOption - - // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to - // specify if the IPV6_TCLASS ancillary message is passed with incoming - // packets. - ReceiveTClassOption - - // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify - // if the TOS ancillary message is passed with incoming packets. - ReceiveTOSOption - - // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to - // specify if more inforamtion is provided with incoming packets such as - // interface index and address. - ReceiveIPPacketInfoOption - - // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to - // specify whether Bind() should allow reuse of local address. - ReuseAddressOption - - // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit - // multiple sockets to be bound to an identical socket address. - ReusePortOption - - // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify - // whether an IPv6 socket is to be restricted to sending and receiving - // IPv6 packets only. - V6OnlyOption - - // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw - // endpoint that all packets being written have an IP header and the - // endpoint should not attach an IP header. - IPHdrIncludedOption - - // AcceptConnOption is used by GetSockOptBool to indicate if the - // socket is a listening socket. - AcceptConnOption -) - // SockOptInt represents socket options which values have the int type. type SockOptInt int @@ -977,14 +994,6 @@ type SettableSocketOption interface { isSettableSocketOption() } -// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets -// should bind only on a specific NIC. -type BindToDeviceOption NICID - -func (*BindToDeviceOption) isGettableSocketOption() {} - -func (*BindToDeviceOption) isSettableSocketOption() {} - // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // // TODO(b/64800844): Add and populate stat fields. @@ -1159,14 +1168,6 @@ type RemoveMembershipOption MembershipOption func (*RemoveMembershipOption) isSettableSocketOption() {} -// OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether -// TCP out-of-band data is delivered along with the normal in-band data. -type OutOfBandInlineOption int - -func (*OutOfBandInlineOption) isGettableSocketOption() {} - -func (*OutOfBandInlineOption) isSettableSocketOption() {} - // SocketDetachFilterOption is used by SetSockOpt to detach a previously attached // classic BPF filter on a given endpoint. type SocketDetachFilterOption int @@ -1216,10 +1217,6 @@ type LingerOption struct { Timeout time.Duration } -func (*LingerOption) isGettableSocketOption() {} - -func (*LingerOption) isSettableSocketOption() {} - // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable @@ -1390,6 +1387,18 @@ type ICMPv6PacketStats struct { // RedirectMsg is the total number of ICMPv6 redirect message packets // counted. RedirectMsg *StatCounter + + // MulticastListenerQuery is the total number of Multicast Listener Query + // messages counted. + MulticastListenerQuery *StatCounter + + // MulticastListenerReport is the total number of Multicast Listener Report + // messages counted. + MulticastListenerReport *StatCounter + + // MulticastListenerDone is the total number of Multicast Listener Done + // messages counted. + MulticastListenerDone *StatCounter } // ICMPv4SentPacketStats collects outbound ICMPv4-specific stats. @@ -1431,6 +1440,10 @@ type ICMPv6SentPacketStats struct { type ICMPv6ReceivedPacketStats struct { ICMPv6PacketStats + // Unrecognized is the total number of ICMPv6 packets received that the + // transport layer does not know how to parse. + Unrecognized *StatCounter + // Invalid is the total number of ICMPv6 packets received that the // transport layer could not parse. Invalid *StatCounter @@ -1440,33 +1453,102 @@ type ICMPv6ReceivedPacketStats struct { RouterOnlyPacketsDroppedByHost *StatCounter } -// ICMPStats collects ICMP-specific stats (both v4 and v6). -type ICMPStats struct { +// ICMPv4Stats collects ICMPv4-specific stats. +type ICMPv4Stats struct { // ICMPv4SentPacketStats contains counts of sent packets by ICMPv4 packet type // and a single count of packets which failed to write to the link // layer. - V4PacketsSent ICMPv4SentPacketStats + PacketsSent ICMPv4SentPacketStats // ICMPv4ReceivedPacketStats contains counts of received packets by ICMPv4 // packet type and a single count of invalid packets received. - V4PacketsReceived ICMPv4ReceivedPacketStats + PacketsReceived ICMPv4ReceivedPacketStats +} +// ICMPv6Stats collects ICMPv6-specific stats. +type ICMPv6Stats struct { // ICMPv6SentPacketStats contains counts of sent packets by ICMPv6 packet type // and a single count of packets which failed to write to the link // layer. - V6PacketsSent ICMPv6SentPacketStats + PacketsSent ICMPv6SentPacketStats // ICMPv6ReceivedPacketStats contains counts of received packets by ICMPv6 // packet type and a single count of invalid packets received. - V6PacketsReceived ICMPv6ReceivedPacketStats + PacketsReceived ICMPv6ReceivedPacketStats +} + +// ICMPStats collects ICMP-specific stats (both v4 and v6). +type ICMPStats struct { + // V4 contains the ICMPv4-specifics stats. + V4 ICMPv4Stats + + // V6 contains the ICMPv4-specifics stats. + V6 ICMPv6Stats +} + +// IGMPPacketStats enumerates counts for all IGMP packet types. +type IGMPPacketStats struct { + // MembershipQuery is the total number of Membership Query messages counted. + MembershipQuery *StatCounter + + // V1MembershipReport is the total number of Version 1 Membership Report + // messages counted. + V1MembershipReport *StatCounter + + // V2MembershipReport is the total number of Version 2 Membership Report + // messages counted. + V2MembershipReport *StatCounter + + // LeaveGroup is the total number of Leave Group messages counted. + LeaveGroup *StatCounter +} + +// IGMPSentPacketStats collects outbound IGMP-specific stats. +type IGMPSentPacketStats struct { + IGMPPacketStats + + // Dropped is the total number of IGMP packets dropped. + Dropped *StatCounter +} + +// IGMPReceivedPacketStats collects inbound IGMP-specific stats. +type IGMPReceivedPacketStats struct { + IGMPPacketStats + + // Invalid is the total number of IGMP packets received that IGMP could not + // parse. + Invalid *StatCounter + + // ChecksumErrors is the total number of IGMP packets dropped due to bad + // checksums. + ChecksumErrors *StatCounter + + // Unrecognized is the total number of unrecognized messages counted, these + // are silently ignored for forward-compatibilty. + Unrecognized *StatCounter +} + +// IGMPStats colelcts IGMP-specific stats. +type IGMPStats struct { + // IGMPSentPacketStats contains counts of sent packets by IGMP packet type + // and a single count of invalid packets received. + PacketsSent IGMPSentPacketStats + + // IGMPReceivedPacketStats contains counts of received packets by IGMP packet + // type and a single count of invalid packets received. + PacketsReceived IGMPReceivedPacketStats } // IPStats collects IP-specific stats (both v4 and v6). type IPStats struct { // PacketsReceived is the total number of IP packets received from the - // link layer in nic.DeliverNetworkPacket. + // link layer. PacketsReceived *StatCounter + // DisabledPacketsReceived is the total number of IP packets received from the + // link layer when the IP layer is disabled. + DisabledPacketsReceived *StatCounter + // InvalidDestinationAddressesReceived is the total number of IP packets // received with an unknown or invalid destination address. InvalidDestinationAddressesReceived *StatCounter @@ -1662,6 +1744,9 @@ type Stats struct { // ICMP breaks out ICMP-specific stats (both v4 and v6). ICMP ICMPStats + // IGMP breaks out IGMP-specific stats. + IGMP IGMPStats + // IP breaks out IP-specific stats (both v4 and v6). IP IPStats diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index 1c8e2bc34..9bd563c46 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -226,3 +226,87 @@ func TestAddressWithPrefixSubnet(t *testing.T) { } } } + +func TestAddressUnspecified(t *testing.T) { + tests := []struct { + addr Address + unspecified bool + }{ + { + addr: "", + unspecified: true, + }, + { + addr: "\x00", + unspecified: true, + }, + { + addr: "\x01", + unspecified: false, + }, + { + addr: "\x00\x00", + unspecified: true, + }, + { + addr: "\x01\x00", + unspecified: false, + }, + { + addr: "\x00\x01", + unspecified: false, + }, + { + addr: "\x01\x01", + unspecified: false, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("addr=%s", test.addr), func(t *testing.T) { + if got := test.addr.Unspecified(); got != test.unspecified { + t.Fatalf("got addr.Unspecified() = %t, want = %t", got, test.unspecified) + } + }) + } +} + +func TestAddressMatchingPrefix(t *testing.T) { + tests := []struct { + addrA Address + addrB Address + prefix uint8 + }{ + { + addrA: "\x01\x01", + addrB: "\x01\x01", + prefix: 16, + }, + { + addrA: "\x01\x01", + addrB: "\x01\x00", + prefix: 15, + }, + { + addrA: "\x01\x01", + addrB: "\x81\x00", + prefix: 0, + }, + { + addrA: "\x01\x01", + addrB: "\x01\x80", + prefix: 8, + }, + { + addrA: "\x01\x01", + addrB: "\x02\x80", + prefix: 6, + }, + } + + for _, test := range tests { + if got := test.addrA.MatchingPrefix(test.addrB); got != test.prefix { + t.Errorf("got (%s).MatchingPrefix(%s) = %d, want = %d", test.addrA, test.addrB, got, test.prefix) + } + } +} diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 9b0f3b675..ca1e88e99 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -15,16 +15,19 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/ethernet", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/nested", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index bf7594268..4c2084d19 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -15,13 +15,16 @@ package integration_test import ( + "bytes" "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -31,6 +34,33 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +var _ stack.NetworkDispatcher = (*endpointWithDestinationCheck)(nil) +var _ stack.LinkEndpoint = (*endpointWithDestinationCheck)(nil) + +// newEthernetEndpoint returns an ethernet link endpoint that wraps an inner +// link endpoint and checks the destination link address before delivering +// network packets to the network dispatcher. +// +// See ethernet.Endpoint for more details. +func newEthernetEndpoint(ep stack.LinkEndpoint) *endpointWithDestinationCheck { + var e endpointWithDestinationCheck + e.Endpoint.Init(ethernet.New(ep), &e) + return &e +} + +// endpointWithDestinationCheck is a link endpoint that checks the destination +// link address before delivering network packets to the network dispatcher. +type endpointWithDestinationCheck struct { + nested.Endpoint +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *endpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { + e.Endpoint.DeliverNetworkPacket(src, dst, proto, pkt) + } +} + func TestForwarding(t *testing.T) { const ( host1NICID = 1 @@ -209,16 +239,16 @@ func TestForwarding(t *testing.T) { host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) - if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { + if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) } - if err := routerStack.CreateNIC(routerNICID1, ethernet.New(routerNIC1)); err != nil { + if err := routerStack.CreateNIC(routerNICID1, newEthernetEndpoint(routerNIC1)); err != nil { t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) } - if err := routerStack.CreateNIC(routerNICID2, ethernet.New(routerNIC2)); err != nil { + if err := routerStack.CreateNIC(routerNICID2, newEthernetEndpoint(routerNIC2)); err != nil { t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) } - if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil { + if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } @@ -229,19 +259,6 @@ func TestForwarding(t *testing.T) { t.Fatalf("routerStack.SetForwarding(%d): %s", ipv6.ProtocolNumber, err) } - if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := routerStack.AddAddress(routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID1, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := routerStack.AddAddress(routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("routerStack.AddAddress(%d, %d, %s): %s", routerNICID2, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) } @@ -268,58 +285,58 @@ func TestForwarding(t *testing.T) { } host1Stack.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), NIC: host1NICID, }, - tcpip.Route{ + { Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), NIC: host1NICID, }, - tcpip.Route{ + { Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), Gateway: routerNIC1IPv4Addr.AddressWithPrefix.Address, NIC: host1NICID, }, - tcpip.Route{ + { Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), Gateway: routerNIC1IPv6Addr.AddressWithPrefix.Address, NIC: host1NICID, }, }) routerStack.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: routerNIC1IPv4Addr.AddressWithPrefix.Subnet(), NIC: routerNICID1, }, - tcpip.Route{ + { Destination: routerNIC1IPv6Addr.AddressWithPrefix.Subnet(), NIC: routerNICID1, }, - tcpip.Route{ + { Destination: routerNIC2IPv4Addr.AddressWithPrefix.Subnet(), NIC: routerNICID2, }, - tcpip.Route{ + { Destination: routerNIC2IPv6Addr.AddressWithPrefix.Subnet(), NIC: routerNICID2, }, }) host2Stack.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), NIC: host2NICID, }, - tcpip.Route{ + { Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), NIC: host2NICID, }, - tcpip.Route{ + { Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), Gateway: routerNIC2IPv4Addr.AddressWithPrefix.Address, NIC: host2NICID, }, - tcpip.Route{ + { Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), Gateway: routerNIC2IPv6Addr.AddressWithPrefix.Address, NIC: host2NICID, @@ -366,24 +383,33 @@ func TestForwarding(t *testing.T) { // Wait for the endpoint to be readable. <-ch - var addr tcpip.FullAddress - v, _, err := ep.Read(&addr) + var buf bytes.Buffer + opts := tcpip.ReadOptions{NeedRemoteAddr: true} + res, err := ep.Read(&buf, len(data), opts) if err != nil { - t.Fatalf("ep.Read(_): %s", err) + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) } - if diff := cmp.Diff(v, buffer.View(data)); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: len(data), + Total: len(data), + RemoteAddr: tcpip.FullAddress{Addr: expectedFrom}, + }, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) } - if addr.Addr != expectedFrom { - t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, expectedFrom) + if diff := cmp.Diff(buf.Bytes(), data); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) } if t.Failed() { t.FailNow() } - return addr + return res.RemoteAddr } addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr) diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index fe7c1bb3d..b4bffaec1 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -15,14 +15,14 @@ package integration_test import ( + "bytes" "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -87,21 +87,21 @@ func TestPing(t *testing.T) { transProto tcpip.TransportProtocolNumber netProto tcpip.NetworkProtocolNumber remoteAddr tcpip.Address - icmpBuf func(*testing.T) buffer.View + icmpBuf func(*testing.T) []byte }{ { name: "IPv4 Ping", transProto: icmp.ProtocolNumber4, netProto: ipv4.ProtocolNumber, remoteAddr: ipv4Addr2.AddressWithPrefix.Address, - icmpBuf: func(t *testing.T) buffer.View { + icmpBuf: func(t *testing.T) []byte { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) hdr.SetType(header.ICMPv4Echo) if n := copy(hdr.Payload(), data[:]); n != len(data) { t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) } - return buffer.View(hdr) + return hdr }, }, { @@ -109,14 +109,14 @@ func TestPing(t *testing.T) { transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, remoteAddr: ipv6Addr2.AddressWithPrefix.Address, - icmpBuf: func(t *testing.T) buffer.View { + icmpBuf: func(t *testing.T) []byte { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) hdr.SetType(header.ICMPv6EchoRequest) if n := copy(hdr.Payload(), data[:]); n != len(data) { t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) } - return buffer.View(hdr) + return hdr }, }, } @@ -133,20 +133,13 @@ func TestPing(t *testing.T) { host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2) - if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { + if err := host1Stack.CreateNIC(host1NICID, newEthernetEndpoint(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) } - if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil { + if err := host2Stack.CreateNIC(host2NICID, newEthernetEndpoint(host2NIC)); err != nil { t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } - if err := host1Stack.AddAddress(host1NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("host1Stack.AddAddress(%d, %d, %s): %s", host1NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := host2Stack.AddAddress(host2NICID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) - } - if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil { t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err) } @@ -161,21 +154,21 @@ func TestPing(t *testing.T) { } host1Stack.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: ipv4Addr1.AddressWithPrefix.Subnet(), NIC: host1NICID, }, - tcpip.Route{ + { Destination: ipv6Addr1.AddressWithPrefix.Subnet(), NIC: host1NICID, }, }) host2Stack.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: ipv4Addr2.AddressWithPrefix.Subnet(), NIC: host2NICID, }, - tcpip.Route{ + { Destination: ipv6Addr2.AddressWithPrefix.Subnet(), NIC: host2NICID, }, @@ -208,16 +201,25 @@ func TestPing(t *testing.T) { // Wait for the endpoint to be readable. <-waiterCH - var addr tcpip.FullAddress - v, _, err := ep.Read(&addr) + var buf bytes.Buffer + opts := tcpip.ReadOptions{NeedRemoteAddr: true} + res, err := ep.Read(&buf, len(icmpBuf), opts) if err != nil { - t.Fatalf("ep.Read(_): %s", err) + t.Fatalf("ep.Read(_, %d, %#v): %s", len(icmpBuf), opts, err) } - if diff := cmp.Diff(v[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + RemoteAddr: tcpip.FullAddress{Addr: test.remoteAddr}, + }, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) } - if addr.Addr != test.remoteAddr { - t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.remoteAddr) + if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) } }) } diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 421da1add..cb6169cfc 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -15,17 +15,20 @@ package integration_test import ( + "bytes" "testing" "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -70,8 +73,8 @@ func TestInitialLoopbackAddresses(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocolWithOptions(ipv6.Options{ - NDPDisp: &ndpDispatcher{}, - AutoGenIPv6LinkLocal: true, + NDPDisp: &ndpDispatcher{}, + AutoGenLinkLocal: true, OpaqueIIDOpts: ipv6.OpaqueInterfaceIdentifierOptions{ NICNameFromID: func(nicID tcpip.NICID, nicName string) string { t.Fatalf("should not attempt to get name for NIC with ID = %d; nicName = %s", nicID, nicName) @@ -93,9 +96,10 @@ func TestInitialLoopbackAddresses(t *testing.T) { } } -// TestLoopbackAcceptAllInSubnet tests that a loopback interface considers -// itself bound to all addresses in the subnet of an assigned address. -func TestLoopbackAcceptAllInSubnet(t *testing.T) { +// TestLoopbackAcceptAllInSubnetUDP tests that a loopback interface considers +// itself bound to all addresses in the subnet of an assigned address and UDP +// traffic is sent/received correctly. +func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { const ( nicID = 1 localPort = 80 @@ -107,7 +111,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr, } - ipv4Bytes := []byte(ipv4Addr.Address) + ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address) ipv4Bytes[len(ipv4Bytes)-1]++ otherIPv4Address := tcpip.Address(ipv4Bytes) @@ -129,7 +133,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { { name: "IPv4 bind to wildcard and send to assigned address", addAddress: ipv4ProtocolAddress, - dstAddr: ipv4Addr.Address, + dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, expectRx: true, }, { @@ -148,7 +152,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { name: "IPv4 bind to other subnet-local address and send to assigned address", addAddress: ipv4ProtocolAddress, bindAddr: otherIPv4Address, - dstAddr: ipv4Addr.Address, + dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, expectRx: false, }, { @@ -161,7 +165,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { { name: "IPv4 bind to assigned address and send to other subnet-local address", addAddress: ipv4ProtocolAddress, - bindAddr: ipv4Addr.Address, + bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, dstAddr: otherIPv4Address, expectRx: false, }, @@ -194,11 +198,11 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err) } s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: header.IPv4EmptySubnet, NIC: nicID, }, - tcpip.Route{ + { Destination: header.IPv6EmptySubnet, NIC: nicID, }, @@ -236,17 +240,28 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want) } - if gotPayload, _, err := rep.Read(nil); test.expectRx { + var buf bytes.Buffer + opts := tcpip.ReadOptions{NeedRemoteAddr: true} + if res, err := rep.Read(&buf, len(data), opts); test.expectRx { if err != nil { - t.Fatalf("reep.Read(nil): %s", err) + t.Fatalf("rep.Read(_, %d, %#v): %s", len(data), opts, err) } - if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + RemoteAddr: tcpip.FullAddress{ + Addr: test.addAddress.AddressWithPrefix.Address, + }, + }, res, + checker.IgnoreCmpPath("ControlMessages", "RemoteAddr.NIC", "RemoteAddr.Port"), + ); diff != "" { + t.Errorf("rep.Read: unexpected result (-want +got):\n%s", diff) } - } else { - if err != tcpip.ErrWouldBlock { - t.Fatalf("got rep.Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + if diff := cmp.Diff(data, buf.Bytes()); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } + } else if err != tcpip.ErrWouldBlock { + t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) } }) } @@ -276,7 +291,7 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) } s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: header.IPv4EmptySubnet, NIC: nicID, }, @@ -312,3 +327,168 @@ func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState) } } + +// TestLoopbackAcceptAllInSubnetTCP tests that a loopback interface considers +// itself bound to all addresses in the subnet of an assigned address and TCP +// traffic is sent/received correctly. +func TestLoopbackAcceptAllInSubnetTCP(t *testing.T) { + const ( + nicID = 1 + localPort = 80 + ) + + ipv4ProtocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + } + ipv4ProtocolAddress.AddressWithPrefix.PrefixLen = 8 + ipv4Bytes := []byte(ipv4ProtocolAddress.AddressWithPrefix.Address) + ipv4Bytes[len(ipv4Bytes)-1]++ + otherIPv4Address := tcpip.Address(ipv4Bytes) + + ipv6ProtocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + } + ipv6Bytes := []byte(ipv6Addr.Address) + ipv6Bytes[len(ipv6Bytes)-1]++ + otherIPv6Address := tcpip.Address(ipv6Bytes) + + tests := []struct { + name string + addAddress tcpip.ProtocolAddress + bindAddr tcpip.Address + dstAddr tcpip.Address + expectAccept bool + }{ + { + name: "IPv4 bind to wildcard and send to assigned address", + addAddress: ipv4ProtocolAddress, + dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, + expectAccept: true, + }, + { + name: "IPv4 bind to wildcard and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + dstAddr: otherIPv4Address, + expectAccept: true, + }, + { + name: "IPv4 bind to wildcard send to other address", + addAddress: ipv4ProtocolAddress, + dstAddr: remoteIPv4Addr, + expectAccept: false, + }, + { + name: "IPv4 bind to other subnet-local address and send to assigned address", + addAddress: ipv4ProtocolAddress, + bindAddr: otherIPv4Address, + dstAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, + expectAccept: false, + }, + { + name: "IPv4 bind and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + bindAddr: otherIPv4Address, + dstAddr: otherIPv4Address, + expectAccept: true, + }, + { + name: "IPv4 bind to assigned address and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + bindAddr: ipv4ProtocolAddress.AddressWithPrefix.Address, + dstAddr: otherIPv4Address, + expectAccept: false, + }, + + { + name: "IPv6 bind and send to assigned address", + addAddress: ipv6ProtocolAddress, + bindAddr: ipv6Addr.Address, + dstAddr: ipv6Addr.Address, + expectAccept: true, + }, + { + name: "IPv6 bind to wildcard and send to other subnet-local address", + addAddress: ipv6ProtocolAddress, + dstAddr: otherIPv6Address, + expectAccept: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, test.addAddress, err) + } + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + listeningEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) + } + defer listeningEndpoint.Close() + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} + if err := listeningEndpoint.Bind(bindAddr); err != nil { + t.Fatalf("listeningEndpoint.Bind(%#v): %s", bindAddr, err) + } + + if err := listeningEndpoint.Listen(1); err != nil { + t.Fatalf("listeningEndpoint.Listen(1): %s", err) + } + + connectingEndpoint, err := s.NewEndpoint(tcp.ProtocolNumber, test.addAddress.Protocol, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) + } + defer connectingEndpoint.Close() + + connectAddr := tcpip.FullAddress{ + Addr: test.dstAddr, + Port: localPort, + } + if err := connectingEndpoint.Connect(connectAddr); err != tcpip.ErrConnectStarted { + t.Fatalf("connectingEndpoint.Connect(%#v): %s", connectAddr, err) + } + + if !test.expectAccept { + if _, _, err := listeningEndpoint.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got listeningEndpoint.Accept(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + return + } + + // Wait for the listening endpoint to be "readable". That is, wait for a + // new connection. + <-ch + var addr tcpip.FullAddress + if _, _, err := listeningEndpoint.Accept(&addr); err != nil { + t.Fatalf("listeningEndpoint.Accept(nil): %s", err) + } + if addr.Addr != test.addAddress.AddressWithPrefix.Address { + t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 1eecd7957..b42375695 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -15,12 +15,14 @@ package integration_test import ( + "bytes" "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -35,6 +37,9 @@ import ( const ( defaultMTU = 1280 ttl = 255 + + remotePort = 5555 + localPort = 80 ) var ( @@ -96,11 +101,11 @@ func TestPingMulticastBroadcast(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: ttl, - SrcAddr: remoteIPv6Addr, - DstAddr: dst, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -151,21 +156,21 @@ func TestPingMulticastBroadcast(t *testing.T) { } ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err) + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) } ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr} if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err) + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv6ProtoAddr, err) } // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote // node when attempting to send the ICMP Echo Reply. s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: header.IPv6EmptySubnet, NIC: nicID, }, - tcpip.Route{ + { Destination: header.IPv4EmptySubnet, NIC: nicID, }, @@ -215,163 +220,219 @@ func TestPingMulticastBroadcast(t *testing.T) { } +func rxIPv4UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) { + payloadLen := header.UDPMinimumSize + len(data) + totalLen := header.IPv4MinimumSize + payloadLen + hdr := buffer.NewPrependable(totalLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + Protocol: uint8(udp.ProtocolNumber), + TTL: ttl, + SrcAddr: src, + DstAddr: dst, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) +} + +func rxIPv6UDP(e *channel.Endpoint, src, dst tcpip.Address, data []byte) { + payloadLen := header.UDPMinimumSize + len(data) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLen), + TransportProtocol: udp.ProtocolNumber, + HopLimit: ttl, + SrcAddr: src, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) +} + // TestIncomingMulticastAndBroadcast tests receiving a packet destined to some // multicast or broadcast address. func TestIncomingMulticastAndBroadcast(t *testing.T) { - const ( - nicID = 1 - remotePort = 5555 - localPort = 80 - ) + const nicID = 1 data := []byte{1, 2, 3, 4} - rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) { - payloadLen := header.UDPMinimumSize + len(data) - totalLen := header.IPv4MinimumSize + payloadLen - hdr := buffer.NewPrependable(totalLen) - u := header.UDP(hdr.Prepend(payloadLen)) - u.Encode(&header.UDPFields{ - SrcPort: remotePort, - DstPort: localPort, - Length: uint16(payloadLen), - }) - copy(u.Payload(), data) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv4Addr, dst, uint16(payloadLen)) - sum = header.Checksum(data, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) - ip.Encode(&header.IPv4Fields{ - TotalLength: uint16(totalLen), - Protocol: uint8(udp.ProtocolNumber), - TTL: ttl, - SrcAddr: remoteIPv4Addr, - DstAddr: dst, - }) - ip.SetChecksum(^ip.CalculateChecksum()) - - e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - - rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) { - payloadLen := header.UDPMinimumSize + len(data) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) - u := header.UDP(hdr.Prepend(payloadLen)) - u.Encode(&header.UDPFields{ - SrcPort: remotePort, - DstPort: localPort, - Length: uint16(payloadLen), - }) - copy(u.Payload(), data) - sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv6Addr, dst, uint16(payloadLen)) - sum = header.Checksum(data, sum) - u.SetChecksum(^u.CalculateChecksum(sum)) - - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLen), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: ttl, - SrcAddr: remoteIPv6Addr, - DstAddr: dst, - }) - - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) - } - tests := []struct { - name string - bindAddr tcpip.Address - dstAddr tcpip.Address - expectRx bool + name string + proto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + localAddr tcpip.AddressWithPrefix + rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte) + bindAddr tcpip.Address + dstAddr tcpip.Address + expectRx bool }{ { - name: "IPv4 unicast binding to unicast", - bindAddr: ipv4Addr.Address, - dstAddr: ipv4Addr.Address, - expectRx: true, + name: "IPv4 unicast binding to unicast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: ipv4Addr.Address, + dstAddr: ipv4Addr.Address, + expectRx: true, }, { - name: "IPv4 unicast binding to broadcast", - bindAddr: header.IPv4Broadcast, - dstAddr: ipv4Addr.Address, - expectRx: false, + name: "IPv4 unicast binding to broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4Addr.Address, + expectRx: false, }, { - name: "IPv4 unicast binding to wildcard", - dstAddr: ipv4Addr.Address, - expectRx: true, + name: "IPv4 unicast binding to wildcard", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + dstAddr: ipv4Addr.Address, + expectRx: true, }, { - name: "IPv4 directed broadcast binding to subnet broadcast", - bindAddr: ipv4SubnetBcast, - dstAddr: ipv4SubnetBcast, - expectRx: true, + name: "IPv4 directed broadcast binding to subnet broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: ipv4SubnetBcast, + dstAddr: ipv4SubnetBcast, + expectRx: true, }, { - name: "IPv4 directed broadcast binding to broadcast", - bindAddr: header.IPv4Broadcast, - dstAddr: ipv4SubnetBcast, - expectRx: false, + name: "IPv4 directed broadcast binding to broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4SubnetBcast, + expectRx: false, }, { - name: "IPv4 directed broadcast binding to wildcard", - dstAddr: ipv4SubnetBcast, - expectRx: true, + name: "IPv4 directed broadcast binding to wildcard", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + dstAddr: ipv4SubnetBcast, + expectRx: true, }, { - name: "IPv4 broadcast binding to broadcast", - bindAddr: header.IPv4Broadcast, - dstAddr: header.IPv4Broadcast, - expectRx: true, + name: "IPv4 broadcast binding to broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: header.IPv4Broadcast, + dstAddr: header.IPv4Broadcast, + expectRx: true, }, { - name: "IPv4 broadcast binding to subnet broadcast", - bindAddr: ipv4SubnetBcast, - dstAddr: header.IPv4Broadcast, - expectRx: false, + name: "IPv4 broadcast binding to subnet broadcast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: ipv4SubnetBcast, + dstAddr: header.IPv4Broadcast, + expectRx: false, }, { - name: "IPv4 broadcast binding to wildcard", - dstAddr: ipv4SubnetBcast, - expectRx: true, + name: "IPv4 broadcast binding to wildcard", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + dstAddr: ipv4SubnetBcast, + expectRx: true, }, { - name: "IPv4 all-systems multicast binding to all-systems multicast", - bindAddr: header.IPv4AllSystems, - dstAddr: header.IPv4AllSystems, - expectRx: true, + name: "IPv4 all-systems multicast binding to all-systems multicast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: header.IPv4AllSystems, + dstAddr: header.IPv4AllSystems, + expectRx: true, }, { - name: "IPv4 all-systems multicast binding to wildcard", - dstAddr: header.IPv4AllSystems, - expectRx: true, + name: "IPv4 all-systems multicast binding to wildcard", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + dstAddr: header.IPv4AllSystems, + expectRx: true, }, { - name: "IPv4 all-systems multicast binding to unicast", - bindAddr: ipv4Addr.Address, - dstAddr: header.IPv4AllSystems, - expectRx: false, + name: "IPv4 all-systems multicast binding to unicast", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + bindAddr: ipv4Addr.Address, + dstAddr: header.IPv4AllSystems, + expectRx: false, }, // IPv6 has no notion of a broadcast. { - name: "IPv6 unicast binding to wildcard", - dstAddr: ipv6Addr.Address, - expectRx: true, + name: "IPv6 unicast binding to wildcard", + dstAddr: ipv6Addr.Address, + proto: header.IPv6ProtocolNumber, + remoteAddr: remoteIPv6Addr, + localAddr: ipv6Addr, + rxUDP: rxIPv6UDP, + expectRx: true, }, { - name: "IPv6 broadcast-like address binding to wildcard", - dstAddr: ipv6SubnetBcast, - expectRx: false, + name: "IPv6 broadcast-like address binding to wildcard", + dstAddr: ipv6SubnetBcast, + proto: header.IPv6ProtocolNumber, + remoteAddr: remoteIPv6Addr, + localAddr: ipv6Addr, + rxUDP: rxIPv6UDP, + expectRx: false, }, } @@ -385,52 +446,41 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } - ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} - if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err) - } - ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr} - if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err) - } - - var netproto tcpip.NetworkProtocolNumber - var rxUDP func(*channel.Endpoint, tcpip.Address) - switch l := len(test.dstAddr); l { - case header.IPv4AddressSize: - netproto = header.IPv4ProtocolNumber - rxUDP = rxIPv4UDP - case header.IPv6AddressSize: - netproto = header.IPv6ProtocolNumber - rxUDP = rxIPv6UDP - default: - t.Fatalf("got unexpected address length = %d bytes", l) + protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) } var wq waiter.Queue - ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq) + ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq) if err != nil { - t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err) + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) } defer ep.Close() bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("ep.Bind(%+v): %s", bindAddr, err) + t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) } - rxUDP(e, test.dstAddr) - if gotPayload, _, err := ep.Read(nil); test.expectRx { + test.rxUDP(e, test.remoteAddr, test.dstAddr, data) + var buf bytes.Buffer + var opts tcpip.ReadOptions + if res, err := ep.Read(&buf, len(data), opts); test.expectRx { if err != nil { - t.Fatalf("Read(nil): %s", err) + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) } - if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) } - } else { - if err != tcpip.ErrWouldBlock { - t.Fatalf("got Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + if diff := cmp.Diff(data, buf.Bytes()); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } + } else if err != tcpip.ErrWouldBlock { + t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) } }) } @@ -476,11 +526,11 @@ func TestReuseAddrAndBroadcast(t *testing.T) { }, } if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { - t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err) + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) } s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { // We use the empty subnet instead of just the loopback subnet so we // also have a route to the IPv4 Broadcast address. Destination: header.IPv4EmptySubnet, @@ -510,23 +560,18 @@ func TestReuseAddrAndBroadcast(t *testing.T) { } defer ep.Close() - if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err) - } - - if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { - t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err) - } + ep.SocketOptions().SetReuseAddress(true) + ep.SocketOptions().SetBroadcast(true) bindAddr := tcpip.FullAddress{Port: localPort} if bindWildcard { if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err) } } else { bindAddr.Addr = test.broadcastAddr if err := ep.Bind(bindAddr); err != nil { - t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + t.Fatalf("eps[%d].Bind(%#v): %s", len(eps), bindAddr, err) } } @@ -552,9 +597,19 @@ func TestReuseAddrAndBroadcast(t *testing.T) { // Wait for the endpoint to become readable. <-rep.ch - if gotPayload, _, err := rep.ep.Read(nil); err != nil { - t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err) - } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + var buf bytes.Buffer + result, err := rep.ep.Read(&buf, len(data), tcpip.ReadOptions{}) + if err != nil { + t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err) + continue + } + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("(eps[%d] write) eps[%d].Read: unexpected result (-want +got):\n%s", i, j, diff) + } + if diff := cmp.Diff([]byte(data), buf.Bytes()); diff != "" { t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) } } @@ -562,3 +617,153 @@ func TestReuseAddrAndBroadcast(t *testing.T) { }) } } + +func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { + const ( + nicID = 1 + ) + + data := []byte{1, 2, 3, 4} + + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + localAddr tcpip.AddressWithPrefix + rxUDP func(*channel.Endpoint, tcpip.Address, tcpip.Address, []byte) + multicastAddr tcpip.Address + }{ + { + name: "IPv4 unicast binding to unicast", + multicastAddr: "\xe0\x01\x02\x03", + proto: header.IPv4ProtocolNumber, + remoteAddr: remoteIPv4Addr, + localAddr: ipv4Addr, + rxUDP: rxIPv4UDP, + }, + { + name: "IPv6 broadcast-like address binding to wildcard", + multicastAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04", + proto: header.IPv6ProtocolNumber, + remoteAddr: remoteIPv6Addr, + localAddr: ipv6Addr, + rxUDP: rxIPv6UDP, + }, + } + + subTests := []struct { + name string + specifyNICID bool + specifyNICAddr bool + }{ + { + name: "Specify NIC ID and NIC address", + specifyNICID: true, + specifyNICAddr: true, + }, + { + name: "Don't specify NIC ID or NIC address", + specifyNICID: false, + specifyNICAddr: false, + }, + { + name: "Specify NIC ID but don't specify NIC address", + specifyNICID: true, + specifyNICAddr: false, + }, + { + name: "Don't specify NIC ID but specify NIC address", + specifyNICID: false, + specifyNICAddr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + protoAddr := tcpip.ProtocolAddress{Protocol: test.proto, AddressWithPrefix: test.localAddr} + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + } + + // Set the route table so that UDP can find a NIC that is + // routable to the multicast address when the NIC isn't specified. + if !subTest.specifyNICID && !subTest.specifyNICAddr { + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + { + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + } + + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, test.proto, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Port: localPort} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("ep.Bind(%#v): %s", bindAddr, err) + } + + memOpt := tcpip.MembershipOption{MulticastAddr: test.multicastAddr} + if subTest.specifyNICID { + memOpt.NIC = nicID + } + if subTest.specifyNICAddr { + memOpt.InterfaceAddr = test.localAddr.Address + } + + // We should receive UDP packets to the group once we join the + // multicast group. + addOpt := tcpip.AddMembershipOption(memOpt) + if err := ep.SetSockOpt(&addOpt); err != nil { + t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err) + } + test.rxUDP(e, test.remoteAddr, test.multicastAddr, data) + var buf bytes.Buffer + result, err := ep.Read(&buf, len(data), tcpip.ReadOptions{}) + if err != nil { + t.Fatalf("ep.Read: %s", err) + } else { + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) + } + if diff := cmp.Diff(data, buf.Bytes()); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + } + } + + // We should not receive UDP packets to the group once we leave + // the multicast group. + removeOpt := tcpip.RemoveMembershipOption(memOpt) + if err := ep.SetSockOpt(&removeOpt); err != nil { + t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err) + } + if _, err := ep.Read(&buf, 1, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { + t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 02fc47015..52cf89b54 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -15,11 +15,14 @@ package integration_test import ( + "bytes" + "math" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -203,16 +206,25 @@ func TestLocalPing(t *testing.T) { // Wait for the endpoint to become readable. <-ch - var addr tcpip.FullAddress - v, _, err := ep.Read(&addr) + var buf bytes.Buffer + opts := tcpip.ReadOptions{NeedRemoteAddr: true} + res, err := ep.Read(&buf, math.MaxUint16, opts) if err != nil { - t.Fatalf("ep.Read(_): %s", err) + t.Fatalf("ep.Read(_, %d, %#v): %s", math.MaxUint16, opts, err) } - if diff := cmp.Diff(v[icmpDataOffset:], buffer.View(payload[icmpDataOffset:])); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + RemoteAddr: tcpip.FullAddress{Addr: test.localAddr}, + }, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) } - if addr.Addr != test.localAddr { - t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.localAddr) + if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], []byte(payload[icmpDataOffset:])); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) } test.checkLinkEndpoint(t, e) @@ -338,14 +350,27 @@ func TestLocalUDP(t *testing.T) { <-serverCH var clientAddr tcpip.FullAddress - if v, _, err := server.Read(&clientAddr); err != nil { + var readBuf bytes.Buffer + if read, err := server.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { t.Fatalf("server.Read(_): %s", err) } else { - if diff := cmp.Diff(buffer.View(clientPayload), v); diff != "" { - t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff) + clientAddr = read.RemoteAddr + + if diff := cmp.Diff(tcpip.ReadResult{ + Count: readBuf.Len(), + Total: readBuf.Len(), + RemoteAddr: tcpip.FullAddress{ + Addr: test.canBePrimaryAddr.AddressWithPrefix.Address, + }, + }, read, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("server.Read: unexpected result (-want +got):\n%s", diff) } - if clientAddr.Addr != test.canBePrimaryAddr.AddressWithPrefix.Address { - t.Errorf("got clientAddr.Addr = %s, want = %s", clientAddr.Addr, test.canBePrimaryAddr.AddressWithPrefix.Address) + if diff := cmp.Diff(buffer.View(clientPayload), buffer.View(readBuf.Bytes())); diff != "" { + t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff) } if t.Failed() { t.FailNow() @@ -367,15 +392,23 @@ func TestLocalUDP(t *testing.T) { // Wait for the client endpoint to become readable. <-clientCH - var gotServerAddr tcpip.FullAddress - if v, _, err := client.Read(&gotServerAddr); err != nil { + readBuf.Reset() + if read, err := client.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { t.Fatalf("client.Read(_): %s", err) } else { - if diff := cmp.Diff(buffer.View(serverPayload), v); diff != "" { - t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: readBuf.Len(), + Total: readBuf.Len(), + RemoteAddr: tcpip.FullAddress{Addr: serverAddr.Addr}, + }, read, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("client.Read: unexpected result (-want +got):\n%s", diff) } - if gotServerAddr.Addr != serverAddr.Addr { - t.Errorf("got gotServerAddr.Addr = %s, want = %s", gotServerAddr.Addr, serverAddr.Addr) + if diff := cmp.Diff(buffer.View(serverPayload), buffer.View(readBuf.Bytes())); diff != "" { + t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff) } if t.Failed() { t.FailNow() diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 763cd8f84..c32fe5c4f 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,6 +15,8 @@ package icmp import ( + "io" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -49,6 +51,7 @@ const ( // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. @@ -71,18 +74,19 @@ type endpoint struct { // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags state endpointState - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner + + // ops is used to get socket level options. + ops tcpip.SocketOptions } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return &endpoint{ + ep := &endpoint{ stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{ NetProto: netProto, @@ -93,7 +97,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt sndBufSize: 32 * 1024, state: stateInitial, uniqueID: s.UniqueID(), - }, nil + } + ep.ops.InitHandler(ep) + return ep, nil } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -126,7 +132,10 @@ func (e *endpoint) Close() { } e.rcvMu.Unlock() - e.route.Release() + if e.route != nil { + e.route.Release() + e.route = nil + } // Update the state. e.state = stateClosed @@ -139,13 +148,13 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} +// SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// Read reads data from the endpoint. This method does not block if -// there is no data pending. -func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.Endpoint.Read. +func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { e.rcvMu.Lock() if e.rcvList.Empty() { @@ -155,20 +164,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } p := e.rcvList.Front() - e.rcvList.Remove(p) - e.rcvBufSize -= p.data.Size() + if !opts.Peek { + e.rcvList.Remove(p) + e.rcvBufSize -= p.data.Size() + } e.rcvMu.Unlock() - if addr != nil { - *addr = p.senderAddress + res := tcpip.ReadResult{ + Total: p.data.Size(), + ControlMessages: tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: p.timestamp, + }, + } + if opts.NeedRemoteAddr { + res.RemoteAddr = p.senderAddress } - return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil + n, err := p.data.ReadTo(dst, count, opts.Peek) + if n == 0 && err != nil { + return res, tcpip.ErrBadBuffer + } + res.Count = n + return res, nil } // prepareForWrite prepares the endpoint for sending data. In particular, it @@ -264,26 +287,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } } - var route *stack.Route - if to == nil { - route = &e.route - - if route.IsResolutionRequired() { - // Promote lock to exclusive if using a shared route, - // given that it may need to change in Route.Resolve() - // call below. - e.mu.RUnlock() - defer e.mu.RLock() - - e.mu.Lock() - defer e.mu.Unlock() - - // Recheck state after lock was re-acquired. - if e.state != stateConnected { - return 0, nil, tcpip.ErrInvalidEndpointState - } - } - } else { + route := e.route + if to != nil { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicID := to.NIC @@ -307,7 +312,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } defer r.Release() - route = &r + route = r } if route.IsResolutionRequired() { @@ -339,27 +344,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return int64(len(v)), nil, nil } -// Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil -} - // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { - case *tcpip.SocketDetachFilterOption: - return nil - - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() - } - return nil -} - -// SetSockOptBool sets a socket option. Currently not supported. -func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { return nil } @@ -375,17 +361,6 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: - return false, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } -} - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -423,16 +398,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.mu.Lock() - *o = e.linger - e.mu.Unlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { @@ -548,7 +514,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer r.Release() id := stack.TransportEndpointID{ LocalAddress: r.LocalAddress, @@ -563,11 +528,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { + r.Release() return err } e.ID = id - e.route = r.Clone() + e.route = r e.RegisterNICID = nicID e.state = stateConnected @@ -823,7 +789,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { } // State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't @@ -853,3 +819,8 @@ func (*endpoint) Wait() {} func (*endpoint) LastError() *tcpip.Error { return nil } + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (e *endpoint) SocketOptions() *tcpip.SocketOptions { + return &e.ops +} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 31831a6d8..3ab060751 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -26,6 +26,7 @@ package packet import ( "fmt" + "io" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -60,6 +61,8 @@ type packet struct { // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler + // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` @@ -83,12 +86,13 @@ type endpoint struct { stats tcpip.TransportEndpointStats `state:"nosave"` bound bool boundNIC tcpip.NICID - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` + + // ops is used to get socket level options. + ops tcpip.SocketOptions } // NewEndpoint returns a new packet endpoint. @@ -104,6 +108,7 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, } + ep.ops.InitHandler(ep) // Override with stack defaults. var ss stack.SendBufferSizeOption @@ -156,8 +161,8 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// Read implements tcpip.PacketEndpoint.ReadPacket. -func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -169,29 +174,37 @@ func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketIn err = tcpip.ErrClosedForReceive } ep.rcvMu.Unlock() - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } packet := ep.rcvList.Front() - ep.rcvList.Remove(packet) - ep.rcvBufSize -= packet.data.Size() + if !opts.Peek { + ep.rcvList.Remove(packet) + ep.rcvBufSize -= packet.data.Size() + } ep.rcvMu.Unlock() - if addr != nil { - *addr = packet.senderAddr + res := tcpip.ReadResult{ + Total: packet.data.Size(), + ControlMessages: tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: packet.timestampNS, + }, } - - if info != nil { - *info = packet.packetInfo + if opts.NeedRemoteAddr { + res.RemoteAddr = packet.senderAddr + } + if opts.NeedLinkPacketInfo { + res.LinkPacketInfo = packet.packetInfo } - return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil -} - -// Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - return ep.ReadPacket(addr, nil) + n, err := packet.data.ReadTo(dst, count, opts.Peek) + if n == 0 && err != nil { + return res, tcpip.ErrBadBuffer + } + res.Count = n + return res, nil } func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { @@ -199,11 +212,6 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha return 0, nil, tcpip.ErrInvalidOptionValue } -// Peek implements tcpip.Endpoint.Peek. -func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil -} - // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be // disconnected, and this function always returns tpcip.ErrNotSupported. func (*endpoint) Disconnect() *tcpip.Error { @@ -300,26 +308,15 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { + switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - ep.mu.Lock() - ep.linger = *v - ep.mu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } } -// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. -func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption -} - // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { @@ -373,28 +370,16 @@ func (ep *endpoint) LastError() *tcpip.Error { return err } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - ep.mu.Lock() - *o = ep.linger - ep.mu.Unlock() - return nil - - default: - return tcpip.ErrNotSupported - } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (ep *endpoint) UpdateLastError(err *tcpip.Error) { + ep.lastErrorMu.Lock() + ep.lastError = err + ep.lastErrorMu.Unlock() } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.AcceptConnOption: - return false, nil - default: - return false, tcpip.ErrNotSupported - } +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + return tcpip.ErrNotSupported } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -548,4 +533,10 @@ func (ep *endpoint) Stats() tcpip.EndpointStats { return &ep.stats } +// SetOwner implements tcpip.Endpoint.SetOwner. func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (ep *endpoint) SocketOptions() *tcpip.SocketOptions { + return &ep.ops +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 7b6a87ba9..dd260535f 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -27,6 +27,7 @@ package raw import ( "fmt" + "io" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -58,12 +59,13 @@ type rawPacket struct { // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler + // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue associated bool - hdrIncluded bool // The following fields are used to manage the receive queue and are // protected by rcvMu. @@ -82,13 +84,14 @@ type endpoint struct { bound bool // route is the route to a remote network endpoint. It is set via // Connect(), and is valid only when conneted is true. - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner + + // ops is used to get socket level options. + ops tcpip.SocketOptions } // NewEndpoint returns a raw endpoint for the given protocols. @@ -111,8 +114,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, associated: associated, - hdrIncluded: !associated, } + e.ops.InitHandler(e) + e.ops.SetHeaderIncluded(!associated) // Override with stack defaults. var ss stack.SendBufferSizeOption @@ -167,9 +171,11 @@ func (e *endpoint) Close() { e.rcvList.Remove(e.rcvList.Front()) } - if e.connected { + e.connected = false + + if e.route != nil { e.route.Release() - e.connected = false + e.route = nil } e.closed = true @@ -185,7 +191,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { e.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -197,20 +203,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } pkt := e.rcvList.Front() - e.rcvList.Remove(pkt) - e.rcvBufSize -= pkt.data.Size() + if !opts.Peek { + e.rcvList.Remove(pkt) + e.rcvBufSize -= pkt.data.Size() + } e.rcvMu.Unlock() - if addr != nil { - *addr = pkt.senderAddr + res := tcpip.ReadResult{ + Total: pkt.data.Size(), + ControlMessages: tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: pkt.timestampNS, + }, + } + if opts.NeedRemoteAddr { + res.RemoteAddr = pkt.senderAddr } - return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil + n, err := pkt.data.ReadTo(dst, count, opts.Peek) + if n == 0 && err != nil { + return res, tcpip.ErrBadBuffer + } + res.Count = n + return res, nil } // Write implements tcpip.Endpoint.Write. @@ -220,6 +240,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrInvalidOptionValue } + if opts.To != nil { + // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint. + if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { + return 0, nil, tcpip.ErrInvalidOptionValue + } + } + n, ch, err := e.write(p, opts) switch err { case nil: @@ -249,24 +276,22 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } e.mu.RLock() + defer e.mu.RUnlock() if e.closed { - e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } payloadBytes, err := p.FullPayload() if err != nil { - e.mu.RUnlock() return 0, nil, err } // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if e.hdrIncluded { + if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } dstAddr := ip.DestinationAddress() @@ -288,39 +313,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If the user doesn't specify a destination, they should have // connected to another address. if !e.connected { - e.mu.RUnlock() return 0, nil, tcpip.ErrDestinationRequired } - if e.route.IsResolutionRequired() { - savedRoute := &e.route - // Promote lock to exclusive if using a shared route, - // given that it may need to change in finishWrite. - e.mu.RUnlock() - e.mu.Lock() - - // Make sure that the route didn't change during the - // time we didn't hold the lock. - if !e.connected || savedRoute != &e.route { - e.mu.Unlock() - return 0, nil, tcpip.ErrInvalidEndpointState - } - - n, ch, err := e.finishWrite(payloadBytes, savedRoute) - e.mu.Unlock() - return n, ch, err - } - - n, ch, err := e.finishWrite(payloadBytes, &e.route) - e.mu.RUnlock() - return n, ch, err + return e.finishWrite(payloadBytes, e.route) } // The caller provided a destination. Reject destination address if it // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC if e.bound && nic != 0 && nic != e.BindNICID { - e.mu.RUnlock() return 0, nil, tcpip.ErrNoRoute } @@ -328,13 +330,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) if err != nil { - e.mu.RUnlock() return 0, nil, err } - n, ch, err := e.finishWrite(payloadBytes, &route) + n, ch, err := e.finishWrite(payloadBytes, route) route.Release() - e.mu.RUnlock() return n, ch, err } @@ -353,7 +353,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - if e.hdrIncluded { + if e.ops.GetHeaderIncluded() { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), }) @@ -378,11 +378,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, return int64(len(payloadBytes)), nil, nil } -// Peek implements tcpip.Endpoint.Peek. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil -} - // Disconnect implements tcpip.Endpoint.Disconnect. func (*endpoint) Disconnect() *tcpip.Error { return tcpip.ErrNotSupported @@ -390,6 +385,11 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect implements tcpip.Endpoint.Connect. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint. + if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { + return tcpip.ErrAddressFamilyNotSupported + } + e.mu.Lock() defer e.mu.Unlock() @@ -418,11 +418,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer route.Release() if e.associated { // Re-register the endpoint with the appropriate NIC. if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + route.Release() return err } e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) @@ -430,7 +430,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Save the route we've connected via. - e.route = route.Clone() + e.route = route e.connected = true return nil @@ -513,33 +513,15 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { + switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } } -// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. -func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.IPHdrIncludedOption: - e.mu.Lock() - e.hdrIncluded = v - e.mu.Unlock() - return nil - } - return tcpip.ErrUnknownProtocolOption -} - // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { @@ -586,33 +568,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.mu.Lock() - *o = e.linger - e.mu.Unlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } -} - -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: - return false, nil - - case tcpip.IPHdrIncludedOption: - e.mu.Lock() - v := e.hdrIncluded - e.mu.Unlock() - return v, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -647,6 +603,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + e.mu.RLock() e.rcvMu.Lock() // Drop the packet if our buffer is currently full or if this is an unassociated @@ -659,6 +616,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // sockets. if e.rcvClosed || !e.associated { e.rcvMu.Unlock() + e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ClosedReceiver.Increment() return @@ -666,6 +624,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if e.rcvBufSize >= e.rcvBufSizeMax { e.rcvMu.Unlock() + e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return @@ -677,11 +636,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // If bound to a NIC, only accept data for that NIC. if e.BindNICID != 0 && e.BindNICID != pkt.NICID { e.rcvMu.Unlock() + e.mu.RUnlock() return } // If bound to an address, only accept data for that address. if e.BindAddr != "" && e.BindAddr != remoteAddr { e.rcvMu.Unlock() + e.mu.RUnlock() return } } @@ -690,6 +651,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // connected to. if e.connected && e.route.RemoteAddress != remoteAddr { e.rcvMu.Unlock() + e.mu.RUnlock() return } @@ -724,6 +686,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() e.rcvMu.Unlock() + e.mu.RUnlock() e.stats.PacketsReceived.Increment() // Notify waiters that there's data to be read. if wasEmpty { @@ -753,6 +716,12 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements stack.TransportEndpoint.Wait. func (*endpoint) Wait() {} +// LastError implements tcpip.Endpoint.LastError. func (*endpoint) LastError() *tcpip.Error { return nil } + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (e *endpoint) SocketOptions() *tcpip.SocketOptions { + return &e.ops +} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 7d97cbdc7..4a7e1c039 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -73,7 +73,13 @@ func (e *endpoint) Resume(s *stack.Stack) { // If the endpoint is connected, re-connect. if e.connected { var err *tcpip.Error - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false) + // TODO(gvisor.dev/issue/4906): Properly restore the route with the right + // remote address. We used to pass e.remote.RemoteAddress which was + // effectively the empty address but since moving e.route to hold a pointer + // to a route instead of the route by value, we pass the empty address + // directly. Obviously this was always wrong since we should provide the + // remote address we were connected to, to properly restore the route. + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, "", e.NetProto, false) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 518449602..7e81203ba 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_library", "go_test", "more_shards") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -45,7 +45,9 @@ go_library( "rcv.go", "rcv_state.go", "reno.go", + "reno_recovery.go", "sack.go", + "sack_recovery.go", "sack_scoreboard.go", "segment.go", "segment_heap.go", @@ -91,7 +93,7 @@ go_test( "tcp_test.go", "tcp_timestamp_test.go", ], - shard_count = 10, + shard_count = more_shards, deps = [ ":tcp", "//pkg/rand", @@ -110,6 +112,7 @@ go_test( "//pkg/tcpip/transport/tcp/testing/context", "//pkg/test/testutil", "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 47982ca41..2d96a65bd 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -213,7 +213,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i route.ResolveWith(s.remoteLinkAddr) n := newEndpoint(l.stack, netProto, queue) - n.v6only = l.v6Only + n.ops.SetV6Only(l.v6Only) n.ID = s.id n.boundNICID = s.nicID n.route = route @@ -235,11 +235,15 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i return n, nil } -// createEndpointAndPerformHandshake creates a new endpoint in connected state -// and then performs the TCP 3-way handshake. +// startHandshake creates a new endpoint in connecting state and then sends +// the SYN-ACK for the TCP 3-way handshake. It returns the state of the +// handshake in progress, which includes the new endpoint in the SYN-RCVD +// state. // -// The new endpoint is returned with e.mu held. -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { +// On success, a handshake h is returned with h.ep.mu held. +// +// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. +func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) @@ -257,10 +261,8 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // listenEP is nil when listenContext is used by tcp.Forwarder. deferAccept := time.Duration(0) if l.listenEP != nil { - l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { - l.listenEP.mu.Unlock() // Ensure we release any registrations done by the newly // created endpoint. ep.mu.Unlock() @@ -278,16 +280,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head ep.mu.Unlock() ep.Close() - if l.listenEP != nil { - l.removePendingEndpoint(ep) - l.listenEP.mu.Unlock() - } + l.removePendingEndpoint(ep) return nil, tcpip.ErrConnectionAborted } deferAccept = l.listenEP.deferAccept - l.listenEP.mu.Unlock() } // Register new endpoint so that packets are routed to it. @@ -306,28 +304,33 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head ep.isRegistered = true - // Perform the 3-way handshake. - h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) - if err := h.execute(); err != nil { - ep.mu.Unlock() - ep.Close() - ep.notifyAborted() - - if l.listenEP != nil { - l.removePendingEndpoint(ep) - } - - ep.drainClosingSegmentQueue() - + // Initialize and start the handshake. + h := ep.newPassiveHandshake(isn, irs, opts, deferAccept) + if err := h.start(); err != nil { + l.cleanupFailedHandshake(h) return nil, err } - ep.isConnectNotified = true + return h, nil +} - // Update the receive window scaling. We can't do it before the - // handshake because it's possible that the peer doesn't support window - // scaling. - ep.rcv.rcvWndScale = h.effectiveRcvWndScale() +// performHandshake performs a TCP 3-way handshake. On success, the new +// established endpoint is returned with e.mu held. +// +// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. +func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { + h, err := l.startHandshake(s, opts, queue, owner) + if err != nil { + return nil, err + } + ep := h.ep + if err := h.complete(); err != nil { + ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() + ep.stats.FailedConnectionAttempts.Increment() + l.cleanupFailedHandshake(h) + return nil, err + } + l.cleanupCompletedHandshake(h) return ep, nil } @@ -354,6 +357,39 @@ func (l *listenContext) closeAllPendingEndpoints() { l.pending.Wait() } +// Precondition: h.ep.mu must be held. +func (l *listenContext) cleanupFailedHandshake(h *handshake) { + e := h.ep + e.mu.Unlock() + e.Close() + e.notifyAborted() + if l.listenEP != nil { + l.removePendingEndpoint(e) + } + e.drainClosingSegmentQueue() + e.h = nil +} + +// cleanupCompletedHandshake transfers any state from the completed handshake to +// the new endpoint. +// +// Precondition: h.ep.mu must be held. +func (l *listenContext) cleanupCompletedHandshake(h *handshake) { + e := h.ep + if l.listenEP != nil { + l.removePendingEndpoint(e) + } + e.isConnectNotified = true + + // Update the receive window scaling. We can't do it before the + // handshake because it's possible that the peer doesn't support window + // scaling. + e.rcv.rcvWndScale = e.h.effectiveRcvWndScale() + + // Clean up handshake state stored in the endpoint so that it can be GCed. + e.h = nil +} + // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state (acceptedChan is nil), // the new endpoint is closed instead. @@ -433,23 +469,40 @@ func (e *endpoint) notifyAborted() { // // A limited number of these goroutines are allowed before TCP starts using SYN // cookies to accept connections. -func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { - defer ctx.synRcvdCount.dec() +// +// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. +func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error { defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) + h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - e.decSynRcvdCount() - return + e.synRcvdCount-- + return err } - ctx.removePendingEndpoint(n) - e.decSynRcvdCount() - n.startAcceptedLoop() - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(n) + go func() { + defer ctx.synRcvdCount.dec() + if err := h.complete(); err != nil { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + ctx.cleanupFailedHandshake(h) + e.mu.Lock() + e.synRcvdCount-- + e.mu.Unlock() + return + } + ctx.cleanupCompletedHandshake(h) + e.mu.Lock() + e.synRcvdCount-- + e.mu.Unlock() + h.ep.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + e.deliverAccepted(h.ep) + }() // S/R-SAFE: synRcvdCount is the barrier. + + return nil } func (e *endpoint) incSynRcvdCount() bool { @@ -462,12 +515,6 @@ func (e *endpoint) incSynRcvdCount() bool { return canInc } -func (e *endpoint) decSynRcvdCount() { - e.mu.Lock() - e.synRcvdCount-- - e.mu.Unlock() -} - func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan) @@ -477,6 +524,8 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. +// +// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error { e.rcvListMu.Lock() rcvClosed := e.rcvClosed @@ -500,7 +549,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er // backlog. if !e.acceptQueueIsFull() && e.incSynRcvdCount() { s.incRef() - go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. + _ = e.handleSynSegment(ctx, s, &opts) return nil } ctx.synRcvdCount.dec() @@ -550,7 +599,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er ack: s.sequenceNumber + 1, rcvWnd: ctx.rcvWnd, } - if err := e.sendSynTCP(&route, fields, synOpts); err != nil { + if err := e.sendSynTCP(route, fields, synOpts); err != nil { return err } e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() @@ -703,7 +752,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er // its own goroutine and is responsible for handling connection requests. func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.mu.Lock() - v6Only := e.v6only + v6Only := e.ops.GetV6Only() ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) defer func() { @@ -712,7 +761,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { // to the endpoint. e.setEndpointState(StateClose) - // close any endpoints in SYN-RCVD state. + // Close any endpoints in SYN-RCVD state. ctx.closeAllPendingEndpoints() // Do cleanup if needed. @@ -729,7 +778,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() - s := sleep.Sleeper{} + var s sleep.Sleeper s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) for { diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 6e9015be1..a00ef97c6 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -16,6 +16,7 @@ package tcp import ( "encoding/binary" + "math" "time" "gvisor.dev/gvisor/pkg/rand" @@ -102,21 +103,26 @@ type handshake struct { // been received. This is required to stop retransmitting the // original SYN-ACK when deferAccept is enabled. acked bool + + // sendSYNOpts is the cached values for the SYN options to be sent. + sendSYNOpts header.TCPSynOptions } -func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { - h := handshake{ - ep: ep, +func (e *endpoint) newHandshake() *handshake { + h := &handshake{ + ep: e, active: true, - rcvWnd: rcvWnd, - rcvWndScale: ep.rcvWndScaleForHandshake(), + rcvWnd: seqnum.Size(e.initialReceiveWindow()), + rcvWndScale: e.rcvWndScaleForHandshake(), } h.resetState() + // Store reference to handshake state in endpoint. + e.h = h return h } -func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake { - h := newHandshake(ep, rcvWnd) +func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake { + h := e.newHandshake() h.resetToSynRcvd(isn, irs, opts, deferAccept) return h } @@ -128,7 +134,7 @@ func FindWndScale(wnd seqnum.Size) int { return 0 } - max := seqnum.Size(0xffff) + max := seqnum.Size(math.MaxUint16) s := 0 for wnd > max && s < header.MaxWndScale { s++ @@ -295,7 +301,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { if ttl == 0 { ttl = h.ep.route.DefaultTTL() } - h.ep.sendSynTCP(&h.ep.route, tcpFields{ + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: ttl, tos: h.ep.sendTOS, @@ -356,7 +362,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - h.ep.sendSynTCP(&h.ep.route, tcpFields{ + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -456,7 +462,7 @@ func (h *handshake) processSegments() *tcpip.Error { func (h *handshake) resolveRoute() *tcpip.Error { // Set up the wakers. - s := sleep.Sleeper{} + var s sleep.Sleeper resolutionWaker := &sleep.Waker{} s.AddWaker(resolutionWaker, wakerForResolution) s.AddWaker(&h.ep.notificationWaker, wakerForNotification) @@ -464,24 +470,27 @@ func (h *handshake) resolveRoute() *tcpip.Error { // Initial action is to resolve route. index := wakerForResolution + attemptedResolution := false for { switch index { case wakerForResolution: - if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { - if err == tcpip.ErrNoLinkAddress { - h.ep.stats.SendErrors.NoLinkAddr.Increment() - } else if err != nil { + if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock { + if err != nil { h.ep.stats.SendErrors.NoRoute.Increment() } // Either success (err == nil) or failure. return err } + if attemptedResolution { + h.ep.stats.SendErrors.NoLinkAddr.Increment() + return tcpip.ErrNoLinkAddress + } + attemptedResolution = true // Resolution not completed. Keep trying... case wakerForNotification: n := h.ep.fetchNotifications() if n¬ifyClose != 0 { - h.ep.route.RemoveWaker(resolutionWaker) return tcpip.ErrAborted } if n¬ifyDrain != 0 { @@ -491,7 +500,7 @@ func (h *handshake) resolveRoute() *tcpip.Error { h.ep.mu.Lock() } if n¬ifyError != 0 { - return h.ep.LastError() + return h.ep.lastErrorLocked() } } @@ -502,8 +511,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { } } -// execute executes the TCP 3-way handshake. -func (h *handshake) execute() *tcpip.Error { +// start resolves the route if necessary and sends the first +// SYN/SYN-ACK. +func (h *handshake) start() *tcpip.Error { if h.ep.route.IsResolutionRequired() { if err := h.resolveRoute(); err != nil { return err @@ -511,19 +521,7 @@ func (h *handshake) execute() *tcpip.Error { } h.startTime = time.Now() - // Initialize the resend timer. - resendWaker := sleep.Waker{} - timeOut := time.Duration(time.Second) - rt := time.AfterFunc(timeOut, resendWaker.Assert) - defer rt.Stop() - - // Set up the wakers. - s := sleep.Sleeper{} - s.AddWaker(&resendWaker, wakerForResend) - s.AddWaker(&h.ep.notificationWaker, wakerForNotification) - s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) - defer s.Done() - + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled @@ -531,10 +529,6 @@ func (h *handshake) execute() *tcpip.Error { sackEnabled = false } - // Send the initial SYN segment and loop until the handshake is - // completed. - h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) - synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: true, @@ -544,9 +538,8 @@ func (h *handshake) execute() *tcpip.Error { MSS: h.ep.amss, } - // Execute is also called in a listen context so we want to make sure we - // only send the TS/SACK option when we received the TS/SACK in the - // initial SYN. + // start() is also called in a listen context so we want to make sure we only + // send the TS/SACK option when we received the TS/SACK in the initial SYN. if h.state == handshakeSynRcvd { synOpts.TS = h.ep.sendTSOk synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) @@ -557,7 +550,8 @@ func (h *handshake) execute() *tcpip.Error { } } - h.ep.sendSynTCP(&h.ep.route, tcpFields{ + h.sendSYNOpts = synOpts + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -566,7 +560,25 @@ func (h *handshake) execute() *tcpip.Error { ack: h.ackNum, rcvWnd: h.rcvWnd, }, synOpts) + return nil +} + +// complete completes the TCP 3-way handshake initiated by h.start(). +func (h *handshake) complete() *tcpip.Error { + // Set up the wakers. + var s sleep.Sleeper + resendWaker := sleep.Waker{} + s.AddWaker(&resendWaker, wakerForResend) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) + defer s.Done() + // Initialize the resend timer. + timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert) + if err != nil { + return err + } + defer timer.stop() for h.state != handshakeCompleted { // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held // throughout handshake processing). @@ -576,11 +588,9 @@ func (h *handshake) execute() *tcpip.Error { switch index { case wakerForResend: - timeOut *= 2 - if timeOut > MaxRTO { - return tcpip.ErrTimeout + if err := timer.reset(); err != nil { + return err } - rt.Reset(timeOut) // Resend the SYN/SYN-ACK only if the following conditions hold. // - It's an active handshake (deferAccept does not apply) // - It's a passive handshake and we have not yet got the final-ACK. @@ -590,7 +600,7 @@ func (h *handshake) execute() *tcpip.Error { // the connection with another ACK or data (as ACKs are never // retransmitted on their own). if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { - h.ep.sendSynTCP(&h.ep.route, tcpFields{ + h.ep.sendSynTCP(h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -598,7 +608,7 @@ func (h *handshake) execute() *tcpip.Error { seq: h.iss, ack: h.ackNum, rcvWnd: h.rcvWnd, - }, synOpts) + }, h.sendSYNOpts) } case wakerForNotification: @@ -624,9 +634,8 @@ func (h *handshake) execute() *tcpip.Error { h.ep.mu.Lock() } if n¬ifyError != 0 { - return h.ep.LastError() + return h.ep.lastErrorLocked() } - case wakerForNewSegment: if err := h.processSegments(); err != nil { return err @@ -637,6 +646,34 @@ func (h *handshake) execute() *tcpip.Error { return nil } +type backoffTimer struct { + timeout time.Duration + maxTimeout time.Duration + t *time.Timer +} + +func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) { + if timeout > maxTimeout { + return nil, tcpip.ErrTimeout + } + bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} + bt.t = time.AfterFunc(timeout, f) + return bt, nil +} + +func (bt *backoffTimer) reset() *tcpip.Error { + bt.timeout *= 2 + if bt.timeout > MaxRTO { + return tcpip.ErrTimeout + } + bt.t.Reset(bt.timeout) + return nil +} + +func (bt *backoffTimer) stop() { + bt.t.Stop() +} + func parseSynSegmentOptions(s *segment) header.TCPSynOptions { synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck)) if synOpts.TS { @@ -785,8 +822,8 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso data = data.Clone(nil) optLen := len(tf.opts) - if tf.rcvWnd > 0xffff { - tf.rcvWnd = 0xffff + if tf.rcvWnd > math.MaxUint16 { + tf.rcvWnd = math.MaxUint16 } mss := int(gso.MSS) @@ -830,8 +867,8 @@ func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso // network endpoint and under the provided identity. func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { optLen := len(tf.opts) - if tf.rcvWnd > 0xffff { - tf.rcvWnd = 0xffff + if tf.rcvWnd > math.MaxUint16 { + tf.rcvWnd = math.MaxUint16 } if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { @@ -906,7 +943,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := e.sendTCP(&e.route, tcpFields{ + err := e.sendTCP(e.route, tcpFields{ id: e.ID, ttl: e.ttl, tos: e.sendTOS, @@ -967,7 +1004,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { // Only send a reset if the connection is being aborted for a reason // other than receiving a reset. e.setEndpointState(StateError) - e.HardError = err + e.hardError = err if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { // The exact sequence number to be used for the RST is the same as the // one used by Linux. We need to handle the case of window being shrunk @@ -1045,7 +1082,7 @@ func (e *endpoint) transitionToStateCloseLocked() { // to any other listening endpoint. We reply with RST if we cannot find one. func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID) - if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" { + if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.TransportEndpointInfo.ID.LocalAddress.To4() != "" { // Dual-stack socket, try IPv4. ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID) } @@ -1106,7 +1143,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { // delete the TCB, and return. case StateCloseWait: e.transitionToStateCloseLocked() - e.HardError = tcpip.ErrAborted + e.hardError = tcpip.ErrAborted e.notifyProtocolGoroutine(notifyTickleWorker) return false, nil default: @@ -1251,7 +1288,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { userTimeout := e.userTimeout e.keepalive.Lock() - if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { + if !e.SocketOptions().GetKeepAlive() || !e.keepalive.timer.checkExpiration() { e.keepalive.Unlock() return nil } @@ -1288,7 +1325,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) { } // Start the keepalive timer IFF it's enabled and there is no pending // data to send. - if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { + if !e.SocketOptions().GetKeepAlive() || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { e.keepalive.timer.disable() e.keepalive.Unlock() return @@ -1318,9 +1355,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ epilogue := func() { // e.mu is expected to be hold upon entering this section. - if e.snd != nil { e.snd.resendTimer.cleanup() + e.snd.rc.probeTimer.cleanup() } if closeTimer != nil { @@ -1342,20 +1379,13 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if handshake { - // This is an active connection, so we must initiate the 3-way - // handshake, and then inform potential waiters about its - // completion. - initialRcvWnd := e.initialReceiveWindow() - h := newHandshake(e, seqnum.Size(initialRcvWnd)) - h.ep.setEndpointState(StateSynSent) - - if err := h.execute(); err != nil { + if err := e.h.complete(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() e.setEndpointState(StateError) - e.HardError = err + e.hardError = err e.workerCleanup = true // Lock released below. @@ -1364,9 +1394,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } } - e.keepalive.timer.init(&e.keepalive.waker) - defer e.keepalive.timer.cleanup() - drained := e.drainDone != nil if drained { close(e.drainDone) @@ -1411,6 +1438,10 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ }, }, { + w: &e.snd.rc.probeWaker, + f: e.snd.probeTimerExpired, + }, + { w: &e.newSegmentWaker, f: func() *tcpip.Error { return e.handleSegments(false /* fastPath */) @@ -1489,7 +1520,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } // Initialize the sleeper based on the wakers in funcs. - s := sleep.Sleeper{} + var s sleep.Sleeper for i := range funcs { s.AddWaker(funcs[i].w, i) } @@ -1613,7 +1644,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() } extTW, newSyn := e.rcv.handleTimeWaitSegment(s) if newSyn { - info := e.EndpointInfo.TransportEndpointInfo + info := e.TransportEndpointInfo newID := info.ID newID.RemoteAddress = "" newID.RemotePort = 0 @@ -1676,7 +1707,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { const notification = 2 const timeWaitDone = 3 - s := sleep.Sleeper{} + var s sleep.Sleeper defer s.Done() s.AddWaker(&e.newSegmentWaker, newSegment) s.AddWaker(&e.notificationWaker, notification) diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index a6f25896b..1d1b01a6c 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -405,14 +405,6 @@ func testV4Accept(t *testing.T, c *context.Context) { } } - // Make sure we get the same error when calling the original ep and the - // new one. This validates that v4-mapped endpoints are still able to - // query the V6Only flag, whereas pure v4 endpoints are not. - _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption) - if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected { - t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected) - } - // Check the peer address. addr, err := nep.GetRemoteAddress() if err != nil { @@ -530,12 +522,12 @@ func TestV6AcceptOnV6(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) var addr tcpip.FullAddress - nep, _, err := c.EP.Accept(&addr) + _, _, err := c.EP.Accept(&addr) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept(&addr) + _, _, err = c.EP.Accept(&addr) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -548,12 +540,6 @@ func TestV6AcceptOnV6(t *testing.T) { if addr.Addr != context.TestV6Addr { t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr) } - - // Make sure we can still query the v6 only status of the new endpoint, - // that is, that it is in fact a v6 socket. - if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { - t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err) - } } func TestV4AcceptOnV4(t *testing.T) { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 23b9de8c5..25b180fa5 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -17,6 +17,7 @@ package tcp import ( "encoding/binary" "fmt" + "io" "math" "runtime" "strings" @@ -27,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" @@ -310,16 +310,12 @@ type Stats struct { func (*Stats) IsEndpointStats() {} // EndpointInfo holds useful information about a transport endpoint which -// can be queried by monitoring tools. +// can be queried by monitoring tools. This exists to allow tcp-only state to +// be exposed. // // +stateify savable type EndpointInfo struct { stack.TransportEndpointInfo - - // HardError is meaningful only when state is stateError. It stores the - // error to be returned when read/write syscalls are called and the - // endpoint is in this state. HardError is protected by endpoint mu. - HardError *tcpip.Error `state:".(string)"` } // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo @@ -367,6 +363,7 @@ func (*EndpointInfo) IsEndpointInfo() {} // +stateify savable type endpoint struct { EndpointInfo + tcpip.DefaultSocketOptionsHandler // endpointEntry is used to queue endpoints for processing to the // a given tcp processor goroutine. @@ -386,20 +383,38 @@ type endpoint struct { waiterQueue *waiter.Queue `state:"wait"` uniqueID uint64 + // hardError is meaningful only when state is stateError. It stores the + // error to be returned when read/write syscalls are called and the + // endpoint is in this state. hardError is protected by endpoint mu. + hardError *tcpip.Error `state:".(string)"` + // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` - // The following fields are used to manage the receive queue. The - // protocol goroutine adds ready-for-delivery segments to rcvList, - // which are returned by Read() calls to users. + // rcvReadMu synchronizes calls to Read. // - // Once the peer has closed its send side, rcvClosed is set to true - // to indicate to users that no more data is coming. + // mu and rcvListMu are temporarily released during data copying. rcvReadMu + // must be held during each read to ensure atomicity, so that multiple reads + // do not interleave. + // + // rcvReadMu should be held before holding mu. + rcvReadMu sync.Mutex `state:"nosave"` + + // rcvListMu synchronizes access to rcvList. // // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` + rcvListMu sync.Mutex `state:"nosave"` + + // rcvList is the queue for ready-for-delivery segments. + // + // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data + // and removing segments from list. A range of segment can be determined, then + // temporarily release mu and rcvListMu while processing the segment range. + // This allows new segments to be appended to the list while processing. + // + // rcvListMu must be held to append segments to list. rcvList segmentList `state:"wait"` rcvClosed bool // rcvBufSize is the total size of the receive buffer. @@ -421,7 +436,10 @@ type endpoint struct { // mu protects all endpoint fields unless documented otherwise. mu must // be acquired before interacting with the endpoint fields. - mu sync.Mutex `state:"nosave"` + // + // During handshake, mu is locked by the protocol listen goroutine and + // released by the handshake completion goroutine. + mu sync.CrossGoroutineMutex `state:"nosave"` ownedByUser uint32 // state must be read/set using the EndpointState()/setEndpointState() @@ -436,13 +454,14 @@ type endpoint struct { isPortReserved bool `state:"manual"` isRegistered bool `state:"manual"` boundNICID tcpip.NICID - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` ttl uint8 - v6only bool isConnectNotified bool - // TCP should never broadcast but Linux nevertheless supports enabling/ - // disabling SO_BROADCAST, albeit as a NOOP. - broadcast bool + + // h stores a reference to the current handshake state if the endpoint is in + // the SYN-SENT or SYN-RECV states, in which case endpoint == endpoint.h.ep. + // nil otherwise. + h *handshake `state:"nosave"` // portFlags stores the current values of port related flags. portFlags ports.Flags @@ -489,6 +508,9 @@ type endpoint struct { // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags + // tcpRecovery is the loss deteoction algorithm used by TCP. + tcpRecovery tcpip.TCPRecovery + // sackPermitted is set to true if the peer sends the TCPSACKPermitted // option in the SYN/SYN-ACK. sackPermitted bool @@ -496,32 +518,14 @@ type endpoint struct { // sack holds TCP SACK related information for this endpoint. sack SACKInfo - // bindToDevice is set to the NIC on which to bind or disabled if 0. - bindToDevice tcpip.NICID - // delay enables Nagle's algorithm. // // delay is a boolean (0 is false) and must be accessed atomically. delay uint32 - // cork holds back segments until full. - // - // cork is a boolean (0 is false) and must be accessed atomically. - cork uint32 - // scoreboard holds TCP SACK Scoreboard information for this endpoint. scoreboard *SACKScoreboard - // The options below aren't implemented, but we remember the user - // settings because applications expect to be able to set/query these - // options. - - // slowAck holds the negated state of quick ack. It is stubbed out and - // does nothing. - // - // slowAck is a boolean (0 is false) and must be accessed atomically. - slowAck uint32 - // segmentQueue is used to hand received segments to the protocol // goroutine. Segments are queued as long as the queue is not full, // and dropped when it is. @@ -683,8 +687,8 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption + // ops is used to get socket level options. + ops tcpip.SocketOptions } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -696,7 +700,7 @@ func (e *endpoint) UniqueID() uint64 { // // If userMSS is non-zero and is not greater than the maximum possible MSS for // r, it will be used; otherwise, the maximum possible MSS will be used. -func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { +func calculateAdvertisedMSS(userMSS uint16, r *stack.Route) uint16 { // The maximum possible MSS is dependent on the route. // TODO(b/143359391): Respect TCP Min and Max size. maxMSS := uint16(r.MTU() - header.TCPMinimumSize) @@ -845,7 +849,6 @@ func (e *endpoint) recentTimestamp() uint32 { // +stateify savable type keepalive struct { sync.Mutex `state:"nosave"` - enabled bool idle time.Duration interval time.Duration count int @@ -879,6 +882,9 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue windowClamp: DefaultReceiveBufferSize, maxSynRetries: DefaultSynRetries, } + e.ops.InitHandler(e) + e.ops.SetMulticastLoop(true) + e.ops.SetQuickAck(true) var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { @@ -902,7 +908,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue var de tcpip.TCPDelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { - e.SetSockOptBool(tcpip.DelayOption, true) + e.ops.SetDelayOption(true) } var tcpLT tcpip.TCPLingerTimeoutOption @@ -915,6 +921,8 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.maxSynRetries = uint8(synRetries) } + s.TransportProtocolOption(ProtocolNumber, &e.tcpRecovery) + if p := s.GetTCPProbe(); p != nil { e.probe = p } @@ -922,6 +930,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.segmentQueue.ep = e e.tsOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) + e.keepalive.timer.init(&e.keepalive.waker) return e } @@ -1043,7 +1052,8 @@ func (e *endpoint) Close() { return } - if e.linger.Enabled && e.linger.Timeout == 0 { + linger := e.SocketOptions().GetLinger() + if linger.Enabled && linger.Timeout == 0 { s := e.EndpointState() isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv if isResetState { @@ -1146,6 +1156,7 @@ func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. e.closePendingAcceptableConnectionsLocked() + e.keepalive.timer.cleanup() e.workerCleanup = false @@ -1162,7 +1173,11 @@ func (e *endpoint) cleanupLocked() { e.boundPortFlags = ports.Flags{} e.boundDest = tcpip.FullAddress{} - e.route.Release() + if e.route != nil { + e.route.Release() + e.route = nil + } + e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) } @@ -1272,11 +1287,20 @@ func (e *endpoint) ModerateRecvBuf(copied int) { e.rcvListMu.Unlock() } +// SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -func (e *endpoint) LastError() *tcpip.Error { +// Preconditions: e.mu must be held to call this function. +func (e *endpoint) hardErrorLocked() *tcpip.Error { + err := e.hardError + e.hardError = nil + return err +} + +// Preconditions: e.mu must be held to call this function. +func (e *endpoint) lastErrorLocked() *tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() err := e.lastError @@ -1284,8 +1308,88 @@ func (e *endpoint) LastError() *tcpip.Error { return err } -// Read reads data from the endpoint. -func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// LastError implements tcpip.Endpoint.LastError. +func (e *endpoint) LastError() *tcpip.Error { + e.LockUser() + defer e.UnlockUser() + if err := e.hardErrorLocked(); err != nil { + return err + } + return e.lastErrorLocked() +} + +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (e *endpoint) UpdateLastError(err *tcpip.Error) { + e.LockUser() + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + e.UnlockUser() +} + +// Read implements tcpip.Endpoint.Read. +func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { + e.rcvReadMu.Lock() + defer e.rcvReadMu.Unlock() + + // N.B. Here we get a range of segments to be processed. It is safe to not + // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we + // can remove segments from the list through commitRead(). + first, last, serr := e.startRead() + if serr != nil { + if serr == tcpip.ErrClosedForReceive { + e.stats.ReadErrors.ReadClosed.Increment() + } + return tcpip.ReadResult{}, serr + } + + var err error + done := 0 + s := first + for s != nil && done < count { + var n int + n, err = s.data.ReadTo(dst, count-done, opts.Peek) + // Book keeping first then error handling. + + done += n + + if opts.Peek { + // For peek, we use the (first, last) range of segment returned from + // startRead. We don't consume the receive buffer, so commitRead should + // not be called. + // + // N.B. It is important to use `last` to determine the last segment, since + // appending can happen while we process, and will lead to data race. + if s == last { + break + } + s = s.Next() + } else { + // N.B. commitRead() conveniently returns the next segment to read, after + // removing the data/segment that is read. + s = e.commitRead(n) + } + + if err != nil { + break + } + } + + // If something is read, we must report it. Report error when nothing is read. + if done == 0 && err != nil { + return tcpip.ReadResult{}, tcpip.ErrBadBuffer + } + return tcpip.ReadResult{ + Count: done, + Total: done, + }, nil +} + +// startRead checks that endpoint is in a readable state, and return the +// inclusive range of segments that can be read. +// +// Precondition: e.rcvReadMu must be held. +func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -1294,7 +1398,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, // on a receive. It can expect to read any data after the handshake // is complete. RFC793, section 3.9, p58. if e.EndpointState() == StateSynSent { - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + return nil, nil, tcpip.ErrWouldBlock } // The endpoint can be read if it's connected, or if it's already closed @@ -1302,59 +1406,69 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, // would cause the state to become StateError so we should allow the // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() + defer e.rcvListMu.Unlock() + bufUsed := e.rcvBufUsed if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { - e.rcvListMu.Unlock() - he := e.HardError if s == StateError { - return buffer.View{}, tcpip.ControlMessages{}, he + if err := e.hardErrorLocked(); err != nil { + return nil, nil, err + } + return nil, nil, tcpip.ErrClosedForReceive } e.stats.ReadErrors.NotConnected.Increment() - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected - } - - v, err := e.readLocked() - e.rcvListMu.Unlock() - - if err == tcpip.ErrClosedForReceive { - e.stats.ReadErrors.ReadClosed.Increment() + return nil, nil, tcpip.ErrNotConnected } - return v, tcpip.ControlMessages{}, err -} -func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { if e.rcvClosed || !e.EndpointState().connected() { - return buffer.View{}, tcpip.ErrClosedForReceive + return nil, nil, tcpip.ErrClosedForReceive } - return buffer.View{}, tcpip.ErrWouldBlock + return nil, nil, tcpip.ErrWouldBlock } - s := e.rcvList.Front() - views := s.data.Views() - v := views[s.viewToDeliver] - s.viewToDeliver++ + return e.rcvList.Front(), e.rcvList.Back(), nil +} - var delta int - if s.viewToDeliver >= len(views) { +// commitRead commits a read of done bytes and returns the next non-empty +// segment to read. Data read from the segment must have also been removed from +// the segment in order for this method to work correctly. +// +// It is performance critical to call commitRead frequently when servicing a big +// Read request, so TCP can make progress timely. Right now, it is designed to +// do this per segment read, hence this method conveniently returns the next +// segment to read while holding the lock. +// +// Precondition: e.rcvReadMu must be held. +func (e *endpoint) commitRead(done int) *segment { + e.LockUser() + defer e.UnlockUser() + e.rcvListMu.Lock() + defer e.rcvListMu.Unlock() + + memDelta := 0 + s := e.rcvList.Front() + for s != nil && s.data.Size() == 0 { e.rcvList.Remove(s) - // We only free up receive buffer space when the segment is released as the - // segment is still holding on to the views even though some views have been - // read out to the user. - delta = s.segMemSize() + // Memory is only considered released when the whole segment has been + // read. + memDelta += s.segMemSize() s.decRef() + s = e.rcvList.Front() } + e.rcvBufUsed -= done - e.rcvBufUsed -= len(v) - // If the window was small before this read and if the read freed up - // enough buffer space, to either fit an aMSS or half a receive buffer - // (whichever smaller), then notify the protocol goroutine to send a - // window update. - if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above { - e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + if memDelta > 0 { + // If the window was small before this read and if the read freed up + // enough buffer space, to either fit an aMSS or half a receive buffer + // (whichever smaller), then notify the protocol goroutine to send a + // window update. + if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above { + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + } } - return v, nil + return e.rcvList.Front() } // isEndpointWritableLocked checks if a given endpoint is writable @@ -1363,9 +1477,13 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // indicating the reason why it's not writable. // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { + // The endpoint cannot be written to if it's not connected. switch s := e.EndpointState(); { case s == StateError: - return 0, e.HardError + if err := e.hardErrorLocked(); err != nil { + return 0, err + } + return 0, tcpip.ErrClosedForSend case !s.connecting() && !s.connected(): return 0, tcpip.ErrClosedForSend case s.connecting(): @@ -1426,104 +1544,38 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, perr } - queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) { - // Add data to the send queue. - s := newOutgoingSegment(e.ID, v) - e.sndBufUsed += len(v) - e.sndBufInQueue += seqnum.Size(len(v)) - e.sndQueue.PushBack(s) - e.sndBufMu.Unlock() - - // Do the work inline. - e.handleWrite() - e.UnlockUser() - return int64(len(v)), nil, nil - } - - if opts.Atomic { - // Locks released in queueAndSend() - return queueAndSend() - } - - // Since we released locks in between it's possible that the - // endpoint transitioned to a CLOSED/ERROR states so make - // sure endpoint is still writable before trying to write. - e.LockUser() - e.sndBufMu.Lock() - avail, err = e.isEndpointWritableLocked() - if err != nil { - e.sndBufMu.Unlock() - e.UnlockUser() - e.stats.WriteErrors.WriteClosed.Increment() - return 0, nil, err - } - - // Discard any excess data copied in due to avail being reduced due - // to a simultaneous write call to the socket. - if avail < len(v) { - v = v[:avail] - } - - // Locks released in queueAndSend() - return queueAndSend() -} - -// Peek reads data without consuming it from the endpoint. -// -// This method does not block if there is no data pending. -func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - e.LockUser() - defer e.UnlockUser() - - // The endpoint can be read if it's connected, or if it's already closed - // but has some pending unread data. - if s := e.EndpointState(); !s.connected() && s != StateClose { - if s == StateError { - return 0, tcpip.ControlMessages{}, e.HardError + if !opts.Atomic { + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + e.LockUser() + e.sndBufMu.Lock() + avail, err := e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.UnlockUser() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err } - e.stats.ReadErrors.InvalidEndpointState.Increment() - return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState - } - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() - - if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.EndpointState().connected() { - e.stats.ReadErrors.ReadClosed.Increment() - return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] } - return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock } - // Make a copy of vec so we can modify the slide headers. - vec = append([][]byte(nil), vec...) - - var num int64 - for s := e.rcvList.Front(); s != nil; s = s.Next() { - views := s.data.Views() - - for i := s.viewToDeliver; i < len(views); i++ { - v := views[i] - - for len(v) > 0 { - if len(vec) == 0 { - return num, tcpip.ControlMessages{}, nil - } - if len(vec[0]) == 0 { - vec = vec[1:] - continue - } - - n := copy(vec[0], v) - v = v[n:] - vec[0] = vec[0][n:] - num += int64(n) - } - } - } + // Add data to the send queue. + s := newOutgoingSegment(e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() - return num, tcpip.ControlMessages{}, nil + // Do the work inline. + e.handleWrite() + e.UnlockUser() + return int64(len(v)), nil, nil } // selectWindowLocked returns the new window without checking for shrinking or scaling @@ -1595,77 +1647,39 @@ func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed boo return false, false } -// SetSockOptBool sets a socket option. -func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - - case tcpip.BroadcastOption: - e.LockUser() - e.broadcast = v - e.UnlockUser() - - case tcpip.CorkOption: - e.LockUser() - if !v { - atomic.StoreUint32(&e.cork, 0) - - // Handle the corked data. - e.sndWaker.Assert() - } else { - atomic.StoreUint32(&e.cork, 1) - } - e.UnlockUser() - - case tcpip.DelayOption: - if v { - atomic.StoreUint32(&e.delay, 1) - } else { - atomic.StoreUint32(&e.delay, 0) - - // Handle delayed data. - e.sndWaker.Assert() - } - - case tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - e.keepalive.enabled = v - e.keepalive.Unlock() - e.notifyProtocolGoroutine(notifyKeepaliveChanged) - - case tcpip.QuickAckOption: - o := uint32(1) - if v { - o = 0 - } - atomic.StoreUint32(&e.slowAck, o) - - case tcpip.ReuseAddressOption: - e.LockUser() - e.portFlags.TupleOnly = v - e.UnlockUser() - - case tcpip.ReusePortOption: - e.LockUser() - e.portFlags.LoadBalanced = v - e.UnlockUser() +// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. +func (e *endpoint) OnReuseAddressSet(v bool) { + e.LockUser() + e.portFlags.TupleOnly = v + e.UnlockUser() +} - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } +// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet. +func (e *endpoint) OnReusePortSet(v bool) { + e.LockUser() + e.portFlags.LoadBalanced = v + e.UnlockUser() +} - // We only allow this to be set when we're in the initial state. - if e.EndpointState() != StateInitial { - return tcpip.ErrInvalidEndpointState - } +// OnKeepAliveSet implements tcpip.SocketOptionsHandler.OnKeepAliveSet. +func (e *endpoint) OnKeepAliveSet(v bool) { + e.notifyProtocolGoroutine(notifyKeepaliveChanged) +} - e.LockUser() - e.v6only = v - e.UnlockUser() +// OnDelayOptionSet implements tcpip.SocketOptionsHandler.OnDelayOptionSet. +func (e *endpoint) OnDelayOptionSet(v bool) { + if !v { + // Handle delayed data. + e.sndWaker.Assert() } +} - return nil +// OnCorkOptionSet implements tcpip.SocketOptionsHandler.OnCorkOptionSet. +func (e *endpoint) OnCorkOptionSet(v bool) { + if !v { + // Handle the corked data. + e.sndWaker.Assert() + } } // SetSockOptInt sets a socket option. @@ -1825,18 +1839,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +func (e *endpoint) HasNIC(id int32) bool { + return id == 0 || e.stack.HasNIC(tcpip.NICID(id)) +} + // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { - case *tcpip.BindToDeviceOption: - id := tcpip.NICID(*v) - if id != 0 && !e.stack.HasNIC(id) { - return tcpip.ErrUnknownDevice - } - e.LockUser() - e.bindToDevice = id - e.UnlockUser() - case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() e.keepalive.idle = time.Duration(*v) @@ -1849,9 +1858,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - case *tcpip.OutOfBandInlineOption: - // We don't currently support disabling this option. - case *tcpip.TCPUserTimeoutOption: e.LockUser() e.userTimeout = time.Duration(*v) @@ -1920,11 +1926,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - e.LockUser() - e.linger = *v - e.UnlockUser() - default: return nil } @@ -1947,72 +1948,6 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { return e.rcvBufUsed, nil } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.BroadcastOption: - e.LockUser() - v := e.broadcast - e.UnlockUser() - return v, nil - - case tcpip.CorkOption: - return atomic.LoadUint32(&e.cork) != 0, nil - - case tcpip.DelayOption: - return atomic.LoadUint32(&e.delay) != 0, nil - - case tcpip.KeepaliveEnabledOption: - e.keepalive.Lock() - v := e.keepalive.enabled - e.keepalive.Unlock() - - return v, nil - - case tcpip.QuickAckOption: - v := atomic.LoadUint32(&e.slowAck) == 0 - return v, nil - - case tcpip.ReuseAddressOption: - e.LockUser() - v := e.portFlags.TupleOnly - e.UnlockUser() - - return v, nil - - case tcpip.ReusePortOption: - e.LockUser() - v := e.portFlags.LoadBalanced - e.UnlockUser() - - return v, nil - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return false, tcpip.ErrUnknownProtocolOption - } - - e.LockUser() - v := e.v6only - e.UnlockUser() - - return v, nil - - case tcpip.MulticastLoopOption: - return true, nil - - case tcpip.AcceptConnOption: - e.LockUser() - defer e.UnlockUser() - - return e.EndpointState() == StateListen, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } -} - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -2091,11 +2026,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { switch o := opt.(type) { - case *tcpip.BindToDeviceOption: - e.LockUser() - *o = tcpip.BindToDeviceOption(e.bindToDevice) - e.UnlockUser() - case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} e.LockUser() @@ -2123,10 +2053,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { *o = tcpip.TCPUserTimeoutOption(e.userTimeout) e.UnlockUser() - case *tcpip.OutOfBandInlineOption: - // We don't currently support disabling this option. - *o = 1 - case *tcpip.CongestionControlOption: e.LockUser() *o = e.cc @@ -2155,11 +2081,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { Port: port, } - case *tcpip.LingerOption: - e.LockUser() - *o = e.linger - e.UnlockUser() - default: return tcpip.ErrUnknownProtocolOption } @@ -2169,7 +2090,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err } @@ -2185,6 +2106,8 @@ func (*endpoint) Disconnect() *tcpip.Error { func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { err := e.connect(addr, true, true) if err != nil && !err.IgnoreStats() { + // Connect failed. Let's wake up any waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() } @@ -2244,7 +2167,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc return tcpip.ErrAlreadyConnecting case StateError: - return e.HardError + if err := e.hardErrorLocked(); err != nil { + return err + } + return tcpip.ErrConnectionAborted default: return tcpip.ErrInvalidEndpointState @@ -2302,11 +2228,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } } + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { if sameAddr && p == e.ID.RemotePort { return false, nil } - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { if err != tcpip.ErrPortInUse || !reuse { return false, nil } @@ -2344,15 +2271,15 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc tcpEP.notifyProtocolGoroutine(notifyAbort) tcpEP.UnlockUser() // Now try and Reserve again if it fails then we skip. - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { return false, nil } } id := e.ID id.LocalPort = p - if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr) + if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr) if err == tcpip.ErrPortInUse { return false, nil } @@ -2363,7 +2290,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // the selected port. e.ID = id e.isPortReserved = true - e.boundBindToDevice = e.bindToDevice + e.boundBindToDevice = bindToDevice e.boundPortFlags = e.portFlags e.boundDest = addr return true, nil @@ -2374,7 +2301,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc e.isRegistered = true e.setEndpointState(StateConnecting) - e.route = r.Clone() + r.Acquire() + e.route = r e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr @@ -2397,14 +2325,70 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } if run { - e.workerRunning = true - e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() - go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. + if err := e.startMainLoop(handshake); err != nil { + return err + } } return tcpip.ErrConnectStarted } +// startMainLoop sends the initial SYN and starts the main loop for the +// endpoint. +func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error { + preloop := func() *tcpip.Error { + if handshake { + h := e.newHandshake() + e.setEndpointState(StateSynSent) + if err := h.start(); err != nil { + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + e.setEndpointState(StateError) + e.hardError = err + + // Call cleanupLocked to free up any reservations. + e.cleanupLocked() + return err + } + } + e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() + return nil + } + + if e.route.IsResolutionRequired() { + // If the endpoint is closed between releasing e.mu and the goroutine below + // acquiring it, make sure that cleanup is deferred to the new goroutine. + e.workerRunning = true + + // Sending the initial SYN may block due to route resolution; do it in a + // separate goroutine to avoid blocking the syscall goroutine. + go func() { // S/R-SAFE: will be drained before save. + e.mu.Lock() + if err := preloop(); err != nil { + e.workerRunning = false + e.mu.Unlock() + return + } + e.mu.Unlock() + _ = e.protocolMainLoop(handshake, nil) + }() + return nil + } + + // No route resolution is required, so we can send the initial SYN here without + // blocking. This will hopefully reduce overall latency by overlapping time + // spent waiting for a SYN-ACK and time spent spinning up a new goroutine + // for the main loop. + if err := preloop(); err != nil { + return err + } + e.workerRunning = true + go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. + return nil +} + // ConnectEndpoint is not supported. func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { return tcpip.ErrInvalidEndpointState @@ -2642,7 +2626,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // v6only set to false. if netProto == header.IPv6ProtocolNumber { stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) - alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4 + alsoBindToV4 := !e.ops.GetV6Only() && addr.Addr == "" && stackHasV4 if alsoBindToV4 { netProtos = append(netProtos, header.IPv4ProtocolNumber) } @@ -2659,7 +2643,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { e.ID.LocalAddress = addr.Addr } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { id := e.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a @@ -2670,7 +2655,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // demuxer. Further connected endpoints always have a remote // address/port. Hence this will only return an error if there is a matching // listening endpoint. - if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil { + if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { return false } return true @@ -2679,7 +2664,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { return err } - e.boundBindToDevice = e.bindToDevice + e.boundBindToDevice = bindToDevice e.boundPortFlags = e.portFlags // TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct. e.boundNICID = nic @@ -2727,7 +2712,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress { func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first - // land at the Dispatcher which then can either delivery using the + // land at the Dispatcher which then can either deliver using the // worker go routine or directly do the invoke the tcp processing inline // based on the state of the endpoint. } @@ -2743,8 +2728,43 @@ func (e *endpoint) enqueueSegment(s *segment) bool { return true } +func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { + // Update last error first. + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + // Update the error queue if IP_RECVERR is enabled. + if e.SocketOptions().GetRecvError() { + e.SocketOptions().QueueErr(&tcpip.SockError{ + Err: err, + ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), + ErrType: errType, + ErrCode: errCode, + ErrInfo: extra, + // Linux passes the payload with the TCP header. We don't know if the TCP + // header even exists, it may not for fragmented packets. + Payload: pkt.Data.ToView(), + Dst: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, + }, + Offender: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, + }, + NetProto: pkt.NetworkProtocolNumber, + }) + } + + // Notify of the error. + e.notifyProtocolGoroutine(notifyError) +} + // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { switch typ { case stack.ControlPacketTooBig: e.sndBufMu.Lock() @@ -2757,16 +2777,10 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.notifyProtocolGoroutine(notifyMTUChanged) case stack.ControlNoRoute: - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrNoRoute - e.lastErrorMu.Unlock() - e.notifyProtocolGoroutine(notifyError) + e.onICMPError(tcpip.ErrNoRoute, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) case stack.ControlNetworkUnreachable: - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrNetworkUnreachable - e.lastErrorMu.Unlock() - e.notifyProtocolGoroutine(notifyError) + e.onICMPError(tcpip.ErrNetworkUnreachable, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) } } @@ -3024,6 +3038,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { Ssthresh: e.snd.sndSsthresh, SndCAAckCount: e.snd.sndCAAckCount, Outstanding: e.snd.outstanding, + SackedOut: e.snd.sackedOut, SndWnd: e.snd.sndWnd, SndUna: e.snd.sndUna, SndNxt: e.snd.sndNxt, @@ -3054,13 +3069,14 @@ func (e *endpoint) completeState() stack.TCPEndpointState { } } - rc := e.snd.rc + rc := &e.snd.rc s.Sender.RACKState = stack.TCPRACKState{ XmitTime: rc.xmitTime, EndSequence: rc.endSequence, FACK: rc.fack, RTT: rc.rtt, Reord: rc.reorderSeen, + DSACKSeen: rc.dsackSeen, } return s } @@ -3105,7 +3121,7 @@ func (e *endpoint) State() uint32 { func (e *endpoint) Info() tcpip.EndpointInfo { e.LockUser() // Make a copy of the endpoint info. - ret := e.EndpointInfo + ret := e.TransportEndpointInfo e.UnlockUser() return &ret } @@ -3130,3 +3146,8 @@ func (e *endpoint) Wait() { <-notifyCh } } + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (e *endpoint) SocketOptions() *tcpip.SocketOptions { + return &e.ops +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 2bcc5e1c2..ba67176b5 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -172,6 +172,7 @@ func (e *endpoint) afterLoad() { // Condition variables and mutexs are not S/R'ed so reinitialize // acceptCond with e.acceptMu. e.acceptCond = sync.NewCond(&e.acceptMu) + e.keepalive.timer.init(&e.keepalive.waker) stack.StackFromEnv.RegisterRestoredEndpoint(e) } @@ -320,21 +321,21 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) { } // saveHardError is invoked by stateify. -func (e *EndpointInfo) saveHardError() string { - if e.HardError == nil { +func (e *endpoint) saveHardError() string { + if e.hardError == nil { return "" } - return e.HardError.String() + return e.hardError.String() } // loadHardError is invoked by stateify. -func (e *EndpointInfo) loadHardError(s string) { +func (e *endpoint) loadHardError(s string) { if s == "" { return } - e.HardError = tcpip.StringToError(s) + e.hardError = tcpip.StringToError(s) } // saveMeasureTime is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 0664789da..596178625 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -152,7 +152,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, } f := r.forwarder - ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{ + ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{ MSS: r.synOptions.MSS, WS: r.synOptions.WS, TS: r.synOptions.TS, diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2329aca4b..c9e194f82 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -250,7 +250,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error ttl = route.DefaultTTL() } - return sendTCP(&route, tcpFields{ + return sendTCP(route, tcpFields{ id: s.id, ttl: ttl, tos: tos, @@ -405,7 +405,7 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.E case *tcpip.TCPRecovery: p.mu.RLock() - *v = tcpip.TCPRecovery(p.recovery) + *v = p.recovery p.mu.RUnlock() return nil @@ -543,7 +543,8 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, - recovery: tcpip.TCPRACKLossDetection, + // TODO(gvisor.dev/issue/5243): Set recovery to tcpip.TCPRACKLossDetection. + recovery: 0, } p.dispatcher.init(runtime.GOMAXPROCS(0)) return &p diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go index d312b1b8b..5a4ee70f5 100644 --- a/pkg/tcpip/transport/tcp/rack.go +++ b/pkg/tcpip/transport/tcp/rack.go @@ -17,9 +17,18 @@ package tcp import ( "time" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/seqnum" ) +// wcDelayedACKTimeout is the recommended maximum delayed ACK timer value as +// defined in https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5. +// It stands for worst case delayed ACK timer (WCDelAckT). When FlightSize is +// 1, PTO is inflated by WCDelAckT time to compensate for a potential long +// delayed ACK timer at the receiver. +const wcDelayedACKTimeout = 200 * time.Millisecond + // RACK is a loss detection algorithm used in TCP to detect packet loss and // reordering using transmission timestamp of the packets instead of packet or // sequence counts. To use RACK, SACK should be enabled on the connection. @@ -29,12 +38,12 @@ import ( // // +stateify savable type rackControl struct { + // dsackSeen indicates if the connection has seen a DSACK. + dsackSeen bool + // endSequence is the ending TCP sequence number of rackControl.seg. endSequence seqnum.Value - // dsack indicates if the connection has seen a DSACK. - dsack bool - // fack is the highest selectively or cumulatively acknowledged // sequence. fack seqnum.Value @@ -54,6 +63,23 @@ type rackControl struct { // xmitTime is the latest transmission timestamp of rackControl.seg. xmitTime time.Time `state:".(unixTime)"` + + // probeTimer and probeWaker are used to schedule PTO for RACK TLP algorithm. + probeTimer timer `state:"nosave"` + probeWaker sleep.Waker `state:"nosave"` + + // tlpRxtOut indicates whether there is an unacknowledged + // TLP retransmission. + tlpRxtOut bool + + // tlpHighRxt the value of sender.sndNxt at the time of sending + // a TLP retransmission. + tlpHighRxt seqnum.Value +} + +// init initializes RACK specific fields. +func (rc *rackControl) init() { + rc.probeTimer.init(&rc.probeWaker) } // update will update the RACK related fields when an ACK has been received. @@ -122,3 +148,103 @@ func (rc *rackControl) detectReorder(seg *segment) { rc.reorderSeen = true } } + +// setDSACKSeen updates rack control if duplicate SACK is seen by the connection. +func (rc *rackControl) setDSACKSeen() { + rc.dsackSeen = true +} + +// shouldSchedulePTO dictates whether we should schedule a PTO or not. +// See https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.1. +func (s *sender) shouldSchedulePTO() bool { + // Schedule PTO only if RACK loss detection is enabled. + return s.ep.tcpRecovery&tcpip.TCPRACKLossDetection != 0 && + // The connection supports SACK. + s.ep.sackPermitted && + // The connection is not in loss recovery. + (s.state != RTORecovery && s.state != SACKRecovery) && + // The connection has no SACKed sequences in the SACK scoreboard. + s.ep.scoreboard.Sacked() == 0 +} + +// schedulePTO schedules the probe timeout as defined in +// https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.1. +func (s *sender) schedulePTO() { + pto := time.Second + s.rtt.Lock() + if s.rtt.srttInited && s.rtt.srtt > 0 { + pto = s.rtt.srtt * 2 + if s.outstanding == 1 { + pto += wcDelayedACKTimeout + } + } + s.rtt.Unlock() + + now := time.Now() + if s.resendTimer.enabled() { + if now.Add(pto).After(s.resendTimer.target) { + pto = s.resendTimer.target.Sub(now) + } + s.resendTimer.disable() + } + + s.rc.probeTimer.enable(pto) +} + +// probeTimerExpired is the same as TLP_send_probe() as defined in +// https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.5.2. +func (s *sender) probeTimerExpired() *tcpip.Error { + if !s.rc.probeTimer.checkExpiration() { + return nil + } + // TODO(gvisor.dev/issue/5084): Implement this pseudo algorithm. + // If an unsent segment exists AND + // the receive window allows new data to be sent: + // Transmit the lowest-sequence unsent segment of up to SMSS + // Increment FlightSize by the size of the newly-sent segment + // Else if TLPRxtOut is not set: + // Retransmit the highest-sequence segment sent so far + // TLPRxtOut = true + // TLPHighRxt = SND.NXT + // The cwnd remains unchanged + // If FlightSize != 0: + // Arm RTO timer only. + return nil +} + +// detectTLPRecovery detects if recovery was accomplished by the loss probes +// and updates TLP state accordingly. +// See https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.3. +func (s *sender) detectTLPRecovery(ack seqnum.Value, rcvdSeg *segment) { + if !(s.ep.sackPermitted && s.rc.tlpRxtOut) { + return + } + + // Step 1. + if s.isDupAck(rcvdSeg) && ack == s.rc.tlpHighRxt { + var sbAboveTLPHighRxt bool + for _, sb := range rcvdSeg.parsedOptions.SACKBlocks { + if s.rc.tlpHighRxt.LessThan(sb.End) { + sbAboveTLPHighRxt = true + break + } + } + if !sbAboveTLPHighRxt { + // TLP episode is complete. + s.rc.tlpRxtOut = false + } + } + + if s.rc.tlpRxtOut && s.rc.tlpHighRxt.LessThanEq(ack) { + // TLP episode is complete. + s.rc.tlpRxtOut = false + if !checkDSACK(rcvdSeg) { + // Step 2. Either the original packet or the retransmission (in the + // form of a probe) was lost. Invoke a congestion control response + // equivalent to fast recovery. + s.cc.HandleNDupAcks() + s.enterRecovery() + s.leaveRecovery() + } + } +} diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go index c9dc7e773..76cad0831 100644 --- a/pkg/tcpip/transport/tcp/rack_state.go +++ b/pkg/tcpip/transport/tcp/rack_state.go @@ -27,3 +27,8 @@ func (rc *rackControl) saveXmitTime() unixTime { func (rc *rackControl) loadXmitTime(unix unixTime) { rc.xmitTime = time.Unix(unix.second, unix.nano) } + +// afterLoad is invoked by stateify. +func (rc *rackControl) afterLoad() { + rc.probeTimer.init(&rc.probeWaker) +} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 8e0b7c843..405a6dce7 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -16,6 +16,7 @@ package tcp import ( "container/heap" + "math" "time" "gvisor.dev/gvisor/pkg/tcpip" @@ -48,6 +49,10 @@ type receiver struct { rcvWndScale uint8 + // prevBufused is the snapshot of endpoint rcvBufUsed taken when we + // advertise a receive window. + prevBufUsed int + closed bool // pendingRcvdSegments is bounded by the receive buffer size of the @@ -80,9 +85,9 @@ func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { // outgoing packets, we should use what we have advertised for acceptability // test. scaledWindowSize := r.rcvWnd >> r.rcvWndScale - if scaledWindowSize > 0xffff { + if scaledWindowSize > math.MaxUint16 { // This is what we actually put in the Window field. - scaledWindowSize = 0xffff + scaledWindowSize = math.MaxUint16 } advertisedWindowSize := scaledWindowSize << r.rcvWndScale return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize)) @@ -106,6 +111,34 @@ func (r *receiver) currentWindow() (curWnd seqnum.Size) { func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { newWnd := r.ep.selectWindow() curWnd := r.currentWindow() + unackLen := int(r.ep.snd.maxSentAck.Size(r.rcvNxt)) + bufUsed := r.ep.receiveBufferUsed() + + // Grow the right edge of the window only for payloads larger than the + // the segment overhead OR if the application is actively consuming data. + // + // Avoiding growing the right edge otherwise, addresses a situation below: + // An application has been slow in reading data and we have burst of + // incoming segments lengths < segment overhead. Here, our available free + // memory would reduce drastically when compared to the advertised receive + // window. + // + // For example: With incoming 512 bytes segments, segment overhead of + // 552 bytes (at the time of writing this comment), with receive window + // starting from 1MB and with rcvAdvWndScale being 1, buffer would reach 0 + // when the curWnd is still 19436 bytes, because for every incoming segment + // newWnd would reduce by (552+512) >> rcvAdvWndScale (current value 1), + // while curWnd would reduce by 512 bytes. + // Such a situation causes us to keep tail dropping the incoming segments + // and never advertise zero receive window to the peer. + // + // Linux does a similar check for minimal sk_buff size (128): + // https://github.com/torvalds/linux/blob/d5beb3140f91b1c8a3d41b14d729aefa4dcc58bc/net/ipv4/tcp_input.c#L783 + // + // Also, if the application is reading the data, we keep growing the right + // edge, as we are still advertising a window that we think can be serviced. + toGrow := unackLen >= SegSize || bufUsed <= r.prevBufUsed + // Update rcvAcc only if new window is > previously advertised window. We // should never shrink the acceptable sequence space once it has been // advertised the peer. If we shrink the acceptable sequence space then we @@ -115,7 +148,7 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // rcvWUP rcvNxt rcvAcc new rcvAcc // <=====curWnd ===> // <========= newWnd > curWnd ========= > - if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) { + if r.rcvNxt.Add(seqnum.Size(curWnd)).LessThan(r.rcvNxt.Add(seqnum.Size(newWnd))) && toGrow { // If the new window moves the right edge, then update rcvAcc. r.rcvAcc = r.rcvNxt.Add(seqnum.Size(newWnd)) } else { @@ -130,11 +163,22 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // receiver's estimated RTT. r.rcvWnd = newWnd r.rcvWUP = r.rcvNxt + r.prevBufUsed = bufUsed scaledWnd := r.rcvWnd >> r.rcvWndScale if scaledWnd == 0 { // Increment a metric if we are advertising an actual zero window. r.ep.stats.ReceiveErrors.ZeroRcvWindowState.Increment() } + + // If we started off with a window larger than what can he held in + // the 16bit window field, we ceil the value to the max value. + if scaledWnd > math.MaxUint16 { + scaledWnd = seqnum.Size(math.MaxUint16) + + // Ensure that the stashed receive window always reflects what + // is being advertised. + r.rcvWnd = scaledWnd << r.rcvWndScale + } return r.rcvNxt, scaledWnd } diff --git a/pkg/tcpip/transport/tcp/reno_recovery.go b/pkg/tcpip/transport/tcp/reno_recovery.go new file mode 100644 index 000000000..2aa708e97 --- /dev/null +++ b/pkg/tcpip/transport/tcp/reno_recovery.go @@ -0,0 +1,67 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +// renoRecovery stores the variables related to TCP Reno loss recovery +// algorithm. +// +// +stateify savable +type renoRecovery struct { + s *sender +} + +func newRenoRecovery(s *sender) *renoRecovery { + return &renoRecovery{s: s} +} + +func (rr *renoRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { + ack := rcvdSeg.ackNumber + snd := rr.s + + // We are in fast recovery mode. Ignore the ack if it's out of range. + if !ack.InRange(snd.sndUna, snd.sndNxt+1) { + return + } + + // Don't count this as a duplicate if it is carrying data or + // updating the window. + if rcvdSeg.logicalLen() != 0 || snd.sndWnd != rcvdSeg.window { + return + } + + // Inflate the congestion window if we're getting duplicate acks + // for the packet we retransmitted. + if !fastRetransmit && ack == snd.fr.first { + // We received a dup, inflate the congestion window by 1 packet + // if we're not at the max yet. Only inflate the window if + // regular FastRecovery is in use, RFC6675 does not require + // inflating cwnd on duplicate ACKs. + if snd.sndCwnd < snd.fr.maxCwnd { + snd.sndCwnd++ + } + return + } + + // A partial ack was received. Retransmit this packet and remember it + // so that we don't retransmit it again. + // + // We don't inflate the window because we're putting the same packet + // back onto the wire. + // + // N.B. The retransmit timer will be reset by the caller. + snd.fr.first = ack + snd.dupAckCount = 0 + snd.resendSegment() +} diff --git a/pkg/tcpip/transport/tcp/sack_recovery.go b/pkg/tcpip/transport/tcp/sack_recovery.go new file mode 100644 index 000000000..7e813fa96 --- /dev/null +++ b/pkg/tcpip/transport/tcp/sack_recovery.go @@ -0,0 +1,120 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import "gvisor.dev/gvisor/pkg/tcpip/seqnum" + +// sackRecovery stores the variables related to TCP SACK loss recovery +// algorithm. +// +// +stateify savable +type sackRecovery struct { + s *sender +} + +func newSACKRecovery(s *sender) *sackRecovery { + return &sackRecovery{s: s} +} + +// handleSACKRecovery implements the loss recovery phase as described in RFC6675 +// section 5, step C. +func (sr *sackRecovery) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) { + snd := sr.s + snd.SetPipe() + + if smss := int(snd.ep.scoreboard.SMSS()); limit > smss { + // Cap segment size limit to s.smss as SACK recovery requires + // that all retransmissions or new segments send during recovery + // be of <= SMSS. + limit = smss + } + + nextSegHint := snd.writeList.Front() + for snd.outstanding < snd.sndCwnd { + var nextSeg *segment + var rescueRtx bool + nextSeg, nextSegHint, rescueRtx = snd.NextSeg(nextSegHint) + if nextSeg == nil { + return dataSent + } + if !snd.isAssignedSequenceNumber(nextSeg) || snd.sndNxt.LessThanEq(nextSeg.sequenceNumber) { + // New data being sent. + + // Step C.3 described below is handled by + // maybeSendSegment which increments sndNxt when + // a segment is transmitted. + // + // Step C.3 "If any of the data octets sent in + // (C.1) are above HighData, HighData must be + // updated to reflect the transmission of + // previously unsent data." + // + // We pass s.smss as the limit as the Step 2) requires that + // new data sent should be of size s.smss or less. + if sent := snd.maybeSendSegment(nextSeg, limit, end); !sent { + return dataSent + } + dataSent = true + snd.outstanding++ + snd.writeNext = nextSeg.Next() + continue + } + + // Now handle the retransmission case where we matched either step 1,3 or 4 + // of the NextSeg algorithm. + // RFC 6675, Step C.4. + // + // "The estimate of the amount of data outstanding in the network + // must be updated by incrementing pipe by the number of octets + // transmitted in (C.1)." + snd.outstanding++ + dataSent = true + snd.sendSegment(nextSeg) + + segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen()) + if rescueRtx { + // We do the last part of rule (4) of NextSeg here to update + // RescueRxt as until this point we don't know if we are going + // to use the rescue transmission. + snd.fr.rescueRxt = snd.fr.last + } else { + // RFC 6675, Step C.2 + // + // "If any of the data octets sent in (C.1) are below + // HighData, HighRxt MUST be set to the highest sequence + // number of the retransmitted segment unless NextSeg () + // rule (4) was invoked for this retransmission." + snd.fr.highRxt = segEnd - 1 + } + } + return dataSent +} + +func (sr *sackRecovery) DoRecovery(rcvdSeg *segment, fastRetransmit bool) { + snd := sr.s + if fastRetransmit { + snd.resendSegment() + } + + // We are in fast recovery mode. Ignore the ack if it's out of range. + if ack := rcvdSeg.ackNumber; !ack.InRange(snd.sndUna, snd.sndNxt+1) { + return + } + + // RFC 6675 recovery algorithm step C 1-5. + end := snd.sndUna.Add(snd.sndWnd) + dataSent := sr.handleSACKRecovery(snd.maxPayloadSize, end) + snd.postXmit(dataSent) +} diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 2091989cc..c5a6d2fba 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -37,7 +37,7 @@ const ( // segment represents a TCP segment. It holds the payload and parsed TCP segment // information, and can be added to intrusive lists. -// segment is mostly immutable, the only field allowed to change is viewToDeliver. +// segment is mostly immutable, the only field allowed to change is data. // // +stateify savable type segment struct { @@ -60,10 +60,7 @@ type segment struct { hdr header.TCP // views is used as buffer for data when its length is large // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` - // viewToDeliver keeps track of the next View that should be - // delivered by the Read endpoint. - viewToDeliver int + views [8]buffer.View `state:"nosave"` sequenceNumber seqnum.Value ackNumber seqnum.Value flags uint8 @@ -84,6 +81,9 @@ type segment struct { // acked indicates if the segment has already been SACKed. acked bool + + // dataMemSize is the memory used by data initially. + dataMemSize int } func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { @@ -100,6 +100,7 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) * s.data = pkt.Data.Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) s.rcvdTime = time.Now() + s.dataMemSize = s.data.Size() return s } @@ -113,6 +114,7 @@ func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment { s.views[0] = v s.data = buffer.NewVectorisedView(len(v), s.views[:1]) } + s.dataMemSize = s.data.Size() return s } @@ -127,12 +129,12 @@ func (s *segment) clone() *segment { netProto: s.netProto, nicID: s.nicID, remoteLinkAddr: s.remoteLinkAddr, - viewToDeliver: s.viewToDeliver, rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, xmitCount: s.xmitCount, ep: s.ep, qFlags: s.qFlags, + dataMemSize: s.dataMemSize, } t.data = s.data.Clone(t.views[:]) return t @@ -204,7 +206,7 @@ func (s *segment) payloadSize() int { // segMemSize is the amount of memory used to hold the segment data and // the associated metadata. func (s *segment) segMemSize() int { - return segSize + s.data.Size() + return SegSize + s.dataMemSize } // parse populates the sequence & ack numbers, flags, and window fields of the diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go index 7dc2741a6..7422d8c02 100644 --- a/pkg/tcpip/transport/tcp/segment_state.go +++ b/pkg/tcpip/transport/tcp/segment_state.go @@ -24,16 +24,11 @@ import ( func (s *segment) saveData() buffer.VectorisedView { // We cannot save s.data directly as s.data.views may alias to s.views, // which is not allowed by state framework (in-struct pointer). - v := make([]buffer.View, len(s.data.Views())) - // For views already delivered, we cannot save them directly as they may - // have already been sliced and saved elsewhere (e.g., readViews). - for i := 0; i < s.viewToDeliver; i++ { - v[i] = append([]byte(nil), s.data.Views()[i]...) + vs := make([]buffer.View, len(s.data.Views())) + for i, v := range s.data.Views() { + vs[i] = v } - for i := s.viewToDeliver; i < len(v); i++ { - v[i] = s.data.Views()[i] - } - return buffer.NewVectorisedView(s.data.Size(), v) + return buffer.NewVectorisedView(s.data.Size(), vs) } // loadData is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go index 0ab7b8f56..392ff0859 100644 --- a/pkg/tcpip/transport/tcp/segment_unsafe.go +++ b/pkg/tcpip/transport/tcp/segment_unsafe.go @@ -19,5 +19,6 @@ import ( ) const ( - segSize = int(unsafe.Sizeof(segment{})) + // SegSize is the minimal size of the segment overhead. + SegSize = int(unsafe.Sizeof(segment{})) ) diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index ab5fa4fb7..079d90848 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -18,7 +18,6 @@ import ( "fmt" "math" "sort" - "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" @@ -92,6 +91,17 @@ type congestionControl interface { PostRecovery() } +// lossRecovery is an interface that must be implemented by any supported +// loss recovery algorithm. +type lossRecovery interface { + // DoRecovery is invoked when loss is detected and segments need + // to be retransmitted. The cumulative or selective ACK is passed along + // with the flag which identifies whether the connection entered fast + // retransmit with this ACK and to retransmit the first unacknowledged + // segment. + DoRecovery(rcvdSeg *segment, fastRetransmit bool) +} + // sender holds the state necessary to send TCP segments. // // +stateify savable @@ -108,6 +118,9 @@ type sender struct { // fr holds state related to fast recovery. fr fastRecovery + // lr is the loss recovery algorithm used by the sender. + lr lossRecovery + // sndCwnd is the congestion window, in packets. sndCwnd int @@ -124,6 +137,9 @@ type sender struct { // that have been sent but not yet acknowledged. outstanding int + // sackedOut is the number of packets which are selectively acked. + sackedOut int + // sndWnd is the send window size. sndWnd seqnum.Size @@ -270,12 +286,16 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint gso: ep.gso != nil, } + s.rc.init() + if s.gso { s.ep.gso.MSS = uint16(maxPayloadSize) } s.cc = s.initCongestionControl(ep.cc) + s.lr = s.initLossRecovery() + // A negative sndWndScale means that no scaling is in use, otherwise we // store the scaling value. if sndWndScale > 0 { @@ -330,6 +350,14 @@ func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionCon } } +// initLossRecovery initiates the loss recovery algorithm for the sender. +func (s *sender) initLossRecovery() lossRecovery { + if s.ep.sackPermitted { + return newSACKRecovery(s) + } + return newRenoRecovery(s) +} + // updateMaxPayloadSize updates the maximum payload size based on the given // MTU. If this is in response to "packet too big" control packets (indicated // by the count argument), it also reduces the number of outstanding packets and @@ -349,6 +377,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { m = 1 } + oldMSS := s.maxPayloadSize s.maxPayloadSize = m if s.gso { s.ep.gso.MSS = uint16(m) @@ -371,6 +400,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // Rewind writeNext to the first segment exceeding the MTU. Do nothing // if it is already before such a packet. + nextSeg := s.writeNext for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { if seg == s.writeNext { // We got to writeNext before we could find a segment @@ -378,16 +408,22 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { break } - if seg.data.Size() > m { + if nextSeg == s.writeNext && seg.data.Size() > m { // We found a segment exceeding the MTU. Rewind // writeNext and try to retransmit it. - s.writeNext = seg - break + nextSeg = seg + } + + if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + // Update sackedOut for new maximum payload size. + s.sackedOut -= s.pCount(seg, oldMSS) + s.sackedOut += s.pCount(seg, s.maxPayloadSize) } } // Since we likely reduced the number of outstanding packets, we may be // ready to send some more. + s.writeNext = nextSeg s.sendData() } @@ -497,6 +533,10 @@ func (s *sender) retransmitTimerExpired() bool { s.ep.stack.Stats().TCP.Timeouts.Increment() s.ep.stats.SendErrors.Timeouts.Increment() + // Set TLPRxtOut to false according to + // https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.1. + s.rc.tlpRxtOut = false + // Give up if we've waited more than a minute since the last resend or // if a user time out is set and we have exceeded the user specified // timeout since the first retransmission. @@ -550,7 +590,7 @@ func (s *sender) retransmitTimerExpired() bool { // We were attempting fast recovery but were not successful. // Leave the state. We don't need to update ssthresh because it // has already been updated when entered fast-recovery. - s.leaveFastRecovery() + s.leaveRecovery() } s.state = RTORecovery @@ -606,13 +646,13 @@ func (s *sender) retransmitTimerExpired() bool { // pCount returns the number of packets in the segment. Due to GSO, a segment // can be composed of multiple packets. -func (s *sender) pCount(seg *segment) int { +func (s *sender) pCount(seg *segment, maxPayloadSize int) int { size := seg.data.Size() if size == 0 { return 1 } - return (size-1)/s.maxPayloadSize + 1 + return (size-1)/maxPayloadSize + 1 } // splitSeg splits a given segment at the size specified and inserts the @@ -789,7 +829,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } if !nextTooBig && seg.data.Size() < available { // Segment is not full. - if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 { + if s.outstanding > 0 && s.ep.ops.GetDelayOption() { // Nagle's algorithm. From Wikipedia: // Nagle's algorithm works by // combining a number of small @@ -808,7 +848,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se // send space and MSS. // TODO(gvisor.dev/issue/2833): Drain the held segments after a // timeout. - if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 { + if seg.data.Size() < s.maxPayloadSize && s.ep.ops.GetCorkOption() { return false } } @@ -913,79 +953,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se return true } -// handleSACKRecovery implements the loss recovery phase as described in RFC6675 -// section 5, step C. -func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) { - s.SetPipe() - - if smss := int(s.ep.scoreboard.SMSS()); limit > smss { - // Cap segment size limit to s.smss as SACK recovery requires - // that all retransmissions or new segments send during recovery - // be of <= SMSS. - limit = smss - } - - nextSegHint := s.writeList.Front() - for s.outstanding < s.sndCwnd { - var nextSeg *segment - var rescueRtx bool - nextSeg, nextSegHint, rescueRtx = s.NextSeg(nextSegHint) - if nextSeg == nil { - return dataSent - } - if !s.isAssignedSequenceNumber(nextSeg) || s.sndNxt.LessThanEq(nextSeg.sequenceNumber) { - // New data being sent. - - // Step C.3 described below is handled by - // maybeSendSegment which increments sndNxt when - // a segment is transmitted. - // - // Step C.3 "If any of the data octets sent in - // (C.1) are above HighData, HighData must be - // updated to reflect the transmission of - // previously unsent data." - // - // We pass s.smss as the limit as the Step 2) requires that - // new data sent should be of size s.smss or less. - if sent := s.maybeSendSegment(nextSeg, limit, end); !sent { - return dataSent - } - dataSent = true - s.outstanding++ - s.writeNext = nextSeg.Next() - continue - } - - // Now handle the retransmission case where we matched either step 1,3 or 4 - // of the NextSeg algorithm. - // RFC 6675, Step C.4. - // - // "The estimate of the amount of data outstanding in the network - // must be updated by incrementing pipe by the number of octets - // transmitted in (C.1)." - s.outstanding++ - dataSent = true - s.sendSegment(nextSeg) - - segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen()) - if rescueRtx { - // We do the last part of rule (4) of NextSeg here to update - // RescueRxt as until this point we don't know if we are going - // to use the rescue transmission. - s.fr.rescueRxt = s.fr.last - } else { - // RFC 6675, Step C.2 - // - // "If any of the data octets sent in (C.1) are below - // HighData, HighRxt MUST be set to the highest sequence - // number of the retransmitted segment unless NextSeg () - // rule (4) was invoked for this retransmission." - s.fr.highRxt = segEnd - 1 - } - } - return dataSent -} - func (s *sender) sendZeroWindowProbe() { ack, win := s.ep.rcv.getSendParams() s.unackZeroWindowProbes++ @@ -1014,6 +981,30 @@ func (s *sender) disableZeroWindowProbing() { s.resendTimer.disable() } +func (s *sender) postXmit(dataSent bool) { + if dataSent { + // We sent data, so we should stop the keepalive timer to ensure + // that no keepalives are sent while there is pending data. + s.ep.disableKeepaliveTimer() + } + + // If the sender has advertized zero receive window and we have + // data to be sent out, start zero window probing to query the + // the remote for it's receive window size. + if s.writeNext != nil && s.sndWnd == 0 { + s.enableZeroWindowProbing() + } + + // Enable the timer if we have pending data and it's not enabled yet. + if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { + s.resendTimer.enable(s.rto) + } + // If we have no more pending data, start the keepalive timer. + if s.sndUna == s.sndNxt { + s.ep.resetKeepaliveTimer(false) + } +} + // sendData sends new data segments. It is called when data becomes available or // when the send window opens up. func (s *sender) sendData() { @@ -1034,55 +1025,29 @@ func (s *sender) sendData() { } var dataSent bool - - // RFC 6675 recovery algorithm step C 1-5. - if s.fr.active && s.ep.sackPermitted { - dataSent = s.handleSACKRecovery(s.maxPayloadSize, end) - } else { - for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { - cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize - if cwndLimit < limit { - limit = cwndLimit - } - if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { - // Move writeNext along so that we don't try and scan data that - // has already been SACKED. - s.writeNext = seg.Next() - continue - } - if sent := s.maybeSendSegment(seg, limit, end); !sent { - break - } - dataSent = true - s.outstanding += s.pCount(seg) + for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { + cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize + if cwndLimit < limit { + limit = cwndLimit + } + if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + // Move writeNext along so that we don't try and scan data that + // has already been SACKED. s.writeNext = seg.Next() + continue } + if sent := s.maybeSendSegment(seg, limit, end); !sent { + break + } + dataSent = true + s.outstanding += s.pCount(seg, s.maxPayloadSize) + s.writeNext = seg.Next() } - if dataSent { - // We sent data, so we should stop the keepalive timer to ensure - // that no keepalives are sent while there is pending data. - s.ep.disableKeepaliveTimer() - } - - // If the sender has advertized zero receive window and we have - // data to be sent out, start zero window probing to query the - // the remote for it's receive window size. - if s.writeNext != nil && s.sndWnd == 0 { - s.enableZeroWindowProbing() - } - - // Enable the timer if we have pending data and it's not enabled yet. - if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { - s.resendTimer.enable(s.rto) - } - // If we have no more pending data, start the keepalive timer. - if s.sndUna == s.sndNxt { - s.ep.resetKeepaliveTimer(false) - } + s.postXmit(dataSent) } -func (s *sender) enterFastRecovery() { +func (s *sender) enterRecovery() { s.fr.active = true // Save state to reflect we're now in fast recovery. // @@ -1090,6 +1055,7 @@ func (s *sender) enterFastRecovery() { // We inflate the cwnd by 3 to account for the 3 packets which triggered // the 3 duplicate ACKs and are now not in flight. s.sndCwnd = s.sndSsthresh + 3 + s.sackedOut = 0 s.fr.first = s.sndUna s.fr.last = s.sndNxt - 1 s.fr.maxCwnd = s.sndCwnd + s.outstanding @@ -1098,13 +1064,16 @@ func (s *sender) enterFastRecovery() { if s.ep.sackPermitted { s.state = SACKRecovery s.ep.stack.Stats().TCP.SACKRecovery.Increment() + // Set TLPRxtOut to false according to + // https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.6.1. + s.rc.tlpRxtOut = false return } s.state = FastRecovery s.ep.stack.Stats().TCP.FastRecovery.Increment() } -func (s *sender) leaveFastRecovery() { +func (s *sender) leaveRecovery() { s.fr.active = false s.fr.maxCwnd = 0 s.dupAckCount = 0 @@ -1115,57 +1084,6 @@ func (s *sender) leaveFastRecovery() { s.cc.PostRecovery() } -func (s *sender) handleFastRecovery(seg *segment) (rtx bool) { - ack := seg.ackNumber - // We are in fast recovery mode. Ignore the ack if it's out of - // range. - if !ack.InRange(s.sndUna, s.sndNxt+1) { - return false - } - - // Leave fast recovery if it acknowledges all the data covered by - // this fast recovery session. - if s.fr.last.LessThan(ack) { - s.leaveFastRecovery() - return false - } - - if s.ep.sackPermitted { - // When SACK is enabled we let retransmission be governed by - // the SACK logic. - return false - } - - // Don't count this as a duplicate if it is carrying data or - // updating the window. - if seg.logicalLen() != 0 || s.sndWnd != seg.window { - return false - } - - // Inflate the congestion window if we're getting duplicate acks - // for the packet we retransmitted. - if ack == s.fr.first { - // We received a dup, inflate the congestion window by 1 packet - // if we're not at the max yet. Only inflate the window if - // regular FastRecovery is in use, RFC6675 does not require - // inflating cwnd on duplicate ACKs. - if s.sndCwnd < s.fr.maxCwnd { - s.sndCwnd++ - } - return false - } - - // A partial ack was received. Retransmit this packet and - // remember it so that we don't retransmit it again. We don't - // inflate the window because we're putting the same packet back - // onto the wire. - // - // N.B. The retransmit timer will be reset by the caller. - s.fr.first = ack - s.dupAckCount = 0 - return true -} - // isAssignedSequenceNumber relies on the fact that we only set flags once a // sequencenumber is assigned and that is only done right before we send the // segment. As a result any segment that has a non-zero flag has a valid @@ -1228,26 +1146,15 @@ func (s *sender) SetPipe() { s.outstanding = pipe } -// checkDuplicateAck is called when an ack is received. It manages the state -// related to duplicate acks and determines if a retransmit is needed according -// to the rules in RFC 6582 (NewReno). -func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { - ack := seg.ackNumber - if s.fr.active { - return s.handleFastRecovery(seg) - } - - // We're not in fast recovery yet. A segment is considered a duplicate - // only if it doesn't carry any data and doesn't update the send window, - // because if it does, it wasn't sent in response to an out-of-order - // segment. If SACK is enabled then we have an additional check to see - // if the segment carries new SACK information. If it does then it is - // considered a duplicate ACK as per RFC6675. - if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt { - if !s.ep.sackPermitted || !seg.hasNewSACKInfo { - s.dupAckCount = 0 - return false - } +// detectLoss is called when an ack is received and returns whether a loss is +// detected. It manages the state related to duplicate acks and determines if +// a retransmit is needed according to the rules in RFC 6582 (NewReno). +func (s *sender) detectLoss(seg *segment) (fastRetransmit bool) { + // We're not in fast recovery yet. + + if !s.isDupAck(seg) { + s.dupAckCount = 0 + return false } s.dupAckCount++ @@ -1266,18 +1173,43 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2 // // We only do the check here, the incrementing of last to the highest - // sequence number transmitted till now is done when enterFastRecovery + // sequence number transmitted till now is done when enterRecovery // is invoked. if !s.fr.last.LessThan(seg.ackNumber) { s.dupAckCount = 0 return false } s.cc.HandleNDupAcks() - s.enterFastRecovery() + s.enterRecovery() s.dupAckCount = 0 return true } +// isDupAck determines if seg is a duplicate ack as defined in +// https://tools.ietf.org/html/rfc5681#section-2. +func (s *sender) isDupAck(seg *segment) bool { + // A TCP that utilizes selective acknowledgments (SACKs) [RFC2018, RFC2883] + // can leverage the SACK information to determine when an incoming ACK is a + // "duplicate" (e.g., if the ACK contains previously unknown SACK + // information). + if s.ep.sackPermitted && !seg.hasNewSACKInfo { + return false + } + + // (a) The receiver of the ACK has outstanding data. + return s.sndUna != s.sndNxt && + // (b) The incoming acknowledgment carries no data. + seg.logicalLen() == 0 && + // (c) The SYN and FIN bits are both off. + !seg.flagIsSet(header.TCPFlagFin) && !seg.flagIsSet(header.TCPFlagSyn) && + // (d) the ACK number is equal to the greatest acknowledgment received on + // the given connection (TCP.UNA from RFC793). + seg.ackNumber == s.sndUna && + // (e) the advertised window in the incoming acknowledgment equals the + // advertised window in the last incoming acknowledgment. + s.sndWnd == seg.window +} + // Iterate the writeList and update RACK for each segment which is newly acked // either cumulatively or selectively. Loop through the segments which are // sacked, and update the RACK related variables and check for reordering. @@ -1285,36 +1217,85 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // steps 2 and 3. func (s *sender) walkSACK(rcvdSeg *segment) { - if len(rcvdSeg.parsedOptions.SACKBlocks) == 0 { + // Look for DSACK block. + idx := 0 + n := len(rcvdSeg.parsedOptions.SACKBlocks) + if checkDSACK(rcvdSeg) { + s.rc.setDSACKSeen() + idx = 1 + n-- + } + + if n == 0 { return } // Sort the SACK blocks. The first block is the most recent unacked // block. The following blocks can be in arbitrary order. - sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks)) - copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks) + sackBlocks := make([]header.SACKBlock, n) + copy(sackBlocks, rcvdSeg.parsedOptions.SACKBlocks[idx:]) sort.Slice(sackBlocks, func(i, j int) bool { return sackBlocks[j].Start.LessThan(sackBlocks[i].Start) }) seg := s.writeList.Front() for _, sb := range sackBlocks { - // This check excludes DSACK blocks. - if sb.Start.LessThanEq(rcvdSeg.ackNumber) || sb.Start.LessThanEq(s.sndUna) || s.sndNxt.LessThan(sb.End) { - continue - } - for seg != nil && seg.sequenceNumber.LessThan(sb.End) && seg.xmitCount != 0 { if sb.Start.LessThanEq(seg.sequenceNumber) && !seg.acked { s.rc.update(seg, rcvdSeg, s.ep.tsOffset) s.rc.detectReorder(seg) seg.acked = true + s.sackedOut += s.pCount(seg, s.maxPayloadSize) } seg = seg.Next() } } } +// checkDSACK checks if a DSACK is reported. +func checkDSACK(rcvdSeg *segment) bool { + n := len(rcvdSeg.parsedOptions.SACKBlocks) + if n == 0 { + return false + } + + sb := rcvdSeg.parsedOptions.SACKBlocks[0] + // Check if SACK block is invalid. + if sb.End.LessThan(sb.Start) { + return false + } + + // See: https://tools.ietf.org/html/rfc2883#section-5 DSACK is sent in + // at most one SACK block. DSACK is detected in the below two cases: + // * If the SACK sequence space is less than this cumulative ACK, it is + // an indication that the segment identified by the SACK block has + // been received more than once by the receiver. + // * If the sequence space in the first SACK block is greater than the + // cumulative ACK, then the sender next compares the sequence space + // in the first SACK block with the sequence space in the second SACK + // block, if there is one. This comparison can determine if the first + // SACK block is reporting duplicate data that lies above the + // cumulative ACK. + if sb.Start.LessThan(rcvdSeg.ackNumber) { + return true + } + + if n > 1 { + sb1 := rcvdSeg.parsedOptions.SACKBlocks[1] + if sb1.End.LessThan(sb1.Start) { + return false + } + + // If the first SACK block is fully covered by second SACK + // block, then the first block is a DSACK block. + if sb.End.LessThanEq(sb1.End) && sb1.Start.LessThanEq(sb.Start) { + return true + } + } + + return false +} + // handleRcvdSegment is called when a segment is received; it is responsible for // updating the send-related state. func (s *sender) handleRcvdSegment(rcvdSeg *segment) { @@ -1367,14 +1348,26 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { s.SetPipe() } - // Count the duplicates and do the fast retransmit if needed. - rtx := s.checkDuplicateAck(rcvdSeg) + ack := rcvdSeg.ackNumber + fastRetransmit := false + // Do not leave fast recovery, if the ACK is out of range. + if s.fr.active { + // Leave fast recovery if it acknowledges all the data covered by + // this fast recovery session. + if ack.InRange(s.sndUna, s.sndNxt+1) && s.fr.last.LessThan(ack) { + s.leaveRecovery() + } + } else { + // Detect loss by counting the duplicates and enter recovery. + fastRetransmit = s.detectLoss(rcvdSeg) + } + + // See if TLP based recovery was successful. + s.detectTLPRecovery(ack, rcvdSeg) // Stash away the current window size. s.sndWnd = rcvdSeg.window - ack := rcvdSeg.ackNumber - // Disable zero window probing if remote advertizes a non-zero receive // window. This can be with an ACK to the zero window probe (where the // acknumber refers to the already acknowledged byte) OR to any previously @@ -1429,10 +1422,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { datalen := seg.logicalLen() if datalen > ackLeft { - prevCount := s.pCount(seg) + prevCount := s.pCount(seg, s.maxPayloadSize) seg.data.TrimFront(int(ackLeft)) seg.sequenceNumber.UpdateForward(ackLeft) - s.outstanding -= prevCount - s.pCount(seg) + s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize) break } @@ -1448,11 +1441,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { s.writeList.Remove(seg) - // If SACK is enabled then Only reduce outstanding if + // If SACK is enabled then only reduce outstanding if // the segment was not previously SACKED as these have // already been accounted for in SetPipe(). if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { - s.outstanding -= s.pCount(seg) + s.outstanding -= s.pCount(seg, s.maxPayloadSize) + } else { + s.sackedOut -= s.pCount(seg, s.maxPayloadSize) } seg.decRef() ackLeft -= datalen @@ -1489,21 +1484,27 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { // Reset firstRetransmittedSegXmitTime to the zero value. s.firstRetransmittedSegXmitTime = time.Time{} s.resendTimer.disable() + s.rc.probeTimer.disable() } } + // Now that we've popped all acknowledged data from the retransmit // queue, retransmit if needed. - if rtx { - s.resendSegment() + if s.fr.active { + s.lr.DoRecovery(rcvdSeg, fastRetransmit) + // When SACK is enabled data sending is governed by steps in + // RFC 6675 Section 5 recovery steps A-C. + // See: https://tools.ietf.org/html/rfc6675#section-5. + if s.ep.sackPermitted { + return + } } // Send more data now that some of the pending data has been ack'd, or // that the window opened up, or the congestion window was inflated due // to a duplicate ack during fast recovery. This will also re-enable // the retransmit timer if needed. - if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo { - s.sendData() - } + s.sendData() } // sendSegment sends the specified segment. diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index d3f92b48c..9818ffa0f 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -30,15 +30,17 @@ const ( maxPayload = 10 tsOptionSize = 12 maxTCPOptionSize = 40 + mtu = header.TCPMinimumSize + header.IPv4MinimumSize + maxTCPOptionSize + maxPayload ) // TestRACKUpdate tests the RACK related fields are updated when an ACK is // received on a SACK enabled connection. func TestRACKUpdate(t *testing.T) { - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) + c := context.New(t, uint32(mtu)) defer c.Cleanup() var xmitTime time.Time + probeDone := make(chan struct{}) c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { // Validate that the endpoint Sender.RACKState is what we expect. if state.Sender.RACKState.XmitTime.Before(xmitTime) { @@ -54,6 +56,7 @@ func TestRACKUpdate(t *testing.T) { if state.Sender.RACKState.RTT == 0 { t.Fatalf("RACK RTT failed to update when an ACK is received, got RACKState.RTT == 0 want != 0") } + close(probeDone) }) setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) @@ -73,18 +76,20 @@ func TestRACKUpdate(t *testing.T) { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload c.SendAck(seqnum.Value(context.TestInitialSequenceNumber).Add(1), bytesRead) - time.Sleep(200 * time.Millisecond) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + <-probeDone } // TestRACKDetectReorder tests that RACK detects packet reordering. func TestRACKDetectReorder(t *testing.T) { - c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) + c := context.New(t, uint32(mtu)) defer c.Cleanup() - const ackNum = 2 - var n int - ch := make(chan struct{}) + const ackNumToVerify = 2 + probeDone := make(chan struct{}) c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { gotSeq := state.Sender.RACKState.FACK wantSeq := state.Sender.SndNxt @@ -95,7 +100,7 @@ func TestRACKDetectReorder(t *testing.T) { } n++ - if n < ackNum { + if n < ackNumToVerify { if state.Sender.RACKState.Reord { t.Fatalf("RACK reorder detected when there is no reordering") } @@ -105,11 +110,11 @@ func TestRACKDetectReorder(t *testing.T) { if state.Sender.RACKState.Reord == false { t.Fatalf("RACK reorder detection failed") } - close(ch) + close(probeDone) }) setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(ackNum * maxPayload) + data := buffer.NewView(ackNumToVerify * maxPayload) for i := range data { data[i] = byte(i) } @@ -120,7 +125,7 @@ func TestRACKDetectReorder(t *testing.T) { } bytesRead := 0 - for i := 0; i < ackNum; i++ { + for i := 0; i < ackNumToVerify; i++ { c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) bytesRead += maxPayload } @@ -133,5 +138,393 @@ func TestRACKDetectReorder(t *testing.T) { // Wait for the probe function to finish processing the ACK before the // test completes. - <-ch + <-probeDone +} + +func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.View { + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + + data := buffer.NewView(numPackets * maxPayload) + for i := range data { + data[i] = byte(i) + } + + // Write the data. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + bytesRead := 0 + for i := 0; i < numPackets; i++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + } + + return data +} + +const ( + validDSACKDetected = 1 + failedToDetectDSACK = 2 + invalidDSACKDetected = 3 +) + +func addDSACKSeenCheckerProbe(t *testing.T, c *context.Context, numACK int, probeDone chan int) { + var n int + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that RACK detects DSACK. + n++ + if n < numACK { + if state.Sender.RACKState.DSACKSeen { + probeDone <- invalidDSACKDetected + } + return + } + + if !state.Sender.RACKState.DSACKSeen { + probeDone <- failedToDetectDSACK + return + } + probeDone <- validDSACKDetected + }) +} + +// TestRACKDetectDSACK tests that RACK detects DSACK with duplicate segments. +// See: https://tools.ietf.org/html/rfc2883#section-4.1.1. +func TestRACKDetectDSACK(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 2 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 8 + data := sendAndReceive(t, c, numPackets) + + // Cumulative ACK for [1-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // Expect retransmission of #6 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Send DSACK block for #6 packet indicating both + // initial and retransmitted packet are received and + // packets [1-7] are received. + start := c.IRS.Add(seqnum.Size(bytesRead)) + end := start.Add(maxPayload) + bytesRead += 2 * maxPayload + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKWithOutOfOrder tests that RACK detects DSACK with out of +// order segments. +// See: https://tools.ietf.org/html/rfc2883#section-4.1.2. +func TestRACKDetectDSACKWithOutOfOrder(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 2 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 10 + data := sendAndReceive(t, c, numPackets) + + // Cumulative ACK for [1-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // Expect retransmission of #6 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Send DSACK block for #6 packet indicating both + // initial and retransmitted packet are received and + // packets [1-7] are received. + start := c.IRS.Add(seqnum.Size(bytesRead)) + end := start.Add(maxPayload) + bytesRead += 2 * maxPayload + // Send DSACK block for #6 along with out of + // order #9 packet is received. + start1 := c.IRS.Add(seqnum.Size(bytesRead) + maxPayload) + end1 := start1.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}, {start1, end1}}) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKWithOutOfOrderDup tests that DSACK is detected on a +// duplicate of out of order packet. +// See: https://tools.ietf.org/html/rfc2883#section-4.1.3 +func TestRACKDetectDSACKWithOutOfOrderDup(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 4 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 10 + sendAndReceive(t, c, numPackets) + + // ACK [1-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // Send SACK indicating #6 packet is missing and received #7 packet. + offset := seqnum.Size(bytesRead + maxPayload) + start := c.IRS.Add(1 + offset) + end := start.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Send SACK with #6 packet is missing and received [7-8] packets. + end = start.Add(2 * maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Consider #8 packet is duplicated on the network and send DSACK. + dsackStart := c.IRS.Add(1 + offset + maxPayload) + dsackEnd := dsackStart.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKSingleDup tests DSACK for a single duplicate subsegment. +// See: https://tools.ietf.org/html/rfc2883#section-4.2.1. +func TestRACKDetectDSACKSingleDup(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 4 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 4 + data := sendAndReceive(t, c, numPackets) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-3] packets and received #4 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Expect retransmission of #2 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // ACK for retransmitted #2 packet. + bytesRead += maxPayload + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Simulate receving delayed subsegment of #2 packet and delayed #3 packet by + // sending DSACK block for the subsegment. + dsackStart := c.IRS.Add(1 + seqnum.Size(bytesRead)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) + c.SendAckWithSACK(seq, numPackets*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKDupWithCumulativeACK tests DSACK for two non-contiguous +// duplicate subsegments covered by the cumulative acknowledgement. +// See: https://tools.ietf.org/html/rfc2883#section-4.2.2. +func TestRACKDetectDSACKDupWithCumulativeACK(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 5 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 6 + data := sendAndReceive(t, c, numPackets) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-5] packets and received #6 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(5*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Expect retransmission of #2 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Received delayed #2 packet. + bytesRead += maxPayload + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Received delayed #4 packet. + start1 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) + end1 := start1.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) + + // Simulate receiving retransmitted subsegment for #2 packet and delayed #3 + // packet by sending DSACK block for #2 packet. + dsackStart := c.IRS.Add(1 + seqnum.Size(maxPayload)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) + c.SendAckWithSACK(seq, 4*maxPayload, []header.SACKBlock{{dsackStart, dsackEnd}, {start, end}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKDetectDSACKDup tests two non-contiguous duplicate subsegments not +// covered by the cumulative acknowledgement. +// See: https://tools.ietf.org/html/rfc2883#section-4.2.3. +func TestRACKDetectDSACKDup(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan int) + const ackNumToVerify = 5 + addDSACKSeenCheckerProbe(t, c, ackNumToVerify, probeDone) + + numPackets := 7 + data := sendAndReceive(t, c, numPackets) + + // Send ACK for #1 packet. + bytesRead := maxPayload + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + c.SendAck(seq, bytesRead) + + // Missing [2-6] packets and SACK #7 packet. + seq = seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(1 + seqnum.Size(6*maxPayload)) + end := start.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + // Received delayed #3 packet. + start1 := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) + end1 := start1.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) + + // Expect retransmission of #2 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Consider #2 packet has been dropped and SACK #4 packet. + start2 := c.IRS.Add(1 + seqnum.Size(3*maxPayload)) + end2 := start2.Add(seqnum.Size(maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start2, end2}, {start1, end1}, {start, end}}) + + // Simulate receiving retransmitted subsegment for #3 packet and delayed #5 + // packet by sending DSACK block for the subsegment. + dsackStart := c.IRS.Add(1 + seqnum.Size(2*maxPayload)) + dsackEnd := dsackStart.Add(seqnum.Size(maxPayload / 2)) + end1 = end1.Add(seqnum.Size(2 * maxPayload)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{dsackStart, dsackEnd}, {start1, end1}}) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + err := <-probeDone + switch err { + case failedToDetectDSACK: + t.Fatalf("RACK DSACK detection failed") + case invalidDSACKDetected: + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } +} + +// TestRACKWithInvalidDSACKBlock tests that DSACK is not detected when DSACK +// is not the first SACK block. +func TestRACKWithInvalidDSACKBlock(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan struct{}) + const ackNumToVerify = 2 + var n int + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that RACK does not detect DSACK when DSACK block is + // not the first SACK block. + n++ + t.Helper() + if state.Sender.RACKState.DSACKSeen { + t.Fatalf("RACK DSACK detected when there is no duplicate SACK") + } + + if n == ackNumToVerify { + close(probeDone) + } + }) + + numPackets := 10 + data := sendAndReceive(t, c, numPackets) + + // Cumulative ACK for [1-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + bytesRead := 5 * maxPayload + c.SendAck(seq, bytesRead) + + // Expect retransmission of #6 packet. + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + + // Send DSACK block for #6 packet indicating both + // initial and retransmitted packet are received and + // packets [1-7] are received. + start := c.IRS.Add(seqnum.Size(bytesRead)) + end := start.Add(maxPayload) + bytesRead += 2 * maxPayload + + // Send DSACK block as second block. + start1 := c.IRS.Add(seqnum.Size(bytesRead) + maxPayload) + end1 := start1.Add(maxPayload) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start1, end1}, {start, end}}) + + // Wait for the probe function to finish processing the + // ACK before the test completes. + <-probeDone } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index ef7f5719f..faf0c0ad7 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -590,3 +590,45 @@ func TestSACKRecovery(t *testing.T) { expected++ } } + +// TestSACKUpdateSackedOut tests the sacked out field is updated when a SACK +// is received. +func TestSACKUpdateSackedOut(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan struct{}) + ackNum := 0 + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that the endpoint Sender.SackedOut is what we expect. + if state.Sender.SackedOut != 2 && ackNum == 0 { + t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut) + } + + if state.Sender.SackedOut != 0 && ackNum == 1 { + t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut) + } + if ackNum > 0 { + close(probeDone) + } + ackNum++ + }) + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + + sendAndReceive(t, c, 8) + + // ACK for [3-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload)) + bytesRead := 2 * maxPayload + end := start.Add(seqnum.Size(bytesRead)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + bytesRead += 3 * maxPayload + c.SendAck(seq, bytesRead) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + <-probeDone +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index fcc3c5000..aeceee7e0 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -17,10 +17,12 @@ package tcp_test import ( "bytes" "fmt" + "io/ioutil" "math" "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -40,6 +42,64 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// endpointTester provides helper functions to test a tcpip.Endpoint. +type endpointTester struct { + ep tcpip.Endpoint +} + +// CheckReadError issues a read to the endpoint and checking for an error. +func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) { + t.Helper() + res, got := e.ep.Read(ioutil.Discard, 1, tcpip.ReadOptions{}) + if got != want { + t.Fatalf("ep.Read = %s, want %s", got, want) + } + if diff := cmp.Diff(tcpip.ReadResult{}, res); diff != "" { + t.Errorf("ep.Read: unexpected non-zero result (-want +got):\n%s", diff) + } +} + +// CheckRead issues a read to the endpoint and checking for a success, returning +// the data read. +func (e *endpointTester) CheckRead(t *testing.T, count int) []byte { + t.Helper() + var buf bytes.Buffer + res, err := e.ep.Read(&buf, count, tcpip.ReadOptions{}) + if err != nil { + t.Fatalf("ep.Read = _, %s; want _, nil", err) + } + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) + } + return buf.Bytes() +} + +// CheckReadFull reads from the endpoint for exactly count bytes. +func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte { + t.Helper() + var buf bytes.Buffer + var done int + for done < count { + res, err := e.ep.Read(&buf, count-done, tcpip.ReadOptions{}) + if err == tcpip.ErrWouldBlock { + // Wait for receive to be notified. + select { + case <-notifyRead: + case <-time.After(timeout): + t.Fatalf("Timed out waiting for data to arrive") + } + continue + } else if err != nil { + t.Fatalf("ep.Read = _, %s; want _, nil", err) + } + done += res.Count + } + return buf.Bytes() +} + const ( // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -75,9 +135,6 @@ func TestGiveUpConnect(t *testing.T) { // Wait for ep to become writable. <-notifyCh - if err := ep.LastError(); err != tcpip.ErrAborted { - t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted) - } // Call Connect again to retreive the handshake failure status // and stats updates. @@ -267,7 +324,7 @@ func TestTCPResetsSentNoICMP(t *testing.T) { } // Read outgoing ICMP stats and check no ICMP DstUnreachable was recorded. - sent := stats.ICMP.V4PacketsSent + sent := stats.ICMP.V4.PacketsSent if got, want := sent.DstUnreachable.Value(), uint64(0); got != want { t.Errorf("got ICMP DstUnreachable.Value() = %d, want = %d", got, want) } @@ -743,9 +800,7 @@ func TestSimpleReceive(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -765,11 +820,7 @@ func TestSimpleReceive(t *testing.T) { } // Receive data. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - + v := ept.CheckRead(t, defaultMTU) if !bytes.Equal(data, v) { t.Fatalf("got data = %v, want = %v", v, data) } @@ -1383,9 +1434,8 @@ func TestConnectBindToDevice(t *testing.T) { defer c.Cleanup() c.Create(-1) - bindToDevice := tcpip.BindToDeviceOption(test.device) - if err := c.EP.SetSockOpt(&bindToDevice); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", bindToDevice, bindToDevice, err) + if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err) } // Start connection attempt. waitEntry, _ := waiter.NewChannelEntry(nil) @@ -1496,14 +1546,11 @@ func TestSynSent(t *testing.T) { t.Fatal("timed out waiting for packet to arrive") } + ept := endpointTester{c.EP} if test.reset { - if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused) - } + ept.CheckReadError(t, tcpip.ErrConnectionRefused) } else { - if _, _, err := c.EP.Read(nil); err != tcpip.ErrAborted { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrAborted) - } + ept.CheckReadError(t, tcpip.ErrAborted) } if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { @@ -1528,9 +1575,8 @@ func TestOutOfOrderReceive(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) // Send second half of data first, with seqnum 3 ahead of expected. data := []byte{1, 2, 3, 4, 5, 6} @@ -1555,9 +1601,7 @@ func TestOutOfOrderReceive(t *testing.T) { // Wait 200ms and check that no data has been received. time.Sleep(200 * time.Millisecond) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept.CheckReadError(t, tcpip.ErrWouldBlock) // Send the first 3 bytes now. c.SendPacket(data[:3], &context.Headers{ @@ -1570,24 +1614,7 @@ func TestOutOfOrderReceive(t *testing.T) { }) // Receive data. - read := make([]byte, 0, 6) - for len(read) < len(data) { - v, _, err := c.EP.Read(nil) - if err != nil { - if err == tcpip.ErrWouldBlock { - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - continue - } - t.Fatalf("Read failed: %s", err) - } - - read = append(read, v...) - } + read := ept.CheckReadFull(t, 6, ch, 5*time.Second) // Check that we received the data in proper order. if !bytes.Equal(data, read) { @@ -1612,9 +1639,8 @@ func TestOutOfOrderFlood(t *testing.T) { rcvBufSz := math.MaxUint16 c.CreateConnected(789, 30000, rcvBufSz) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) // Send 100 packets before the actual one that is expected. data := []byte{1, 2, 3, 4, 5, 6} @@ -1689,9 +1715,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -1758,9 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -1841,17 +1865,14 @@ func TestShutdownRead(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { t.Fatalf("Shutdown failed: %s", err) } - if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive) - } + ept.CheckReadError(t, tcpip.ErrClosedForReceive) var want uint64 = 1 if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) @@ -1869,10 +1890,8 @@ func TestFullWindowReceive(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - _, _, err := c.EP.Read(nil) - if err != tcpip.ErrWouldBlock { - t.Fatalf("Read failed: %s", err) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies // the provided buffer value by tcp.SegOverheadFactor to calculate the actual @@ -1909,11 +1928,7 @@ func TestFullWindowReceive(t *testing.T) { ) // Receive data and check it. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - + v := ept.CheckRead(t, defaultMTU) if !bytes.Equal(data, v) { t.Fatalf("got data = %v, want = %v", v, data) } @@ -1935,6 +1950,85 @@ func TestFullWindowReceive(t *testing.T) { ) } +// Test the stack receive window advertisement on receiving segments smaller than +// segment overhead. It tests for the right edge of the window to not grow when +// the endpoint is not being read from. +func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + opt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize, + Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)), + } + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } + + c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Bump up the receive buffer size such that, when the receive window grows, + // the scaled window exceeds maxUint16. + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err) + } + + // Keep the payload size < segment overhead and such that it is a multiple + // of the window scaled value. This enables the test to perform equality + // checks on the incoming receive window. + payload := generateRandomPayload(t, (tcp.SegSize-1)&(1<<c.RcvdWindowScale)) + payloadLen := seqnum.Size(len(payload)) + iss := seqnum.Value(789) + seqNum := iss.Add(1) + + // Send payload to the endpoint and return the advertised receive window + // from the endpoint. + getIncomingRcvWnd := func() uint32 { + c.SendPacket(payload, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: seqNum, + AckNum: c.IRS.Add(1), + Flags: header.TCPFlagAck, + RcvWnd: 30000, + }) + seqNum = seqNum.Add(payloadLen) + + pkt := c.GetPacket() + return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale + } + + // Read the advertised receive window with the ACK for payload. + rcvWnd := getIncomingRcvWnd() + + // Check if the subsequent ACK to our send has not grown the right edge of + // the window. + if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } + + // Read the data so that the subsequent ACK from the endpoint + // grows the right edge of the window. + var buf bytes.Buffer + if _, err := c.EP.Read(&buf, math.MaxUint16, tcpip.ReadOptions{}); err != nil { + t.Fatalf("c.EP.Read: %s", err) + } + + // Check if we have received max uint16 as our advertised + // scaled window now after a read above. + maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale) + if got, want := getIncomingRcvWnd(), maxRcv; got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } + + // Check if the subsequent ACK to our send has not grown the right edge of + // the window. + if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } +} + func TestNoWindowShrinking(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -1953,9 +2047,9 @@ func TestNoWindowShrinking(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) + // Send a 1 byte payload so that we can record the current receive window. // Send a payload of half the size of rcvBufSize. seqNum := iss.Add(1) @@ -1977,11 +2071,7 @@ func TestNoWindowShrinking(t *testing.T) { } // Read the 1 byte payload we just sent. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - if got, want := payload, v; !bytes.Equal(got, want) { + if got, want := payload, ept.CheckRead(t, 1); !bytes.Equal(got, want) { t.Fatalf("got data: %v, want: %v", got, want) } @@ -2054,24 +2144,8 @@ func TestNoWindowShrinking(t *testing.T) { ), ) - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - // Receive data and check it. - read := make([]byte, 0, rcvBufSize) - for len(read) < len(data) { - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - - read = append(read, v...) - } - + read := ept.CheckReadFull(t, len(data), ch, 5*time.Second) if !bytes.Equal(data, read) { t.Fatalf("got data = %v, want = %v", read, data) } @@ -2495,11 +2569,11 @@ func TestZeroScaledWindowReceive(t *testing.T) { // we need to read at 3 packets. sz := 0 for sz < defaultMTU*2 { - v, _, err := c.EP.Read(nil) + res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) if err != nil { t.Fatalf("Read failed: %s", err) } - sz += len(v) + sz += res.Count } checker.IPv4(t, c.GetPacket(), @@ -2532,10 +2606,10 @@ func TestSegmentMerging(t *testing.T) { { "cork", func(ep tcpip.Endpoint) { - ep.SetSockOptBool(tcpip.CorkOption, true) + ep.SocketOptions().SetCorkOption(true) }, func(ep tcpip.Endpoint) { - ep.SetSockOptBool(tcpip.CorkOption, false) + ep.SocketOptions().SetCorkOption(false) }, }, } @@ -2627,7 +2701,7 @@ func TestDelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptBool(tcpip.DelayOption, true) + c.EP.SocketOptions().SetDelayOption(true) var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { @@ -2675,7 +2749,7 @@ func TestUndelay(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - c.EP.SetSockOptBool(tcpip.DelayOption, true) + c.EP.SocketOptions().SetDelayOption(true) allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { @@ -2708,7 +2782,7 @@ func TestUndelay(t *testing.T) { // Check that we don't get the second packet yet. c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) - c.EP.SetSockOptBool(tcpip.DelayOption, false) + c.EP.SocketOptions().SetDelayOption(false) // Check that data is received. second := c.GetPacket() @@ -2745,8 +2819,8 @@ func TestMSSNotDelayed(t *testing.T) { fn func(tcpip.Endpoint) }{ {"no-op", func(tcpip.Endpoint) {}}, - {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }}, - {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }}, + {"delay", func(ep tcpip.Endpoint) { ep.SocketOptions().SetDelayOption(true) }}, + {"cork", func(ep tcpip.Endpoint) { ep.SocketOptions().SetCorkOption(true) }}, } for _, test := range tests { @@ -3194,10 +3268,15 @@ func TestReceiveOnResetConnection(t *testing.T) { loop: for { - switch _, _, err := c.EP.Read(nil); err { + switch _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err { case tcpip.ErrWouldBlock: select { case <-ch: + // Expect the state to be StateError and subsequent Reads to fail with HardError. + if _, err := c.EP.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset) + } + break loop case <-time.After(1 * time.Second): t.Fatalf("Timed out waiting for reset to arrive") } @@ -3207,14 +3286,10 @@ loop: t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) } } - // Expect the state to be StateError and subsequent Reads to fail with HardError. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) - } + if tcp.EndpointState(c.EP.State()) != tcp.StateError { t.Fatalf("got EP state is not StateError") } - if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got) } @@ -3386,7 +3461,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), ), ) - idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}} + idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}} // Expect two retransmitted packets, and that all packets received have // unique IPv4 ID values. for i := 0; i <= 2; i++ { @@ -4089,9 +4164,8 @@ func TestReadAfterClosedState(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -4149,35 +4223,31 @@ func TestReadAfterClosedState(t *testing.T) { } // Check that peek works. - peekBuf := make([]byte, 10) - n, _, err := c.EP.Peek([][]byte{peekBuf}) + var peekBuf bytes.Buffer + res, err := c.EP.Read(&peekBuf, 10, tcpip.ReadOptions{Peek: true}) if err != nil { t.Fatalf("Peek failed: %s", err) } - peekBuf = peekBuf[:n] - if !bytes.Equal(data, peekBuf) { - t.Fatalf("got data = %v, want = %v", peekBuf, data) + if got, want := res.Count, len(data); got != want { + t.Fatalf("res.Count = %d, want %d", got, want) } - - // Receive data. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %s", err) + if !bytes.Equal(data, peekBuf.Bytes()) { + t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data) } + // Receive data. + v := ept.CheckRead(t, defaultMTU) if !bytes.Equal(data, v) { t.Fatalf("got data = %v, want = %v", v, data) } // Now that we drained the queue, check that functions fail with the // right error code. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive) - } - - if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive) + ept.CheckReadError(t, tcpip.ErrClosedForReceive) + var buf bytes.Buffer + if _, err := c.EP.Read(&buf, 1, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive { + t.Fatalf("c.EP.Read(_, _, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive) } } @@ -4193,9 +4263,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4205,9 +4273,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4218,9 +4284,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4233,9 +4297,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4246,9 +4308,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4261,9 +4321,7 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } - if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { - t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) - } + c.EP.SocketOptions().SetReuseAddress(true) if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -4443,7 +4501,7 @@ func TestBindToDeviceOption(t *testing.T) { name string setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error - getBindToDevice tcpip.BindToDeviceOption + getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, @@ -4453,15 +4511,13 @@ func TestBindToDeviceOption(t *testing.T) { for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { - bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + bindToDevice := int32(*testAction.setBindToDevice) + if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption(88888) - if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err) - } else if bindToDevice != testAction.getBindToDevice { + bindToDevice := ep.SocketOptions().GetBindToDevice() + if bindToDevice != testAction.getBindToDevice { t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice) } }) @@ -4558,17 +4614,8 @@ func TestSelfConnect(t *testing.T) { // Read back what was written. wq.EventUnregister(&waitEntry) wq.EventRegister(&waitEntry, waiter.EventIn) - rd, _, err := ep.Read(nil) - if err != nil { - if err != tcpip.ErrWouldBlock { - t.Fatalf("Read failed: %s", err) - } - <-notifyCh - rd, _, err = ep.Read(nil) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - } + ept := endpointTester{ep} + rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second) if !bytes.Equal(data, rd) { t.Fatalf("got data = %v, want = %v", rd, data) @@ -4656,13 +4703,9 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { switch network { case "ipv4": case "ipv6": - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err) - } + ep.SocketOptions().SetV6Only(true) case "dual": - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil { - t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err) - } + ep.SocketOptions().SetV6Only(false) default: t.Fatalf("unknown network: '%s'", network) } @@ -4998,9 +5041,7 @@ func TestKeepalive(t *testing.T) { if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil { t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err) } - if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil { - t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err) - } + c.EP.SocketOptions().SetKeepAlive(true) // 5 unacked keepalives are sent. ACK each one, and check that the // connection stays alive after 5. @@ -5027,9 +5068,8 @@ func TestKeepalive(t *testing.T) { } // Check that the connection is still alive. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) // Send some data and wait before ACKing it. Keepalives should be disabled // during this period. @@ -5118,9 +5158,7 @@ func TestKeepalive(t *testing.T) { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) } - if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) - } + ept.CheckReadError(t, tcpip.ErrTimeout) if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) @@ -5660,16 +5698,14 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { t.Fatalf("Bind failed: %s", err) } - // Test acceptance. // Start listening. listenBacklog := 1 - portOffset := uint16(0) if err := c.EP.Listen(listenBacklog); err != nil { t.Fatalf("Listen failed: %s", err) } - executeHandshake(t, c, context.TestPort+portOffset, false) - portOffset++ + executeHandshake(t, c, context.TestPort, false) + // Wait for this to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) @@ -5717,6 +5753,50 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } } +func TestSYNRetransmit(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send the same SYN packet multiple times. We should still get a valid SYN-ACK + // reply. + irs := seqnum.Value(789) + for i := 0; i < 5; i++ { + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + } + + // Receive the SYN-ACK reply. + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs) + 1), + } + checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...)) +} + func TestSynRcvdBadSeqNumber(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5971,9 +6051,8 @@ func TestEndpointBindListenAcceptState(t *testing.T) { t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } - if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected { - t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected) - } + ept := endpointTester{ep} + ept.CheckReadError(t, tcpip.ErrNotConnected) if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) } @@ -6074,10 +6153,13 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Introduce a 25ms latency by delaying the first byte. latency := 25 * time.Millisecond time.Sleep(latency) - rawEP.SendPacketWithTS([]byte{1}, tsVal) + // Send an initial payload with atleast segment overhead size. The receive + // window would not grow for smaller segments. + rawEP.SendPacketWithTS(make([]byte, tcp.SegSize), tsVal) pkt := rawEP.VerifyAndReturnACKWithTS(tsVal) rcvWnd := header.TCP(header.IPv4(pkt).Payload()).WindowSize() + time.Sleep(25 * time.Millisecond) // Allocate a large enough payload for the test. @@ -6125,7 +6207,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Now read all the data from the endpoint and verify that advertised // window increases to the full available buffer size. for { - _, _, err := c.EP.Read(nil) + _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) if err == tcpip.ErrWouldBlock { break } @@ -6249,11 +6331,11 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // to happen before we measure the new window. totalCopied := 0 for { - b, _, err := c.EP.Read(nil) + res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) if err == tcpip.ErrWouldBlock { break } - totalCopied += len(b) + totalCopied += res.Count } // Invoke the moderation API. This is required for auto-tuning @@ -6350,10 +6432,7 @@ func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.T if err != nil { t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err) } - gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption) - if err != nil { - t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err) - } + gotDelayOption := ep.SocketOptions().GetDelayOption() if gotDelayOption != wantDelayOption { t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption) } @@ -7173,9 +7252,8 @@ func TestTCPUserTimeout(t *testing.T) { ), ) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrTimeout) if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) @@ -7206,9 +7284,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil { t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err) } - if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil { - t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err) - } + c.EP.SocketOptions().SetKeepAlive(true) // Set userTimeout to be the duration to be 1 keepalive // probes. Which means that after the first probe is sent @@ -7220,9 +7296,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { } // Check that the connection is still alive. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) // Now receive 1 keepalives, but don't ACK it. b := c.GetPacket() @@ -7261,9 +7336,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { ), ) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) - } + ept.CheckReadError(t, tcpip.ErrTimeout) if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) } @@ -7320,11 +7393,11 @@ func TestIncreaseWindowOnRead(t *testing.T) { // defaultMTU is a good enough estimate for the MSS used for this // connection. for read < defaultMTU*2 { - v, _, err := c.EP.Read(nil) + res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) if err != nil { t.Fatalf("Read failed: %s", err) } - read += len(v) + read += res.Count } // After reading > MSS worth of data, we surely crossed MSS. See the ack: diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 0f9ed06cd..9e02d467d 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -105,11 +106,18 @@ func TestTimeStampEnabledConnect(t *testing.T) { // There should be 5 views to read and each of them should // contain the same data. for i := 0; i < 5; i++ { - got, _, err := c.EP.Read(nil) + var buf bytes.Buffer + result, err := c.EP.Read(&buf, len(data), tcpip.ReadOptions{}) if err != nil { t.Fatalf("Unexpected error from Read: %v", err) } - if want := data; bytes.Compare(got, want) != 0 { + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("Read: unexpected result (-want +got):\n%s", diff) + } + if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 { t.Fatalf("Data is different: got: %v, want: %v", got, want) } } @@ -286,11 +294,18 @@ func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) { } // Issue a read and we should data. - got, _, err := c.EP.Read(nil) + var buf bytes.Buffer + result, err := c.EP.Read(&buf, defaultMTU, tcpip.ReadOptions{}) if err != nil { t.Fatalf("Unexpected error from Read: %v", err) } - if want := data; bytes.Compare(got, want) != 0 { + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("Read: unexpected result (-want +got):\n%s", diff) + } + if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 { t.Fatalf("Data is different: got: %v, want: %v", got, want) } } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index e6aa4fc4b..ee55f030c 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -592,9 +592,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) { c.t.Fatalf("NewEndpoint failed: %v", err) } - if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } + c.EP.SocketOptions().SetV6Only(v6only) } // GetV6Packet reads a single packet from the link layer endpoint of the context @@ -637,11 +635,11 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.TCPMinimumSize + len(payload)), - NextHeader: uint8(tcp.ProtocolNumber), - HopLimit: 65, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(header.TCPMinimumSize + len(payload)), + TransportProtocol: tcp.ProtocolNumber, + HopLimit: 65, + SrcAddr: src, + DstAddr: dst, }) // Initialize the TCP header. diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go index 7981d469b..38a335840 100644 --- a/pkg/tcpip/transport/tcp/timer.go +++ b/pkg/tcpip/transport/tcp/timer.go @@ -84,6 +84,10 @@ func (t *timer) init(w *sleep.Waker) { // cleanup frees all resources associated with the timer. func (t *timer) cleanup() { + if t.timer == nil { + // No cleanup needed. + return + } t.timer.Stop() *t = timer{} } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index c78549424..153e8c950 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -56,6 +56,8 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", + "//pkg/tcpip/transport/icmp", "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 9bcb918bb..5d87f3a7e 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -16,8 +16,9 @@ package udp import ( "fmt" + "io" + "sync/atomic" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -30,10 +31,11 @@ import ( // +stateify savable type udpPacket struct { udpPacketEntry - senderAddress tcpip.FullAddress - packetInfo tcpip.IPPacketInfo - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + senderAddress tcpip.FullAddress + destinationAddress tcpip.FullAddress + packetInfo tcpip.IPPacketInfo + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + timestamp int64 // tos stores either the receiveTOS or receiveTClass value. tos uint8 } @@ -77,6 +79,7 @@ func (s EndpointState) String() string { // +stateify savable type endpoint struct { stack.TransportEndpointInfo + tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and do not // change throughout the lifetime of the endpoint. @@ -94,22 +97,19 @@ type endpoint struct { rcvClosed bool // The following fields are protected by the mu mutex. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - sndBufSizeMax int + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + // state must be read/set using the EndpointState()/setEndpointState() + // methods. state EndpointState - route stack.Route `state:"manual"` + route *stack.Route `state:"manual"` dstPort uint16 - v6only bool ttl uint8 multicastTTL uint8 multicastAddr tcpip.Address multicastNICID tcpip.NICID - multicastLoop bool portFlags ports.Flags - bindToDevice tcpip.NICID - broadcast bool - noChecksum bool lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` @@ -123,17 +123,6 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 - // receiveTOS determines if the incoming IPv4 TOS header field is passed - // as ancillary data to ControlMessages on Read. - receiveTOS bool - - // receiveTClass determines if the incoming IPv6 TClass header field is - // passed as ancillary data to ControlMessages on Read. - receiveTClass bool - - // receiveIPPacketInfo determines if the packet info is returned by Read. - receiveIPPacketInfo bool - // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -155,8 +144,8 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption + // ops is used to get socket level options. + ops tcpip.SocketOptions } // +stateify savable @@ -186,13 +175,14 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Linux defaults to TTL=1. multicastTTL: 1, - multicastLoop: true, rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), state: StateInitial, uniqueID: s.UniqueID(), } + e.ops.InitHandler(e) + e.ops.SetMulticastLoop(true) // Override with stack defaults. var ss stack.SendBufferSizeOption @@ -208,6 +198,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue return e } +// setEndpointState updates the state of the endpoint to state atomically. This +// method is unexported as the only place we should update the state is in this +// package but we allow the state to be read freely without holding e.mu. +// +// Precondition: e.mu must be held to call this method. +func (e *endpoint) setEndpointState(state EndpointState) { + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) +} + +// EndpointState() returns the current state of the endpoint. +func (e *endpoint) EndpointState() EndpointState { + return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) +} + // UniqueID implements stack.TransportEndpoint.UniqueID. func (e *endpoint) UniqueID() uint64 { return e.uniqueID @@ -222,6 +226,13 @@ func (e *endpoint) LastError() *tcpip.Error { return err } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (e *endpoint) UpdateLastError(err *tcpip.Error) { + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() +} + // Abort implements stack.TransportEndpoint.Abort. func (e *endpoint) Abort() { e.Close() @@ -233,7 +244,7 @@ func (e *endpoint) Close() { e.mu.Lock() e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite - switch e.state { + switch e.EndpointState() { case StateBound, StateConnected: e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) @@ -256,10 +267,13 @@ func (e *endpoint) Close() { } e.rcvMu.Unlock() - e.route.Release() + if e.route != nil { + e.route.Release() + e.route = nil + } // Update the state. - e.state = StateClosed + e.setEndpointState(StateClosed) e.mu.Unlock() @@ -269,11 +283,10 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} -// Read reads data from the endpoint. This method does not block if -// there is no data pending. -func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.Endpoint.Read. +func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { if err := e.LastError(); err != nil { - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } e.rcvMu.Lock() @@ -285,41 +298,54 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } p := e.rcvList.Front() - e.rcvList.Remove(p) - e.rcvBufSize -= p.data.Size() - e.rcvMu.Unlock() - - if addr != nil { - *addr = p.senderAddress + if !opts.Peek { + e.rcvList.Remove(p) + e.rcvBufSize -= p.data.Size() } + e.rcvMu.Unlock() + // Control Messages cm := tcpip.ControlMessages{ HasTimestamp: true, Timestamp: p.timestamp, } - e.mu.RLock() - receiveTOS := e.receiveTOS - receiveTClass := e.receiveTClass - receiveIPPacketInfo := e.receiveIPPacketInfo - e.mu.RUnlock() - if receiveTOS { + if e.ops.GetReceiveTOS() { cm.HasTOS = true cm.TOS = p.tos } - if receiveTClass { + if e.ops.GetReceiveTClass() { cm.HasTClass = true // Although TClass is an 8-bit value it's read in the CMsg as a uint32. cm.TClass = uint32(p.tos) } - if receiveIPPacketInfo { + if e.ops.GetReceivePacketInfo() { cm.HasIPPacketInfo = true cm.PacketInfo = p.packetInfo } - return p.data.ToView(), cm, nil + if e.ops.GetReceiveOriginalDstAddress() { + cm.HasOriginalDstAddress = true + cm.OriginalDstAddress = p.destinationAddress + } + + // Read Result + res := tcpip.ReadResult{ + Total: p.data.Size(), + ControlMessages: cm, + } + if opts.NeedRemoteAddr { + res.RemoteAddr = p.senderAddress + } + + n, err := p.data.ReadTo(dst, count, opts.Peek) + if n == 0 && err != nil { + return res, tcpip.ErrBadBuffer + } + res.Count = n + return res, nil } // prepareForWrite prepares the endpoint for sending data. In particular, it @@ -328,7 +354,7 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess // // Returns true for retry if preparation should be retried. func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { - switch e.state { + switch e.EndpointState() { case StateInitial: case StateConnected: return false, nil @@ -350,7 +376,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return true, nil } @@ -365,9 +391,9 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // connectRoute establishes a route to the specified interface or the // configured multicast interface if no interface is specified and the // specified address is a multicast address. -func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, *tcpip.Error) { localAddr := e.ID.LocalAddress - if isBroadcastOrMulticast(localAddr) { + if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). localAddr = "" } @@ -382,9 +408,9 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop) + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop()) if err != nil { - return stack.Route{}, 0, err + return nil, 0, err } return r, nicID, nil } @@ -427,7 +453,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c to := opts.To e.mu.RLock() - defer e.mu.RUnlock() + lockReleased := false + defer func() { + if lockReleased { + return + } + e.mu.RUnlock() + }() // If we've shutdown with SHUT_WR we are in an invalid state for sending. if e.shutdownFlags&tcpip.ShutdownWrite != 0 { @@ -446,36 +478,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } } - var route *stack.Route - var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) - var dstPort uint16 - if to == nil { - route = &e.route - dstPort = e.dstPort - resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) { - // Promote lock to exclusive if using a shared route, given that it may - // need to change in Route.Resolve() call below. - e.mu.RUnlock() - e.mu.Lock() - - // Recheck state after lock was re-acquired. - if e.state != StateConnected { - err = tcpip.ErrInvalidEndpointState - } - if err == nil && route.IsResolutionRequired() { - ch, err = route.Resolve(waker) - } - - e.mu.Unlock() - e.mu.RLock() - - // Recheck state after lock was re-acquired. - if e.state != StateConnected { - err = tcpip.ErrInvalidEndpointState - } - return - } - } else { + route := e.route + dstPort := e.dstPort + if to != nil { // Reject destination address if it goes through a different // NIC than the endpoint was bound to. nicID := to.NIC @@ -503,17 +508,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } defer r.Release() - route = &r + route = r dstPort = dst.Port - resolve = route.Resolve } - if !e.broadcast && route.IsOutboundBroadcast() { + if !e.ops.GetBroadcast() && route.IsOutboundBroadcast() { return 0, nil, tcpip.ErrBroadcastDisabled } if route.IsResolutionRequired() { - if ch, err := resolve(nil); err != nil { + if ch, err := route.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { return 0, ch, tcpip.ErrNoLinkAddress } @@ -527,6 +531,20 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. + so := e.SocketOptions() + if so.GetRecvError() { + so.QueueLocalErr( + tcpip.ErrMessageTooLong, + route.NetProto, + header.UDPMaximumPacketSize, + tcpip.FullAddress{ + NIC: route.NICID(), + Addr: route.RemoteAddress, + Port: dstPort, + }, + v, + ) + } return 0, nil, tcpip.ErrMessageTooLong } @@ -539,83 +557,41 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil { + localPort := e.ID.LocalPort + sendTOS := e.sendTOS + owner := e.owner + noChecksum := e.SocketOptions().GetNoChecksum() + lockReleased = true + e.mu.RUnlock() + + // Do not hold lock when sending as loopback is synchronous and if the UDP + // datagram ends up generating an ICMP response then it can result in a + // deadlock where the ICMP response handling ends up acquiring this endpoint's + // mutex using e.mu.RLock() in endpoint.HandleControlPacket which can cause a + // deadlock if another caller is trying to acquire e.mu in exclusive mode w/ + // e.mu.Lock(). Since e.mu.Lock() prevents any new read locks to ensure the + // lock can be eventually acquired. + // + // See: https://golang.org/pkg/sync/#RWMutex for details on why recursive read + // locking is prohibited. + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), localPort, dstPort, ttl, useDefaultTTL, sendTOS, owner, noChecksum); err != nil { return 0, nil, err } return int64(len(v)), nil, nil } -// Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. +func (e *endpoint) OnReuseAddressSet(v bool) { + e.mu.Lock() + e.portFlags.MostRecent = v + e.mu.Unlock() } -// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. -func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.BroadcastOption: - e.mu.Lock() - e.broadcast = v - e.mu.Unlock() - - case tcpip.MulticastLoopOption: - e.mu.Lock() - e.multicastLoop = v - e.mu.Unlock() - - case tcpip.NoChecksumOption: - e.mu.Lock() - e.noChecksum = v - e.mu.Unlock() - - case tcpip.ReceiveTOSOption: - e.mu.Lock() - e.receiveTOS = v - e.mu.Unlock() - - case tcpip.ReceiveTClassOption: - // We only support this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrNotSupported - } - - e.mu.Lock() - e.receiveTClass = v - e.mu.Unlock() - - case tcpip.ReceiveIPPacketInfoOption: - e.mu.Lock() - e.receiveIPPacketInfo = v - e.mu.Unlock() - - case tcpip.ReuseAddressOption: - e.mu.Lock() - e.portFlags.MostRecent = v - e.mu.Unlock() - - case tcpip.ReusePortOption: - e.mu.Lock() - e.portFlags.LoadBalanced = v - e.mu.Unlock() - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } - - e.mu.Lock() - defer e.mu.Unlock() - - // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { - return tcpip.ErrInvalidEndpointState - } - - e.v6only = v - } - - return nil +// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet. +func (e *endpoint) OnReusePortSet(v bool) { + e.mu.Lock() + e.portFlags.LoadBalanced = v + e.mu.Unlock() } // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. @@ -691,6 +667,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +func (e *endpoint) HasNIC(id int32) bool { + return id == 0 || e.stack.HasNIC(tcpip.NICID(id)) +} + // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { @@ -737,14 +717,9 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { nicID := v.NIC - // The interface address is considered not-set if it is empty or contains - // all-zeros. The former represent the zero-value in golang, the latter the - // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall. - allZeros := header.IPv4Any - if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros { + if v.InterfaceAddr.Unspecified() { if nicID == 0 { - r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) - if err == nil { + if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil { nicID = r.NICID() r.Release() } @@ -777,10 +752,9 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { } nicID := v.NIC - if v.InterfaceAddr == header.IPv4Any { + if v.InterfaceAddr.Unspecified() { if nicID == 0 { - r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) - if err == nil { + if r, err := e.stack.FindRoute(0, "", v.MulticastAddr, e.NetProto, false /* multicastLoop */); err == nil { nicID = r.NICID() r.Release() } @@ -807,107 +781,12 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { delete(e.multicastMemberships, memToRemove) - case *tcpip.BindToDeviceOption: - id := tcpip.NICID(*v) - if id != 0 && !e.stack.HasNIC(id) { - return tcpip.ErrUnknownDevice - } - e.mu.Lock() - e.bindToDevice = id - e.mu.Unlock() - case *tcpip.SocketDetachFilterOption: return nil - - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() } return nil } -// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.BroadcastOption: - e.mu.RLock() - v := e.broadcast - e.mu.RUnlock() - return v, nil - - case tcpip.KeepaliveEnabledOption: - return false, nil - - case tcpip.MulticastLoopOption: - e.mu.RLock() - v := e.multicastLoop - e.mu.RUnlock() - return v, nil - - case tcpip.NoChecksumOption: - e.mu.RLock() - v := e.noChecksum - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveTOSOption: - e.mu.RLock() - v := e.receiveTOS - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveTClassOption: - // We only support this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return false, tcpip.ErrNotSupported - } - - e.mu.RLock() - v := e.receiveTClass - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveIPPacketInfoOption: - e.mu.RLock() - v := e.receiveIPPacketInfo - e.mu.RUnlock() - return v, nil - - case tcpip.ReuseAddressOption: - e.mu.RLock() - v := e.portFlags.MostRecent - e.mu.RUnlock() - - return v, nil - - case tcpip.ReusePortOption: - e.mu.RLock() - v := e.portFlags.LoadBalanced - e.mu.RUnlock() - - return v, nil - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return false, tcpip.ErrUnknownProtocolOption - } - - e.mu.RLock() - v := e.v6only - e.mu.RUnlock() - - return v, nil - - case tcpip.AcceptConnOption: - return false, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } -} - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { @@ -977,16 +856,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { } e.mu.Unlock() - case *tcpip.BindToDeviceOption: - e.mu.RLock() - *o = tcpip.BindToDeviceOption(e.bindToDevice) - e.mu.RUnlock() - - case *tcpip.LingerOption: - e.mu.RLock() - *o = e.linger - e.mu.RUnlock() - default: return tcpip.ErrUnknownProtocolOption } @@ -1046,7 +915,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err } @@ -1058,7 +927,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - if e.state != StateConnected { + if e.EndpointState() != StateConnected { return nil } var ( @@ -1081,7 +950,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { if err != nil { return err } - e.state = StateBound + e.setEndpointState(StateBound) boundPortFlags = e.boundPortFlags } else { if e.ID.LocalPort != 0 { @@ -1089,14 +958,14 @@ func (e *endpoint) Disconnect() *tcpip.Error { e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } - e.state = StateInitial + e.setEndpointState(StateInitial) } e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) e.ID = id e.boundBindToDevice = btd e.route.Release() - e.route = stack.Route{} + e.route = nil e.dstPort = 0 return nil @@ -1114,7 +983,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { nicID := addr.NIC var localPort uint16 - switch e.state { + switch e.EndpointState() { case StateInitial: case StateBound, StateConnected: localPort = e.ID.LocalPort @@ -1140,7 +1009,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer r.Release() id := stack.TransportEndpointID{ LocalAddress: e.ID.LocalAddress, @@ -1149,7 +1017,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { RemoteAddress: r.RemoteAddress, } - if e.state == StateInitial { + if e.EndpointState() == StateInitial { id.LocalAddress = r.LocalAddress } @@ -1157,7 +1025,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // packets on a different network protocol, so we register both even if // v6only is set to false and this is an ipv6 endpoint. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only { + if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() { netProtos = []tcpip.NetworkProtocolNumber{ header.IPv4ProtocolNumber, header.IPv6ProtocolNumber, @@ -1168,6 +1036,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { + r.Release() return err } @@ -1178,12 +1047,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.ID = id e.boundBindToDevice = btd - e.route = r.Clone() + e.route = r e.dstPort = addr.Port e.RegisterNICID = nicID e.effectiveNetProtos = netProtos - e.state = StateConnected + e.setEndpointState(StateConnected) e.rcvMu.Lock() e.rcvReady = true @@ -1205,7 +1074,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // A socket in the bound state can still receive multicast messages, // so we need to notify waiters on shutdown. - if e.state != StateBound && e.state != StateConnected { + if state := e.EndpointState(); state != StateBound && state != StateConnected { return tcpip.ErrNotConnected } @@ -1236,27 +1105,28 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcp } func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */) if err != nil { - return id, e.bindToDevice, err + return id, bindToDevice, err } id.LocalPort = port } e.boundPortFlags = e.portFlags - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{}) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } - return id, e.bindToDevice, err + return id, bindToDevice, err } func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. - if e.state != StateInitial { + if e.EndpointState() != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -1269,7 +1139,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // wildcard (empty) address, and this is an IPv6 endpoint with v6only // set to false. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" { netProtos = []tcpip.NetworkProtocolNumber{ header.IPv6ProtocolNumber, header.IPv4ProtocolNumber, @@ -1277,7 +1147,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { } nicID := addr.NIC - if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { + if len(addr.Addr) != 0 && !e.isBroadcastOrMulticast(addr.NIC, netProto, addr.Addr) { // A local unicast address was specified, verify that it's valid. nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nicID == 0 { @@ -1300,7 +1170,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { e.effectiveNetProtos = netProtos // Mark endpoint as bound. - e.state = StateBound + e.setEndpointState(StateBound) e.rcvMu.Lock() e.rcvReady = true @@ -1332,7 +1202,7 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { defer e.mu.RUnlock() addr := e.ID.LocalAddress - if e.state == StateConnected { + if e.EndpointState() == StateConnected { addr = e.route.LocalAddress } @@ -1348,7 +1218,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != StateConnected { + if e.EndpointState() != StateConnected { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -1447,6 +1317,11 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB Addr: id.RemoteAddress, Port: header.UDP(hdr).SourcePort(), }, + destinationAddress: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: header.UDP(hdr).DestinationPort(), + }, } packet.data = pkt.Data e.rcvList.PushBack(packet) @@ -1477,28 +1352,71 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } +func (e *endpoint) onICMPError(err *tcpip.Error, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { + // Update last error first. + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + // Update the error queue if IP_RECVERR is enabled. + if e.SocketOptions().GetRecvError() { + // Linux passes the payload without the UDP header. + var payload []byte + udp := header.UDP(pkt.Data.ToView()) + if len(udp) >= header.UDPMinimumSize { + payload = udp.Payload() + } + + e.SocketOptions().QueueErr(&tcpip.SockError{ + Err: err, + ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), + ErrType: errType, + ErrCode: errCode, + ErrInfo: extra, + Payload: payload, + Dst: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, + }, + Offender: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, + }, + NetProto: pkt.NetworkProtocolNumber, + }) + } + + // Notify of the error. + e.waiterQueue.Notify(waiter.EventErr) +} + // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +func (e *endpoint) HandleControlPacket(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { - e.mu.RLock() - if e.state == StateConnected { - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrConnectionRefused - e.lastErrorMu.Unlock() - e.mu.RUnlock() - - e.waiterQueue.Notify(waiter.EventErr) + if e.EndpointState() == StateConnected { + var errType byte + var errCode byte + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + errType = byte(header.ICMPv4DstUnreachable) + errCode = byte(header.ICMPv4PortUnreachable) + case header.IPv6ProtocolNumber: + errType = byte(header.ICMPv6DstUnreachable) + errCode = byte(header.ICMPv6PortUnreachable) + default: + panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber)) + } + e.onICMPError(tcpip.ErrConnectionRefused, errType, errCode, extra, pkt) return } - e.mu.RUnlock() } } // State implements tcpip.Endpoint.State. func (e *endpoint) State() uint32 { - e.mu.Lock() - defer e.mu.Unlock() - return uint32(e.state) + return uint32(e.EndpointState()) } // Info returns a copy of the endpoint info. @@ -1518,10 +1436,16 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements tcpip.Endpoint.Wait. func (*endpoint) Wait() {} -func isBroadcastOrMulticast(a tcpip.Address) bool { - return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) +func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr) } +// SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } + +// SocketOptions implements tcpip.Endpoint.SocketOptions. +func (e *endpoint) SocketOptions() *tcpip.SocketOptions { + return &e.ops +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 858c99a45..13b72dc88 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -98,7 +98,8 @@ func (e *endpoint) Resume(s *stack.Stack) { } } - if e.state != StateBound && e.state != StateConnected { + state := e.EndpointState() + if state != StateBound && state != StateConnected { return } @@ -113,12 +114,12 @@ func (e *endpoint) Resume(s *stack.Stack) { } var err *tcpip.Error - if e.state == StateConnected { - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop) + if state == StateConnected { + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop()) if err != nil { panic(err) } - } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound + } else if len(e.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.RegisterNICID, netProto, e.ID.LocalAddress) { // stateBound // A local unicast address is specified, verify that it's valid. if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 14e4648cd..d7fc21f11 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -78,7 +78,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, route.ResolveWith(r.pkt.SourceLinkAddress()) ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) - if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { ep.Close() route.Release() return nil, err diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index c09c7aa86..c8da173f1 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -18,10 +18,12 @@ import ( "bytes" "context" "fmt" + "io/ioutil" "math/rand" "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -32,6 +34,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -54,6 +57,7 @@ const ( stackPort = 1234 testAddr = "\x0a\x00\x00\x02" testPort = 4096 + invalidPort = 8192 multicastAddr = "\xe8\x2b\xd3\xea" multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" broadcastAddr = header.IPv4Broadcast @@ -295,7 +299,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() return newDualTestContextWithOptions(t, mtu, stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: true, }) } @@ -360,13 +365,9 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { c.createEndpoint(flow.sockProto()) if flow.isV6Only() { - if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - c.t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetV6Only(true) } else if flow.isBroadcast() { - if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { - c.t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetBroadcast(true) } } @@ -453,12 +454,12 @@ func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, + TrafficClass: testTOS, + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. @@ -555,7 +556,7 @@ func TestBindToDeviceOption(t *testing.T) { name string setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error - getBindToDevice tcpip.BindToDeviceOption + getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, @@ -565,15 +566,13 @@ func TestBindToDeviceOption(t *testing.T) { for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { - bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + bindToDevice := int32(*testAction.setBindToDevice) + if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption(88888) - if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err) - } else if bindToDevice != testAction.getBindToDevice { + bindToDevice := ep.SocketOptions().GetBindToDevice() + if bindToDevice != testAction.getBindToDevice { t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice) } }) @@ -598,13 +597,13 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe // Take a snapshot of the stats to validate them at the end of the test. epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - var addr tcpip.FullAddress - v, cm, err := c.ep.Read(&addr) + var buf bytes.Buffer + res, err := c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true}) if err == tcpip.ErrWouldBlock { // Wait for data to become available. select { case <-ch: - v, cm, err = c.ep.Read(&addr) + res, err = c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true}) case <-time.After(300 * time.Millisecond): if packetShouldBeDropped { @@ -624,23 +623,32 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe } if packetShouldBeDropped { - c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr) + c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr) } - // Check the peer address. + // Check the read result. h := flow.header4Tuple(incoming) - if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("got address = %s, want = %s", addr.Addr, h.srcAddr.Addr) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr}, + }, res, checker.IgnoreCmpPath( + "ControlMessages", // ControlMessages will be checked later. + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff) } // Check the payload. + v := buf.Bytes() if !bytes.Equal(payload, v) { c.t.Fatalf("got payload = %x, want = %x", v, payload) } // Run any checkers against the ControlMessages. for _, f := range checkers { - f(c.t, cm) + f(c.t, res.ControlMessages) } c.checkEndpointReadStats(1, epstats, err) @@ -831,8 +839,8 @@ func TestV4ReadSelfSource(t *testing.T) { t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) } - if _, _, err := c.ep.Read(nil); err != tt.wantErr { - t.Errorf("got c.ep.Read(nil) = %s, want = %s", err, tt.wantErr) + if _, err := c.ep.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err != tt.wantErr { + t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr) } }) } @@ -974,7 +982,7 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { // provided. func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { c.t.Helper() - return testWriteInternal(c, flow, true, checkers...) + return testWriteAndVerifyInternal(c, flow, true, checkers...) } // testWriteWithoutDestination sends a packet of the given test flow from the @@ -983,10 +991,10 @@ func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker // checker functions provided. func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { c.t.Helper() - return testWriteInternal(c, flow, false, checkers...) + return testWriteAndVerifyInternal(c, flow, false, checkers...) } -func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { +func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View { c.t.Helper() // Take a snapshot of the stats to validate them at the end of the test. epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() @@ -1008,6 +1016,12 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } c.checkEndpointWriteStats(1, epstats, err) + return payload +} + +func testWriteAndVerifyInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + payload := testWriteNoVerify(c, flow, setDest) // Received the packet and check the payload. b := c.getPacketAndVerify(flow, checkers...) var udp header.UDP @@ -1152,6 +1166,39 @@ func TestV4WriteOnConnected(t *testing.T) { testWriteWithoutDestination(c, unicastV4) } +func TestWriteOnConnectedInvalidPort(t *testing.T) { + protocols := map[string]tcpip.NetworkProtocolNumber{ + "ipv4": ipv4.ProtocolNumber, + "ipv6": ipv6.ProtocolNumber, + } + for name, pn := range protocols { + t.Run(name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(pn) + if err := c.ep.Connect(tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}); err != nil { + c.t.Fatalf("Connect failed: %s", err) + } + writeOpts := tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, + } + payload := buffer.View(newPayload()) + n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + if err != nil { + c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) + } + if got, want := n, int64(len(payload)); got != want { + c.t.Fatalf("c.ep.Write(...) wrote %d bytes, want %d bytes", got, want) + } + + if err := c.ep.LastError(); err != tcpip.ErrConnectionRefused { + c.t.Fatalf("expected c.ep.LastError() == ErrConnectionRefused, got: %+v", err) + } + }) + } +} + // TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket // that is bound to a V4 multicast address. func TestWriteOnBoundToV4Multicast(t *testing.T) { @@ -1374,9 +1421,7 @@ func TestReadIPPacketInfo(t *testing.T) { } } - if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil { - t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err) - } + c.ep.SocketOptions().SetReceivePacketInfo(true) testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ NIC: 1, @@ -1391,6 +1436,93 @@ func TestReadIPPacketInfo(t *testing.T) { } } +func TestReadRecvOriginalDstAddr(t *testing.T) { + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + flow testFlow + expectedOriginalDstAddr tcpip.FullAddress + }{ + { + name: "IPv4 unicast", + proto: header.IPv4ProtocolNumber, + flow: unicastV4, + expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort}, + }, + { + name: "IPv4 multicast", + proto: header.IPv4ProtocolNumber, + flow: multicastV4, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort}, + }, + { + name: "IPv4 broadcast", + proto: header.IPv4ProtocolNumber, + flow: broadcast, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort}, + }, + { + name: "IPv6 unicast", + proto: header.IPv6ProtocolNumber, + flow: unicastV6, + expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort}, + }, + { + name: "IPv6 multicast", + proto: header.IPv6ProtocolNumber, + flow: multicastV6, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(test.proto) + + bindAddr := tcpip.FullAddress{Port: stackPort} + if err := c.ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%#v): %s", bindAddr, err) + } + + if test.flow.isMulticast() { + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} + if err := c.ep.SetSockOpt(&ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) + } + } + + c.ep.SocketOptions().SetReceiveOriginalDstAddress(true) + + testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr)) + + if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { + t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) + } + }) + } +} + func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1414,16 +1546,12 @@ func TestNoChecksum(t *testing.T) { c.createEndpointForFlow(flow) // Disable the checksum generation. - if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil { - t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetNoChecksum(true) // This option is effective on IPv4 only. testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4()))) // Enable the checksum generation. - if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil { - t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetNoChecksum(false) testWrite(c, flow, checker.UDP(checker.NoChecksum(false))) }) } @@ -1432,29 +1560,17 @@ func TestNoChecksum(t *testing.T) { var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { - stack.NetworkLinkEndpoint + stack.NetworkInterface } func (*testInterface) ID() tcpip.NICID { return 0 } -func (*testInterface) IsLoopback() bool { - return false -} - -func (*testInterface) Name() string { - return "" -} - func (*testInterface) Enabled() bool { return true } -func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { - return tcpip.ErrNotSupported -} - func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1589,13 +1705,15 @@ func TestSetTClass(t *testing.T) { } func TestReceiveTosTClass(t *testing.T) { + const RcvTOSOpt = "ReceiveTosOption" + const RcvTClassOpt = "ReceiveTClassOption" + testCases := []struct { - name string - getReceiveOption tcpip.SockOptBool - tests []testFlow + name string + tests []testFlow }{ - {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}}, - {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, + {RcvTOSOpt, []testFlow{unicastV4, broadcast}}, + {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, } for _, testCase := range testCases { for _, flow := range testCase.tests { @@ -1604,29 +1722,32 @@ func TestReceiveTosTClass(t *testing.T) { defer c.cleanup() c.createEndpointForFlow(flow) - option := testCase.getReceiveOption name := testCase.name - // Verify that setting and reading the option works. - v, err := c.ep.GetSockOptBool(option) - if err != nil { - c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) + var optionGetter func() bool + var optionSetter func(bool) + switch name { + case RcvTOSOpt: + optionGetter = c.ep.SocketOptions().GetReceiveTOS + optionSetter = c.ep.SocketOptions().SetReceiveTOS + case RcvTClassOpt: + optionGetter = c.ep.SocketOptions().GetReceiveTClass + optionSetter = c.ep.SocketOptions().SetReceiveTClass + default: + t.Fatalf("unkown test variant: %s", name) } + + // Verify that setting and reading the option works. + v := optionGetter() // Test for expected default value. if v != false { c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) } want := true - if err := c.ep.SetSockOptBool(option, want); err != nil { - c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err) - } - - got, err := c.ep.GetSockOptBool(option) - if err != nil { - c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) - } + optionSetter(want) + got := optionGetter() if got != want { c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want) } @@ -1636,10 +1757,10 @@ func TestReceiveTosTClass(t *testing.T) { if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %s", err) } - switch option { - case tcpip.ReceiveTClassOption: + switch name { + case RcvTClassOpt: testRead(c, flow, checker.ReceiveTClass(testTOS)) - case tcpip.ReceiveTOSOption: + case RcvTOSOpt: testRead(c, flow, checker.ReceiveTOS(testTOS)) default: t.Fatalf("unknown test variant: %s", name) @@ -1953,12 +2074,12 @@ func TestShortHeader(t *testing.T) { // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(udpSize), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, + TrafficClass: testTOS, + PayloadLength: uint16(udpSize), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. @@ -2393,17 +2514,13 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) } - if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { - t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err) - } + ep.SocketOptions().SetBroadcast(true) if n, _, err := ep.Write(data, opts); err != nil { t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err) } - if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil { - t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err) - } + ep.SocketOptions().SetBroadcast(false) if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) |